You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ru...@apache.org on 2022/08/27 03:20:52 UTC
[tvm] branch main updated: [MetaSchedule][UX] Make `Database` with-able (#12520)
This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 370abe69d2 [MetaSchedule][UX] Make `Database` with-able (#12520)
370abe69d2 is described below
commit 370abe69d24519a5453cead846d328a1c378957f
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Fri Aug 26 20:20:42 2022 -0700
[MetaSchedule][UX] Make `Database` with-able (#12520)
`ApplyHistoryBest` right now plays a role as the database adaptor to query inside the database.
In fact, the logic could be simplified and users only have to deal with `Database` instead of this
extra object.
- [x] Add `EnterWithScope`/`ExitWithScope`/`Current` to Database
- [x] Migrate `te_filter_func` => "tir_filter" in Relay's pass context
- [x] Migrate `f_take_tuning_record` => "Database.query_tuning_record"
- [x] Migrate `TECompiler` to use `Database`
- [x] Remove apply-history-best
Next PR:
- Migrate `f_direct_dispatch` (potentially unify with `apply_fixed_schedule`?)
---
include/tvm/meta_schedule/apply_history_best.h | 115 --------------
include/tvm/meta_schedule/database.h | 28 ++++
include/tvm/meta_schedule/extracted_task.h | 20 ---
python/tvm/auto_scheduler/testing/tune_relay.py | 93 ++++++------
python/tvm/meta_schedule/__init__.py | 1 -
python/tvm/meta_schedule/apply_history_best.py | 130 ----------------
python/tvm/meta_schedule/database/database.py | 104 ++++++++++++-
python/tvm/meta_schedule/default_config.py | 4 -
python/tvm/meta_schedule/relay_integration.py | 29 ++--
python/tvm/meta_schedule/testing/tune_relay.py | 30 ++--
python/tvm/meta_schedule/testing/utils.py | 26 ++--
python/tvm/meta_schedule/tune.py | 12 +-
src/meta_schedule/apply_history_best.cc | 165 ---------------------
src/meta_schedule/database/database.cc | 64 ++++++++
src/meta_schedule/extracted_task.cc | 70 ---------
src/meta_schedule/utils.h | 1 -
src/relay/backend/task_extraction.cc | 25 ++--
src/relay/backend/te_compiler.cc | 1 +
src/relay/backend/te_compiler_cache.cc | 70 ++++-----
src/relay/backend/utils.cc | 73 +++++++++
src/relay/backend/utils.h | 31 ++++
.../test_meta_schedule_auto_tensorize.py | 25 ++--
tests/python/unittest/test_link_params.py | 19 +--
.../unittest/test_meta_schedule_integration.py | 62 +-------
.../unittest/test_meta_schedule_multi_anchor.py | 2 +-
.../test_meta_schedule_relay_tir_compute.py | 18 +--
.../unittest/test_meta_schedule_tune_relay.py | 57 ++++---
27 files changed, 511 insertions(+), 764 deletions(-)
diff --git a/include/tvm/meta_schedule/apply_history_best.h b/include/tvm/meta_schedule/apply_history_best.h
deleted file mode 100644
index 44a34b3ee4..0000000000
--- a/include/tvm/meta_schedule/apply_history_best.h
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#ifndef TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_
-#define TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_
-
-#include <tvm/ir/module.h>
-#include <tvm/meta_schedule/database.h>
-#include <tvm/node/reflection.h>
-#include <tvm/runtime/container/array.h>
-#include <tvm/runtime/container/optional.h>
-#include <tvm/runtime/container/string.h>
-#include <tvm/runtime/object.h>
-#include <tvm/runtime/packed_func.h>
-#include <tvm/target/target.h>
-#include <tvm/te/tensor.h>
-
-namespace tvm {
-namespace meta_schedule {
-
-/*!
- * \brief An integration context that allows application of historically best records from a
- * database
- */
-class ApplyHistoryBestNode : public runtime::Object {
- public:
- /*! \brief A callback function that filters TE compute */
- using FTEFilterFunc = runtime::TypedPackedFunc<Optional<tir::PrimFunc>(
- const Array<te::Tensor, void>&, const Array<runtime::NDArray>&)>;
- /*! \brief A callback function that takes a tuning record and does something with it */
- using FTakeTuningRecord = runtime::TypedPackedFunc<void(const TuningRecord&)>;
- using FDirectDispatch = runtime::TypedPackedFunc<Optional<IRModule>(const IRModule&)>;
-
- /*! \brief The database to be queried from */
- Database database{nullptr};
- /*! \brief The filtering function for TE computation */
- FTEFilterFunc te_filter_func{nullptr};
- /*! \brief The logging function to be used */
- PackedFunc logging_func;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("database", &database);
- // `te_filter_func` is not visited
- // `logging_func` is not visited
- }
- /*!
- * \brief Query the best entry from the database
- * \param task_name The name of the task to be queried
- * \param mod The module to be queried
- * \param target The target to be queried
- * \param dispatched The IRs after dispatch
- * \param f_take_tuning_record A callback function that takes a tuning record and does something
- * with it.
- * \param f_direct_dispatch A function that directly dispatches an IRModule to the given workload
- * as result if available, skipping the database query.
- */
- Optional<IRModule> Query(runtime::String task_name, IRModule mod, Target target,
- Optional<Array<IRModule>> dispatched,
- FTakeTuningRecord f_take_tuning_record,
- FDirectDispatch f_direct_dispatch = nullptr);
-
- static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest";
- TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, runtime::Object);
-};
-
-/*!
- * \brief Managed reference to ApplyHistoryBestNode
- * \sa ApplyHistoryBestNode
- */
-class ApplyHistoryBest : public runtime::ObjectRef {
- public:
- /*!
- * \brief Constructor
- * \param database The database to be queried from
- * \param te_filter_func The filtering function for TE computation
- * \param logging_func The logging function to use
- */
- explicit ApplyHistoryBest(Database database, ApplyHistoryBestNode::FTEFilterFunc te_filter_func,
- PackedFunc logging_func);
- /*!
- * \brief The current ApplyHistoryBest in the context
- * \return The ApplyHistoryBest in the current scope.
- */
- static Optional<ApplyHistoryBest> Current();
-
- TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ApplyHistoryBest, runtime::ObjectRef,
- ApplyHistoryBestNode);
-
- protected:
- friend class ApplyHistoryBestInternal;
- /*! \brief Entering the scope of the context manager */
- void EnterWithScope();
- /*! \brief Exiting the scope of the context manager */
- void ExitWithScope();
-};
-
-} // namespace meta_schedule
-} // namespace tvm
-
-#endif // TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_
diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h
index 1c260d9d74..0e7f45d393 100644
--- a/include/tvm/meta_schedule/database.h
+++ b/include/tvm/meta_schedule/database.h
@@ -203,6 +203,27 @@ class DatabaseNode : public runtime::Object {
* \return The size of the database.
*/
virtual int64_t Size() = 0;
+ /*!
+ * \brief Query the best record of the given workload from the database.
+ * \param mod The IRModule to be searched for.
+ * \param target The target to be searched for.
+ * \return The best record of the given workload; NullOpt if not found.
+ */
+ virtual Optional<TuningRecord> QueryTuningRecord(IRModule mod, Target target);
+ /*!
+ * \brief Query the best schedule of the given workload from the database.
+ * \param mod The IRModule to be searched for.
+ * \param target The target to be searched for.
+ * \return The schedule in the best schedule of the given workload; NullOpt if not found.
+ */
+ virtual Optional<tir::Schedule> QuerySchedule(IRModule mod, Target target);
+ /*!
+ * \brief Query the best IRModule of the given workload from the database.
+ * \param mod The IRModule to be searched for.
+ * \param target The target to be searched for.
+ * \return The IRModule in the best IRModule of the given workload; NullOpt if not found.
+ */
+ virtual Optional<IRModule> QueryIRModule(IRModule mod, Target target);
static constexpr const char* _type_key = "meta_schedule.Database";
TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object);
@@ -339,6 +360,13 @@ class Database : public runtime::ObjectRef {
PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
PyDatabaseNode::FSize f_size);
+ /*! \return The current Database in the scope. */
+ static Optional<Database> Current();
+ /*! \brief Entering the scope of the context manager */
+ void EnterWithScope();
+ /*! \brief Exiting the scope of the context manager */
+ void ExitWithScope();
+
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode);
};
diff --git a/include/tvm/meta_schedule/extracted_task.h b/include/tvm/meta_schedule/extracted_task.h
index bce40e6b95..239bf0dc57 100644
--- a/include/tvm/meta_schedule/extracted_task.h
+++ b/include/tvm/meta_schedule/extracted_task.h
@@ -76,26 +76,6 @@ class ExtractedTask : public runtime::ObjectRef {
ExtractedTaskNode);
};
-/*!
- * \brief The default TE task filter
- * \param args The input/output arguments of the TE compute graph
- * \param constants Raw data for constant tensors in args. If the size of this array is N, the last
- * N tensors in args will be treated as constant tensors.
- * \return NullOpt if the task is filtered out, otherwise the task in PrimFunc
- */
-Optional<tvm::tir::PrimFunc> DefaultTaskFilter(const Array<tvm::te::Tensor, void>& args,
- const Array<runtime::NDArray>& constants);
-
-/*!
- * \brief The default TE task filter, with `te.extern` allowed
- * \param args The input/output arguments of the TE compute graph
- * \param constants Raw data for constant tensors in args. If the size of this array is N, the last
- * N tensors in args will be treated as constant tensors.
- * \return NullOpt if the task is filtered out, otherwise the task in PrimFunc
- */
-Optional<tir::PrimFunc> DefaultTaskFilterAllowExtern(const Array<tvm::te::Tensor, void>& args,
- const Array<runtime::NDArray>& constants);
-
} // namespace meta_schedule
} // namespace tvm
diff --git a/python/tvm/auto_scheduler/testing/tune_relay.py b/python/tvm/auto_scheduler/testing/tune_relay.py
index fe747af797..2d84389f9d 100644
--- a/python/tvm/auto_scheduler/testing/tune_relay.py
+++ b/python/tvm/auto_scheduler/testing/tune_relay.py
@@ -15,10 +15,10 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
-from distutils.util import strtobool
import argparse
import json
import os
+from distutils.util import strtobool
import tvm
from tvm import auto_scheduler
@@ -26,7 +26,7 @@ from tvm import meta_schedule as ms
from tvm import relay
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.relay_workload import get_network
-from tvm.meta_schedule.testing.tune_utils import generate_input_data, create_timer
+from tvm.meta_schedule.testing.tune_utils import create_timer, generate_input_data
from tvm.meta_schedule.utils import cpu_count
from tvm.support import describe
@@ -170,53 +170,62 @@ def main():
ARGS.input_shape,
cache_dir=ARGS.cache_dir,
)
- input_info = {input_name: input_shape}
+ input_info = [
+ {
+ "name": input_name,
+ "shape": input_shape,
+ "dtype": input_dtype,
+ },
+ ]
input_data = {
- item["name"]: generate_input_data(item["shape"], item["dtype"]) for item in ARGS.input_shape
+ item["name"]: generate_input_data(item["shape"], item["dtype"]) for item in input_info
}
- for input_name, input_shape in input_info.items():
- print(f" input_name : {input_name}")
- print(f" input_shape: {input_shape}")
- print(f" input_dtype: {input_dtype}")
+ for item in input_info:
+ print(f" input_name : {item['name']}")
+ print(f" input_shape: {item['shape']}")
+ print(f" input_dtype: {item['dtype']}")
with ms.Profiler() as profiler:
- tasks, task_weights = auto_scheduler.extract_tasks(
- mod["main"],
- params,
- target=ARGS.target,
- hardware_params=hardware_params,
- )
- for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
- print(
- f"==== Task {idx}: {task.desc} "
- f"(weight {task_weight} key: {task.workload_key}) ====="
- )
- print(task.compute_dag)
-
- if ARGS.num_trials > 0:
- tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
- tuner.tune(
- auto_scheduler.TuningOptions(
- num_measure_trials=ARGS.num_trials,
- runner=runner,
- measure_callbacks=[
- auto_scheduler.RecordToFile(log_file),
- ],
- ),
- adaptive_training=ARGS.adaptive_training,
+ with ms.Profiler.timeit("TaskExtraction"):
+ tasks, task_weights = auto_scheduler.extract_tasks(
+ mod["main"],
+ params,
+ target=ARGS.target,
+ hardware_params=hardware_params,
)
+ for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
+ print(
+ f"==== Task {idx}: {task.desc} "
+ f"(weight {task_weight} key: {task.workload_key}) ====="
+ )
+ print(task.compute_dag)
+
+ with ms.Profiler.timeit("Tuning"):
+ if ARGS.num_trials > 0:
+ tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
+ tuner.tune(
+ auto_scheduler.TuningOptions(
+ num_measure_trials=ARGS.num_trials,
+ runner=runner,
+ measure_callbacks=[
+ auto_scheduler.RecordToFile(log_file),
+ ],
+ ),
+ adaptive_training=ARGS.adaptive_training,
+ )
relay_build = {"graph": relay.build, "vm": relay.vm.compile}[ARGS.backend]
- with auto_scheduler.ApplyHistoryBest(log_file):
- with tvm.transform.PassContext(
- opt_level=3,
- config={"relay.backend.use_auto_scheduler": True},
- ):
- lib = relay_build(
- mod,
- target=ARGS.target,
- params=params,
- )
+ with ms.Profiler.timeit("PostTuningCompilation"):
+ with auto_scheduler.ApplyHistoryBest(log_file):
+ with tvm.transform.PassContext(
+ opt_level=3,
+ config={"relay.backend.use_auto_scheduler": True},
+ ):
+ lib = relay_build(
+ mod,
+ target=ARGS.target,
+ params=params,
+ )
print("Tuning Time:")
print(profiler.table())
diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py
index f60d0a5490..cf348d49f4 100644
--- a/python/tvm/meta_schedule/__init__.py
+++ b/python/tvm/meta_schedule/__init__.py
@@ -30,7 +30,6 @@ from . import (
search_strategy,
space_generator,
)
-from .apply_history_best import ApplyHistoryBest
from .extracted_task import ExtractedTask
from .profiler import Profiler
from .relay_integration import (
diff --git a/python/tvm/meta_schedule/apply_history_best.py b/python/tvm/meta_schedule/apply_history_best.py
deleted file mode 100644
index a7b9b20bf2..0000000000
--- a/python/tvm/meta_schedule/apply_history_best.py
+++ /dev/null
@@ -1,130 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""A context manager that injects the best tuning record in the database into compilation"""
-import logging
-from typing import Callable, List, Optional, Union
-
-from tvm._ffi import get_global_func, register_object
-from tvm.ir import IRModule
-from tvm.runtime import Object
-from tvm.target import Target
-from tvm.te import Tensor
-from tvm.tir import PrimFunc
-
-from . import _ffi_api
-from .database import Database, TuningRecord
-from .utils import make_logging_func
-
-logger = logging.getLogger(__name__) # pylint: disable=invalid-name
-
-
-@register_object("meta_schedule.ApplyHistoryBest")
-class ApplyHistoryBest(Object):
- """An integration context that allows application of historically best records from a database
-
- Parameters
- ----------
- database : Database
- The database to be queried from
- te_filter_func : Union[str, None, Callable[[List[Tensor], List[NDArray]], PrimFunc]] = None
- The filtering function for TE computation
- If it's a string, it's the name of the filtering function. Built in functions are
- - "meta_schedule.DefaultTaskFilter"
- - "meta_schedule.DefaultTaskFilterAllowExtern"
- If it's None, it's the default filtering function
- If it's a callable, it's the filtering function
- """
-
- database: Database
-
- def __init__(
- self,
- database: Database,
- te_filter_func: Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None,
- ) -> None:
- if isinstance(te_filter_func, str):
- te_filter_func = get_global_func(te_filter_func)
- self.__init_handle_by_constructor__(
- _ffi_api.ApplyHistoryBest, # type: ignore # pylint: disable=no-member
- database,
- te_filter_func,
- make_logging_func(logger),
- )
-
- def query(
- self,
- task_name: str,
- mod: IRModule,
- target: Target,
- dispatched: Optional[List[IRModule]],
- f_take_tuning_record: Optional[Callable[[TuningRecord], None]] = None,
- f_direct_dispatch: Optional[Callable[[IRModule], Optional[IRModule]]] = None,
- ) -> Union[IRModule, None]:
- """The entry point of the integration
-
- Parameters
- ----------
- task_name : str
- The name of the task extracted
- mod : IRModule
- The high-level IR
- target: Target
- Target Info
- dispatched : Optional[List[IRModule]]
- A list of low-level IRs that the high-level IR could potentially dispatch to
- f_take_tuning_record : Optional[Callable[[TuningRecord], None]] = None
- A callback function that takes a tuning record and does something with it
- f_direct_dispatch : Optional[Callable[[IRModule], Optional[IRModule]]] = None
- A function that directly dispatches an IRModule to the given workload as result if
- available, skipping the database query.
-
- Returns
- -------
- result : IRModule or None
- Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for
- more general future use. None is returned if there is no feedback hint.
- """
- return _ffi_api.ApplyHistoryBestQuery( # type: ignore # pylint: disable=no-member
- self,
- task_name,
- mod,
- target,
- dispatched,
- f_take_tuning_record,
- f_direct_dispatch,
- )
-
- @staticmethod
- def current() -> Optional["ApplyHistoryBest"]:
- """The context manager in the current scope
-
- Returns
- -------
- ctx : Optional[ApplyHistoryBest]
- The ApplyHistoryBest context manager in the current scope.
- None if it's currently not under any ApplyHistoryBest context.
- """
- return _ffi_api.ApplyHistoryBestCurrent() # type: ignore # pylint: disable=no-member
-
- def __enter__(self) -> "ApplyHistoryBest":
- """Entering the scope of the context manager"""
- _ffi_api.ApplyHistoryBestEnterScope(self) # type: ignore # pylint: disable=no-member
- return self
-
- def __exit__(self, ptype, value, trace) -> None:
- """Exiting the scope of the context manager"""
- _ffi_api.ApplyHistoryBestExitScope(self) # type: ignore # pylint: disable=no-member
diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py
index 0c11f77591..68283b4554 100644
--- a/python/tvm/meta_schedule/database/database.py
+++ b/python/tvm/meta_schedule/database/database.py
@@ -15,13 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""TuningRecord database"""
-from typing import Any, Callable, List, Optional
+from typing import Any, Callable, List, Optional, Union
from tvm._ffi import register_object
from tvm.ir.module import IRModule
from tvm.runtime import Object
from tvm.target import Target
-from tvm.tir.schedule import Trace
+from tvm.tir.schedule import Schedule, Trace
+from typing_extensions import Literal # pylint: disable=wrong-import-order
from .. import _ffi_api
from ..arg_info import ArgInfo
@@ -234,6 +235,105 @@ class Database(Object):
"""
return _ffi_api.DatabaseSize(self) # type: ignore # pylint: disable=no-member
+ def query_tuning_record(self, mod: IRModule, target: Target) -> Optional[TuningRecord]:
+ """Query the best record of the given workload from the database.
+
+ Parameters
+ ----------
+ mod : IRModule
+ The IRModule to be searched for.
+ target : Target
+ The target to be searched for.
+
+ Returns
+ -------
+ tuning_record : Optional[TuningRecord]
+ The best record of the given workload; None if not found.
+ """
+ return _ffi_api.DatabaseQueryTuningRecord(self, mod, target) # type: ignore # pylint: disable=no-member
+
+ def query_schedule(self, mod: IRModule, target: Target) -> Optional[Schedule]:
+ """Query the best schedule of the given workload from the database.
+
+ Parameters
+ ----------
+ mod : IRModule
+ The IRModule to be searched for.
+ target : Target
+ The target to be searched for.
+
+ Returns
+ -------
+ schedule : Optional[Schedule]
+ The best schedule of the given workload; None if not found.
+ """
+ return _ffi_api.DatabaseQuerySchedule(self, mod, target) # type: ignore # pylint: disable=no-member
+
+ def query_ir_module(self, mod: IRModule, target: Target) -> Optional[IRModule]:
+ """Query the best IRModule of the given workload from the database.
+
+ Parameters
+ ----------
+ mod : IRModule
+ The IRModule to be searched for.
+ target : Target
+ The target to be searched for.
+
+ Returns
+ -------
+ ir_module : Optional[IRModule]
+ The best IRModule of the given workload; None if not found.
+ """
+ return _ffi_api.DatabaseQueryIRModule(self, mod, target) # type: ignore # pylint: disable=no-member
+
+ def query(
+ self,
+ mod: IRModule,
+ target: Target,
+ kind: Union[
+ Literal["schedule"],
+ Literal["record"],
+ Literal["ir_module"],
+ ] = "schedule",
+ ) -> Union[Schedule, IRModule, TuningRecord]:
+ """Query the database to retrieve the best optimization outcome of the given workload.
+
+ Parameters
+ ----------
+ mod : IRModule
+ The IRModule to be searched for.
+ target : Target
+ The target to be searched for.
+ kind : str = "schedule" | "record" | "ir_module"
+ The kind of the optimization outcome to be returned.
+
+ Returns
+ -------
+ result : Union[Schedule, IRModule, TuningRecord]
+ The best optimization outcome of the given workload.
+ """
+ if kind == "schedule":
+ return self.query_schedule(mod, target)
+ if kind == "record":
+ return self.query_tuning_record(mod, target)
+ if kind == "ir_module":
+ return self.query_ir_module(mod, target)
+ raise ValueError(f'Unknown kind: {kind}. Candidates are: "schedule", "record", "ir_module"')
+
+ def __enter__(self) -> "Database":
+ """Entering the scope of the context manager"""
+ _ffi_api.DatabaseEnterWithScope(self) # type: ignore # pylint: disable=no-member
+ return self
+
+ def __exit__(self, ptype, value, trace) -> None:
+ """Exiting the scope of the context manager"""
+ _ffi_api.DatabaseExitWithScope(self) # type: ignore # pylint: disable=no-member
+
+ @staticmethod
+ def current() -> Optional["Database"]:
+ """Get the current database under scope."""
+ return _ffi_api.DatabaseCurrent() # type: ignore # pylint: disable=no-member
+
@register_object("meta_schedule.PyDatabase")
class _PyDatabase(Database):
diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py
index 97cbfc58a6..652f09261b 100644
--- a/python/tvm/meta_schedule/default_config.py
+++ b/python/tvm/meta_schedule/default_config.py
@@ -20,7 +20,6 @@ import logging
from os import path as osp
from typing import Callable, Dict, List, Optional, Union
-from tvm._ffi.registry import register_func
from tvm.ir import IRModule
from tvm.target import Target
from tvm.tir import PrimFunc
@@ -44,7 +43,6 @@ FnPostproc = Callable[[], List[Postproc]]
FnMutatorProb = Callable[[], Dict[Mutator, float]]
-@register_func("tvm.meta_schedule.tune.parse_mod") # for use in ApplyHistoryBest
def mod(mod: Union[PrimFunc, IRModule]) -> IRModule: # pylint: disable=redefined-outer-name
"""Normalize the input to an IRModule"""
if isinstance(mod, PrimFunc):
@@ -53,8 +51,6 @@ def mod(mod: Union[PrimFunc, IRModule]) -> IRModule: # pylint: disable=redefine
mod = IRModule({"main": mod})
if not isinstance(mod, IRModule):
raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}")
- # in order to make sure the mod can be found in ApplyHistoryBest
- # different func name can cause structural unequal
func_names = mod.get_global_vars()
(func_name,) = func_names
if len(func_names) == 1 and func_name != "main":
diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py
index d3b3ea7965..24009ab07f 100644
--- a/python/tvm/meta_schedule/relay_integration.py
+++ b/python/tvm/meta_schedule/relay_integration.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""MetaSchedule-Relay integration"""
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional
import numpy as np # type: ignore
from tvm import nd
@@ -23,8 +23,6 @@ from tvm._ffi import get_global_func
from tvm.ir import IRModule, transform
from tvm.runtime import NDArray
from tvm.target import Target
-from tvm.te import Tensor
-from tvm.tir import PrimFunc
from .extracted_task import ExtractedTask
from .utils import autotvm_silencer
@@ -38,7 +36,7 @@ def extract_task_from_relay(
opt_level: int = 3,
pass_config: Optional[Dict[str, Any]] = None,
disabled_pass: Optional[List[str]] = None,
- te_filter_func: Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None,
+ tir_converter: str = "default",
) -> List[ExtractedTask]:
"""Extract tuning tasks from a relay program.
@@ -56,13 +54,13 @@ def extract_task_from_relay(
The pass config of the compiler
disabled_pass : Optional[List[str]]
The list of disabled passes of the compiler
- te_filter_func : Callable[[List[tvm.te.Tensor], List[NDArray]], bool]
- The filter function to filter out the extracted tasks
- If it's a string, it's the name of the filtering function. Built in functions are
- - "meta_schedule.DefaultTaskFilter"
- - "meta_schedule.DefaultTaskFilterAllowExtern"
- If it's None, it's the default filtering function
- If it's a callable, it's the filtering function
+ tir_converter : str
+ The filter function to filter out the extracted tasks. Builtin filters:
+ - "default"
+ - "allow_extern"
+ The converter is a PackedFunc registered as f"relay.backend.tir_converter.{tir_converter}",
+ with the signature below:
+ (args: List[te.Tensor], constants: List[NDArray]) -> Optional[tir.PrimFunc]
Returns
-------
@@ -75,8 +73,6 @@ def extract_task_from_relay(
# pylint: enable=import-outside-toplevel
- if isinstance(te_filter_func, str):
- te_filter_func = get_global_func(te_filter_func)
extract_task_func = get_global_func(
"relay.backend.MetaScheduleExtractTask",
allow_missing=False,
@@ -89,7 +85,10 @@ def extract_task_from_relay(
if disabled_pass is None:
disabled_pass = []
if pass_config is None:
- pass_config = {"relay.backend.use_meta_schedule": True}
+ pass_config = {
+ "relay.backend.use_meta_schedule": True,
+ "relay.backend.tir_converter": tir_converter,
+ }
if params is None:
params = {}
relay_params = {}
@@ -110,7 +109,7 @@ def extract_task_from_relay(
else:
tophub_context = autotvm.utils.EmptyContext()
with tophub_context:
- return list(extract_task_func(mod, target, relay_params, te_filter_func))
+ return list(extract_task_func(mod, target, relay_params))
def is_meta_schedule_enabled() -> bool:
diff --git a/python/tvm/meta_schedule/testing/tune_relay.py b/python/tvm/meta_schedule/testing/tune_relay.py
index 8010e36fd6..596a5a7363 100644
--- a/python/tvm/meta_schedule/testing/tune_relay.py
+++ b/python/tvm/meta_schedule/testing/tune_relay.py
@@ -15,16 +15,18 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
-from distutils.util import strtobool
import argparse
import json
import logging
+from distutils.util import strtobool
+from typing import Dict
+import numpy as np # type: ignore
import tvm
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.relay_workload import get_network
-from tvm.meta_schedule.testing.tune_utils import generate_input_data, create_timer
+from tvm.meta_schedule.testing.tune_utils import create_timer, generate_input_data
from tvm.support import describe
@@ -137,14 +139,24 @@ def main():
ARGS.input_shape,
cache_dir=ARGS.cache_dir,
)
- input_info = {input_name: input_shape}
- input_data = {
- item["name"]: generate_input_data(item["shape"], item["dtype"]) for item in ARGS.input_shape
+ input_info = [
+ {
+ "name": input_name,
+ "shape": input_shape,
+ "dtype": input_dtype,
+ },
+ ]
+ input_data: Dict[str, np.ndarray] = {
+ item["name"]: generate_input_data( # type: ignore
+ item["shape"], # type: ignore
+ item["dtype"], # type: ignore
+ )
+ for item in input_info
}
- for input_name, input_shape in input_info.items():
- print(f" input_name : {input_name}")
- print(f" input_shape: {input_shape}")
- print(f" input_dtype: {input_dtype}")
+ for item in input_info:
+ print(f" input_name : {item['name']}")
+ print(f" input_shape: {item['shape']}")
+ print(f" input_dtype: {item['dtype']}")
runner = ms.runner.RPCRunner(
rpc_config=ARGS.rpc_config,
diff --git a/python/tvm/meta_schedule/testing/utils.py b/python/tvm/meta_schedule/testing/utils.py
index dda492008f..5919fb47c8 100644
--- a/python/tvm/meta_schedule/testing/utils.py
+++ b/python/tvm/meta_schedule/testing/utils.py
@@ -16,12 +16,13 @@
# under the License.
"""Testing utility functions in meta schedule"""
from typing import Callable, Dict, Optional, Union
+
+from tvm import meta_schedule as ms
from tvm.ir import IRModule, transform
from tvm.relay import Function as RelayFunc
from tvm.runtime import NDArray
from tvm.target import Target
from tvm.tir import Schedule
-from tvm import meta_schedule as ms
def apply_fixed_schedules(
@@ -29,10 +30,10 @@ def apply_fixed_schedules(
target: Union[str, Target],
params: Optional[Dict[str, NDArray]],
schedule_fn: Callable[[ms.ExtractedTask, Schedule], bool],
- te_filter_func=None,
+ tir_converter: str = "default",
):
"""Apply fixed schedules (manually written, without any tunable knobs) as specified by
- schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest.
+ schedule_fn to extracted tasks, and return a database that can be passed to compilation.
Parameters
----------
@@ -45,13 +46,13 @@ def apply_fixed_schedules(
schedule_fn : Callable[[ExtractedTask, Schedule], bool]
A callable that is applied for each extracted task and the corresponding default schedule.
Returns True if the given schedule should be committed to the database, False otherwise.
- te_filter_func : Union[str, None, Callable[[List[Tensor], List[NDArray]], PrimFunc]] = None
- The filtering function for TE computation
- If it's a string, it's the name of the filtering function. Built in functions are
- - "meta_schedule.DefaultTaskFilter"
- - "meta_schedule.DefaultTaskFilterAllowExtern"
- If it's None, it's the default filtering function
- If it's a callable, it's the filtering function
+ tir_converter : str
+ The filter function to filter out the extracted tasks. Builtin filters:
+ - "default"
+ - "allow_extern"
+ The converter is a PackedFunc registered as f"relay.backend.tir_converter.{tir_converter}",
+ with the signature below:
+ (args: List[te.Tensor], constants: List[NDArray]) -> Optional[tir.PrimFunc]
Returns
-------
@@ -64,7 +65,10 @@ def apply_fixed_schedules(
config[k] = v
extracted_tasks = ms.extract_task_from_relay(
- relay_mod, target, params, te_filter_func=te_filter_func, pass_config=config
+ relay_mod,
+ target,
+ params,
+ tir_converter=tir_converter,
)
database = ms.database.MemoryDatabase()
for task in extracted_tasks:
diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py
index 447fb56637..20eccc30a1 100644
--- a/python/tvm/meta_schedule/tune.py
+++ b/python/tvm/meta_schedule/tune.py
@@ -24,14 +24,12 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union
from tvm.ir import IRModule
from tvm.ir.transform import PassContext
-from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
from tvm.runtime import Module, NDArray, vm
from tvm.target import Target
from tvm.te import Tensor, create_prim_func
from tvm.tir import PrimFunc, Schedule
from . import default_config
-from .apply_history_best import ApplyHistoryBest
from .builder import Builder
from .cost_model import CostModel
from .database import Database, TuningRecord
@@ -43,7 +41,7 @@ from .profiler import Profiler
from .runner import Runner
from .schedule_rule import ScheduleRule
from .search_strategy import EvolutionarySearch, ReplayFunc, ReplayTrace
-from .space_generator import SpaceGenerator
+from .space_generator import PostOrderApply, SpaceGenerator
from .task_scheduler import GradientBased, RoundRobin
from .tune_context import TuneContext
from .utils import autotvm_silencer, batch_parameterize_config
@@ -461,7 +459,7 @@ def tune_tir(
mutator_probs=mutator_probs,
num_threads=num_threads,
)
- with Profiler.timeit("ApplyHistoryBest"):
+ with Profiler.timeit("PostTuningCompilation"):
bests: List[TuningRecord] = database.get_top_k(database.commit_workload(mod), top_k=1)
if not bests:
return None
@@ -591,6 +589,7 @@ def tune_relay(
"""
# pylint: disable=import-outside-toplevel
from tvm import relay
+
from .relay_integration import extract_task_from_relay
# pylint: disable=protected-access, enable=import-outside-toplevel
@@ -615,13 +614,14 @@ def tune_relay(
num_threads=num_threads,
)
relay_build = {"graph": relay.build, "vm": relay.vm.compile}[backend]
- with Profiler.timeit("ApplyHistoryBest"):
- with target, autotvm_silencer(), ApplyHistoryBest(database):
+ with Profiler.timeit("PostTuningCompilation"):
+ with target, autotvm_silencer(), database:
with PassContext(
opt_level=3,
config={
"relay.backend.use_meta_schedule": True,
"relay.backend.use_meta_schedule_dispatch": target.kind.name != "cuda",
+ "relay.backend.tir_converter": "default",
},
):
return relay_build(mod, target=target, params=params)
diff --git a/src/meta_schedule/apply_history_best.cc b/src/meta_schedule/apply_history_best.cc
deleted file mode 100644
index 62db293067..0000000000
--- a/src/meta_schedule/apply_history_best.cc
+++ /dev/null
@@ -1,165 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#include <tvm/te/tensor.h>
-
-#include "./utils.h"
-
-namespace tvm {
-namespace meta_schedule {
-
-/**************** Utility functions ****************/
-
-template <class FunctionType, class RetType, class Callback>
-Optional<RetType> GetOnlyOneFunctionCommon(const IRModule& mod, Callback on_found) {
- if (mod->functions.size() != 1) {
- return NullOpt;
- }
- for (const auto& kv : mod->functions) {
- const BaseFunc& func = kv.second;
- if (!func->IsInstance<typename FunctionType::ContainerType>()) {
- return NullOpt;
- } else {
- return on_found(kv);
- }
- }
- return NullOpt;
-}
-
-template <class FunctionType>
-Optional<GlobalVar> GetOnlyOneFunctionKey(const IRModule& mod) {
- return GetOnlyOneFunctionCommon<FunctionType, GlobalVar>(mod, [](auto kv) { return kv.first; });
-}
-
-template <class FunctionType>
-Optional<FunctionType> GetOnlyOneFunction(const IRModule& mod) {
- return GetOnlyOneFunctionCommon<FunctionType, FunctionType>(
- mod, [](auto kv) { return Downcast<FunctionType>(kv.second); });
-}
-
-template <class FunctionType>
-bool HasOnlyOneFunction(const IRModule& mod) {
- return GetOnlyOneFunction<FunctionType>(mod).defined();
-}
-
-/**************** Context Manager ****************/
-
-class ApplyHistoryBestInternal {
- public:
- static void EnterScope(ApplyHistoryBest ctx) { ctx.EnterWithScope(); }
- static void ExitScope(ApplyHistoryBest ctx) { ctx.ExitWithScope(); }
-};
-
-struct ApplyHistoryBestThreadLocalEntry {
- Optional<ApplyHistoryBest> ctx;
-};
-
-using ApplyHistoryBestThreadLocalStore = dmlc::ThreadLocalStore<ApplyHistoryBestThreadLocalEntry>;
-
-Optional<ApplyHistoryBest> ApplyHistoryBest::Current() {
- return ApplyHistoryBestThreadLocalStore::Get()->ctx;
-}
-
-void ApplyHistoryBest::EnterWithScope() {
- Optional<ApplyHistoryBest>& ctx = ApplyHistoryBestThreadLocalStore::Get()->ctx;
- CHECK(!ctx.defined()) << "ValueError: Nested ApplyHistoryBest context managers are not allowed";
- ctx = *this;
-}
-
-void ApplyHistoryBest::ExitWithScope() {
- Optional<ApplyHistoryBest>& ctx = ApplyHistoryBestThreadLocalStore::Get()->ctx;
- ICHECK(ctx.defined());
- ctx = NullOpt;
-}
-
-/**************** ApplyHistoryBest ****************/
-
-ApplyHistoryBest::ApplyHistoryBest(Database database,
- ApplyHistoryBestNode::FTEFilterFunc te_filter_func,
- PackedFunc logging_func) {
- ObjectPtr<ApplyHistoryBestNode> n = make_object<ApplyHistoryBestNode>();
- n->database = database;
- n->te_filter_func = te_filter_func;
- n->logging_func = logging_func;
- if (te_filter_func == nullptr) {
- n->te_filter_func = DefaultTaskFilter;
- }
- data_ = n;
-}
-
-Optional<IRModule> ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod,
- Target target, Optional<Array<IRModule>> dispatched,
- FTakeTuningRecord f_take_tuning_record,
- FDirectDispatch f_direct_dispatch) {
- ICHECK(dispatched.defined());
- ICHECK_EQ(dispatched.value().size(), 1);
- ICHECK(HasOnlyOneFunction<relay::Function>(mod)) << mod;
- IRModule prim_mod = dispatched.value()[0];
- ICHECK(HasOnlyOneFunction<tir::PrimFunc>(prim_mod)) << prim_mod;
-
- // Keep the original func name to be returned later.
- GlobalVar gv = GetOnlyOneFunctionKey<tir::PrimFunc>(prim_mod).value();
-
- // Unify func name to make sure it can be found in database
- const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod");
- ICHECK(parse_mod_func) << "Parse mod function not defined!";
- prim_mod = (*parse_mod_func)(prim_mod);
-
- if (f_direct_dispatch != nullptr) {
- Optional<IRModule> mod = f_direct_dispatch(prim_mod);
- if (mod.defined()) {
- TVM_PY_LOG(INFO, logging_func) << "Direct dispatch applied for workload: " << task_name;
- return mod.value();
- }
- }
- if (database->HasWorkload(prim_mod)) {
- Array<TuningRecord> records = database->GetTopK(database->CommitWorkload(prim_mod), 1);
- if (records.size() == 1) {
- if (f_take_tuning_record != nullptr) {
- f_take_tuning_record(records[0]);
- }
- tir::Schedule sch =
- tir::Schedule::Traced(records[0]->workload->mod, /*seed=*/-1, /*debug_mask=*/0,
- /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
- records[0]->trace->ApplyToSchedule(sch, false);
- tir::PrimFunc func = GetOnlyOneFunction<tir::PrimFunc>(sch->mod()).value();
- // Make sure we return the updated PrimFunc paired with the original func name.
- return IRModule({{gv, func}});
- }
- }
- TVM_PY_LOG(WARNING, logging_func) << "Cannot find workload: " << task_name;
- return NullOpt;
-}
-
-TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode);
-TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest")
- .set_body_typed([](Database database, ApplyHistoryBestNode::FTEFilterFunc te_filter_func,
- PackedFunc logging_func) -> ApplyHistoryBest {
- return ApplyHistoryBest(database, te_filter_func, logging_func);
- });
-TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestEnterScope")
- .set_body_typed(ApplyHistoryBestInternal::EnterScope);
-TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestExitScope")
- .set_body_typed(ApplyHistoryBestInternal::ExitScope);
-TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestCurrent")
- .set_body_typed(ApplyHistoryBest::Current);
-TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestQuery")
- .set_body_method<ApplyHistoryBest>(&ApplyHistoryBestNode::Query);
-
-} // namespace meta_schedule
-} // namespace tvm
diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc
index 4e180c4fab..fedd2aa352 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -154,6 +154,59 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w
return TuningRecord(trace, workload, run_secs, target, args_info);
}
+/******** Database ********/
+
+Optional<TuningRecord> DatabaseNode::QueryTuningRecord(IRModule mod, Target target) {
+ if (!this->HasWorkload(mod)) {
+ return NullOpt;
+ }
+ Array<TuningRecord> records = this->GetTopK(this->CommitWorkload(mod), 1);
+ if (records.empty()) {
+ return NullOpt;
+ }
+ ICHECK_EQ(records.size(), 1);
+ return records[0];
+}
+
+Optional<tir::Schedule> DatabaseNode::QuerySchedule(IRModule mod, Target target) {
+ if (Optional<TuningRecord> opt_record = this->QueryTuningRecord(mod, target)) {
+ TuningRecord record = opt_record.value();
+ tir::Schedule sch =
+ tir::Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0,
+ /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);
+ record->trace->ApplyToSchedule(sch, false);
+ return sch;
+ } else {
+ return NullOpt;
+ }
+}
+
+Optional<IRModule> DatabaseNode::QueryIRModule(IRModule mod, Target target) {
+ if (Optional<tir::Schedule> opt_sch = this->QuerySchedule(mod, target)) {
+ return opt_sch.value()->mod();
+ } else {
+ return NullOpt;
+ }
+}
+
+std::vector<Database>* ThreadLocalDatabases() {
+ static thread_local std::vector<Database> tls;
+ return &tls;
+}
+
+void Database::EnterWithScope() { ThreadLocalDatabases()->push_back(*this); }
+
+void Database::ExitWithScope() { ThreadLocalDatabases()->pop_back(); }
+
+Optional<Database> Database::Current() {
+ std::vector<Database>* tls = ThreadLocalDatabases();
+ if (tls->empty()) {
+ return NullOpt;
+ } else {
+ return tls->back();
+ }
+}
+
/******** PyDatabase ********/
Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
@@ -194,6 +247,11 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsMeasureCandidate")
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON")
.set_body_method<TuningRecord>(&TuningRecordNode::AsJSON);
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON);
+TVM_REGISTER_GLOBAL("meta_schedule.DatabaseEnterWithScope")
+ .set_body_method(&Database::EnterWithScope);
+TVM_REGISTER_GLOBAL("meta_schedule.DatabaseExitWithScope")
+ .set_body_method(&Database::ExitWithScope);
+TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCurrent").set_body_typed(Database::Current);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseHasWorkload")
.set_body_method<Database>(&DatabaseNode::HasWorkload);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload")
@@ -205,6 +263,12 @@ TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK")
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetAllTuningRecords")
.set_body_method<Database>(&DatabaseNode::GetAllTuningRecords);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method<Database>(&DatabaseNode::Size);
+TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryTuningRecord")
+ .set_body_method<Database>(&DatabaseNode::QueryTuningRecord);
+TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQuerySchedule")
+ .set_body_method<Database>(&DatabaseNode::QuerySchedule);
+TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryIRModule")
+ .set_body_method<Database>(&DatabaseNode::QueryIRModule);
TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase);
} // namespace meta_schedule
diff --git a/src/meta_schedule/extracted_task.cc b/src/meta_schedule/extracted_task.cc
index 3406f82eb1..ec04361f51 100644
--- a/src/meta_schedule/extracted_task.cc
+++ b/src/meta_schedule/extracted_task.cc
@@ -38,67 +38,6 @@ ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target,
data_ = n;
}
-Optional<tir::PrimFunc> DefaultTaskFilterImpl(const Array<te::Tensor>& args,
- const Array<runtime::NDArray>& constants,
- bool allow_extern_op) {
- using namespace ::tvm::te;
- std::vector<Tensor> stack;
- std::unordered_set<const TensorNode*> visited;
- for (const Tensor& v : args) {
- for (const PrimExpr& e : v->shape) {
- // Dynamic shape is not supported for now
- if (!e->IsInstance<IntImmNode>()) {
- return NullOpt;
- }
- }
- if (!visited.count(v.get())) {
- visited.insert(v.get());
- stack.push_back(v);
- }
- }
- while (!stack.empty()) {
- Tensor tensor = stack.back();
- stack.pop_back();
- if (tensor->op->IsInstance<PlaceholderOpNode>()) {
- // do nothing
- } else if (tensor->op->IsInstance<ComputeOpNode>() ||
- (allow_extern_op && tensor->op->IsInstance<ExternOpNode>())) {
- Array<Tensor> inputs = tensor->op->InputTensors();
- for (const Tensor& v : inputs) {
- if (!visited.count(v.get())) {
- visited.insert(v.get());
- stack.push_back(v);
- }
- }
- } else {
- return NullOpt;
- }
- }
- PrimFunc func = te::CreatePrimFuncWithConstants(args, constants);
- bool dynamic_loop_extent = false;
- PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void {
- if (const auto* loop = obj.as<tir::ForNode>()) {
- if (!loop->extent->IsInstance<IntImmNode>()) {
- dynamic_loop_extent = true;
- }
- }
- });
- if (dynamic_loop_extent) {
- return NullOpt;
- }
- return func;
-}
-
-Optional<tir::PrimFunc> DefaultTaskFilter(const Array<te::Tensor>& args,
- const Array<runtime::NDArray>& constants) {
- return DefaultTaskFilterImpl(args, constants, false);
-}
-
-Optional<tir::PrimFunc> DefaultTaskFilterAllowExtern(const Array<te::Tensor>& args,
- const Array<runtime::NDArray>& constants) {
- return DefaultTaskFilterImpl(args, constants, true);
-}
-
TVM_REGISTER_NODE_TYPE(ExtractedTaskNode);
TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask")
.set_body_typed([](String task_name, IRModule mod, Target target, Array<IRModule> dispatched,
@@ -106,14 +45,5 @@ TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask")
return ExtractedTask(task_name, mod, target, dispatched, weight);
});
-TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilter")
- .set_body_typed([](const Array<te::Tensor>& args, const Array<runtime::NDArray>& constants) {
- return DefaultTaskFilter(args, constants);
- });
-
-TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilterAllowExtern")
- .set_body_typed([](const Array<te::Tensor>& args, const Array<runtime::NDArray>& constants) {
- return DefaultTaskFilterAllowExtern(args, constants);
- });
} // namespace meta_schedule
} // namespace tvm
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 664a6a609e..db37935ec2 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -21,7 +21,6 @@
#include <dmlc/memory_io.h>
#include <tvm/arith/analyzer.h>
-#include <tvm/meta_schedule/apply_history_best.h>
#include <tvm/meta_schedule/arg_info.h>
#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/cost_model.h>
diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc
index 4f83b6eeed..213841c621 100644
--- a/src/relay/backend/task_extraction.cc
+++ b/src/relay/backend/task_extraction.cc
@@ -16,8 +16,6 @@
* specific language governing permissions and limitations
* under the License.
*/
-
-#include <tvm/meta_schedule/apply_history_best.h>
#include <tvm/meta_schedule/extracted_task.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
@@ -32,13 +30,10 @@ namespace tvm {
namespace relay {
namespace backend {
-Array<meta_schedule::ExtractedTask> ExtractTask(
- IRModule mod, Target target, Map<String, runtime::NDArray> params,
- meta_schedule::ApplyHistoryBestNode::FTEFilterFunc filter_func) {
+Array<meta_schedule::ExtractedTask> ExtractTask(IRModule mod, Target target,
+ Map<String, runtime::NDArray> params) {
using meta_schedule::ExtractedTask;
- if (filter_func == nullptr) {
- filter_func = tvm::meta_schedule::DefaultTaskFilter;
- }
+ backend::FTECompilerTIRConverter tir_converter = backend::GetTIRConverter();
backend::BindParamsInModule(mod, params);
// is_vm=true for backward compatibility
Array<Pass> pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true);
@@ -48,7 +43,7 @@ Array<meta_schedule::ExtractedTask> ExtractTask(
std::vector<ExtractedTask> tasks;
std::unordered_map<tec::CCacheKey, ExtractedTask> cache;
- PostOrderVisit(mod->Lookup("main"), [&target, &tasks, &cache, &filter_func](const Expr& exp) {
+ PostOrderVisit(mod->Lookup("main"), [&target, &tasks, &cache, &tir_converter](const Expr& exp) {
if (exp->IsInstance<FunctionNode>()) {
Function relay_func = Downcast<Function>(exp);
if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) {
@@ -62,13 +57,11 @@ Array<meta_schedule::ExtractedTask> ExtractTask(
}
auto [inputs_outputs, constants, fused_name] =
tec::LowerTECompute(relay_func, target, /*return_inputs=*/true);
- if (Optional<tir::PrimFunc> prim_func = filter_func(inputs_outputs, constants)) {
- GlobalVar prim_fn_var(fused_name);
- IRModule relay_mod({{prim_fn_var, relay_func}});
- IRModule tir_mod({{prim_fn_var, prim_func.value()}});
- ExtractedTask extracted_task(fused_name, relay_mod, target, {tir_mod}, 1);
- tasks.push_back(extracted_task);
- cache.emplace(cache_key, extracted_task);
+ if (Optional<tir::PrimFunc> f = tir_converter(inputs_outputs, constants)) {
+ IRModule relay_mod({{GlobalVar(fused_name), relay_func}});
+ ExtractedTask task(fused_name, relay_mod, target, {PrimFuncToIRModule(f.value())}, 1);
+ tasks.push_back(task);
+ cache.emplace(cache_key, task);
}
}
});
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 5c79ed2070..8fa8610c0f 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -548,6 +548,7 @@ TECompiler& TECompiler::Global() {
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule_dispatch", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.tir_converter", String);
TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() {
return TECompiler::Global();
diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc
index 92cc6f8cfa..0e2a3e2702 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -21,7 +21,7 @@
#include <tvm/driver/driver_api.h>
#include <tvm/ir/type_functor.h>
-#include <tvm/meta_schedule/apply_history_best.h>
+#include <tvm/meta_schedule/database.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
@@ -37,6 +37,7 @@
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/function.h>
#include <tvm/tir/index_map.h>
+#include <tvm/tir/schedule/schedule.h>
#include <tvm/tir/transform.h>
#include <tvm/topi/tags.h>
@@ -61,16 +62,6 @@ TVM_REGISTER_NODE_TYPE(CachedFuncNode);
TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
TVM_REGISTER_NODE_TYPE(CCacheValueNode);
-void ExtractTransformLayout(const meta_schedule::TuningRecord& record) {
- static tir::InstructionKind kind_transform_layout = tir::InstructionKind::Get("TransformLayout");
- for (const tir::Instruction& inst : record->trace->insts) {
- if (inst->kind.same_as(kind_transform_layout)) {
- ICHECK_EQ(inst->attrs.size(), 3);
- relay::MetaScheduleLayoutRewriter::LayoutQueuePush(Downcast<tir::IndexMap>(inst->attrs[2]));
- }
- }
-}
-
LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
auto n = make_object<LoweredOutputNode>();
n->outputs = std::move(outputs);
@@ -317,11 +308,11 @@ class ScheduleBuilder : public ExprVisitor {
// Whether to use auto_scheduler schedule.
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
if (backend::IsMetaScheduleEnabled()) {
- meta_schedule_ctx_ = meta_schedule::ApplyHistoryBest::Current();
- CHECK(meta_schedule_ctx_.defined()) << "ValueError: `use_meta_schedule` is enabled in Relay "
- "build, but no ApplyHistoryBest context is provided. ";
+ database_ = meta_schedule::Database::Current();
+ CHECK(database_.defined()) << "ValueError: `use_meta_schedule` is enabled in Relay "
+ "build, but no `meta_schedule.Database` context is provided. ";
} else {
- meta_schedule_ctx_ = NullOpt;
+ database_ = NullOpt;
}
}
@@ -359,32 +350,43 @@ class ScheduleBuilder : public ExprVisitor {
schedule = Downcast<te::Schedule>(obj);
}
}
- if (meta_schedule_ctx_) {
+ if (database_) {
+ using tvm::meta_schedule::TuningRecord;
+ using tvm::tir::IndexMap;
+ using tvm::tir::Instruction;
+ using tvm::tir::InstructionKind;
+ using tvm::tir::PrimFunc;
+ using tvm::tir::Schedule;
+ backend::FTECompilerTIRConverter tir_converter = backend::GetTIRConverter();
Array<te::Tensor> te_args = Concat(fn_inputs, tensor_outs);
Array<runtime::NDArray> constants;
for (auto [const_node, te_tensor] : lower_te_compute.constant_tensors_) {
te_args.push_back(te_tensor);
constants.push_back(const_node->data);
}
-
- if (Optional<tir::PrimFunc> tir_func =
- meta_schedule_ctx_.value()->te_filter_func(te_args, constants)) {
- IRModule relay_mod({{prim_fn_var, relay_func}});
- IRModule tir_mod({{prim_fn_var, tir_func.value()}});
- if (Optional<IRModule> opt_scheduled_mod = meta_schedule_ctx_.value()->Query(
- /*task_name=*/prim_fn_var->name_hint, //
- /*mod=*/relay_mod, //
- /*target=*/target_, //
- /*dispatched=*/Array<IRModule>{tir_mod}, //
- /*f_take_tuning_record=*/ExtractTransformLayout)) {
- IRModule scheduled_mod =
- tir::transform::RemoveWeightLayoutRewriteBlock()(opt_scheduled_mod.value());
- ICHECK_EQ(scheduled_mod->functions.count(prim_fn_var), 1);
- prim_func = Downcast<tir::PrimFunc>(scheduled_mod->functions[prim_fn_var]);
+ if (Optional<PrimFunc> f = tir_converter(te_args, constants)) {
+ if (Optional<TuningRecord> opt_record = database_.value()->QueryTuningRecord(
+ /*mod=*/backend::PrimFuncToIRModule(f.value()),
+ /*target=*/target_)) {
+ static InstructionKind kind_transform_layout = InstructionKind::Get("TransformLayout");
+ TuningRecord record = opt_record.value();
+ for (const Instruction& inst : record->trace->insts) {
+ if (inst->kind.same_as(kind_transform_layout)) {
+ ICHECK_EQ(inst->attrs.size(), 3);
+ MetaScheduleLayoutRewriter::LayoutQueuePush(Downcast<IndexMap>(inst->attrs[2]));
+ }
+ }
+ Schedule sch = Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0,
+ tir::ScheduleErrorRenderLevel::kDetail);
+ record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false);
+ IRModule mod = sch->mod();
+ ICHECK_EQ(mod->functions.size(), 1);
+ mod = tir::transform::RemoveWeightLayoutRewriteBlock()(std::move(mod));
+ prim_func = Downcast<PrimFunc>(mod->Lookup("main"));
}
}
}
- // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
+ // Use TOPI schedule if user specified, or the function has no auto_scheduler schedule.
if (!schedule.defined() && !prim_func.defined()) {
if (anchor_op_.defined()) {
auto anchor_impl = lower_te_compute.op_implementations_.find(anchor_op_.operator->());
@@ -422,7 +424,7 @@ class ScheduleBuilder : public ExprVisitor {
}
int op_pattern = fpattern[op];
- if (!use_auto_scheduler_ && !meta_schedule_ctx_.defined() && op_pattern >= kCommReduce) {
+ if (!use_auto_scheduler_ && !database_.defined() && op_pattern >= kCommReduce) {
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
<< " anchor=" << anchor_op_ << " current=" << op;
@@ -440,7 +442,7 @@ class ScheduleBuilder : public ExprVisitor {
Attrs anchor_attrs_;
int anchor_op_pattern_{0};
bool use_auto_scheduler_;
- Optional<meta_schedule::ApplyHistoryBest> meta_schedule_ctx_;
+ Optional<meta_schedule::Database> database_;
};
/*!
diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc
index 340986770e..5cf7a5563d 100644
--- a/src/relay/backend/utils.cc
+++ b/src/relay/backend/utils.cc
@@ -28,6 +28,9 @@
#include <tvm/parser/parser.h>
#include <tvm/relay/qnn/transform.h>
#include <tvm/runtime/ndarray.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../te/operation/create_primfunc.h"
namespace tvm {
namespace relay {
@@ -368,6 +371,76 @@ void BindParamsInModule(IRModule mod, Map<String, runtime::NDArray> params) {
BindParamsInModule(mod, params_tmp);
}
+/*!
+ * \brief A default TE compute to TIR compute.
+ * \param args The inputs/outputs of the TE compute graph.
+ * \param constants The constants bound to TIR
+ * \param allow_extern_op Whether to allow extern operation in TE.
+ * \return The TIR converted; NullOpt if not supported (dynamic shape)
+ */
+Optional<tir::PrimFunc> DefaultTIRConverterImpl(const Array<te::Tensor>& args,
+ const Array<runtime::NDArray>& constants,
+ bool allow_extern_op) {
+ using namespace ::tvm::te;
+ std::vector<Tensor> stack;
+ std::unordered_set<const TensorNode*> visited;
+ for (const Tensor& v : args) {
+ for (const PrimExpr& e : v->shape) {
+ // Dynamic shape is not supported for now
+ if (!e->IsInstance<IntImmNode>()) {
+ return NullOpt;
+ }
+ }
+ if (!visited.count(v.get())) {
+ visited.insert(v.get());
+ stack.push_back(v);
+ }
+ }
+ while (!stack.empty()) {
+ Tensor tensor = stack.back();
+ stack.pop_back();
+ if (tensor->op->IsInstance<PlaceholderOpNode>()) {
+ // do nothing
+ } else if (tensor->op->IsInstance<ComputeOpNode>() ||
+ (allow_extern_op && tensor->op->IsInstance<ExternOpNode>())) {
+ Array<Tensor> inputs = tensor->op->InputTensors();
+ for (const Tensor& v : inputs) {
+ if (!visited.count(v.get())) {
+ visited.insert(v.get());
+ stack.push_back(v);
+ }
+ }
+ } else {
+ return NullOpt;
+ }
+ }
+ PrimFunc func = te::CreatePrimFuncWithConstants(args, constants);
+ bool dynamic_loop_extent = false;
+ tir::PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void {
+ if (const auto* loop = obj.as<tir::ForNode>()) {
+ if (!loop->extent->IsInstance<IntImmNode>()) {
+ dynamic_loop_extent = true;
+ }
+ }
+ });
+ if (dynamic_loop_extent) {
+ return NullOpt;
+ }
+ return func;
+}
+
+TVM_REGISTER_GLOBAL("relay.backend.tir_converter.default")
+ .set_body_typed([](const Array<te::Tensor>& args,
+ const Array<runtime::NDArray>& constants) -> Optional<tir::PrimFunc> {
+ return DefaultTIRConverterImpl(args, constants, false);
+ });
+
+TVM_REGISTER_GLOBAL("relay.backend.tir_converter.allow_extern")
+ .set_body_typed([](const Array<te::Tensor>& args,
+ const Array<runtime::NDArray>& constants) -> Optional<tir::PrimFunc> {
+ return DefaultTIRConverterImpl(args, constants, true);
+ });
+
} // namespace backend
} // namespace relay
} // namespace tvm
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index 57c0661311..37ae9d803a 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -558,6 +558,37 @@ inline bool IsMetaScheduleEnabled() {
.value();
}
+/*!
+ * \brief Method in TECompiler to convert TE compute to scheduleable TIR
+ * \param args The arguments of the TE compute
+ * \param constants The constants used in AllocateConst
+ * \return NullOpt if conversion fails; Otherwise the converted TIR
+ * \note This method could be further used as a task filtering mechanism in task extraction
+ */
+using FTECompilerTIRConverter = runtime::TypedPackedFunc< //
+ Optional<tir::PrimFunc>( //
+ const Array<te::Tensor>& args, //
+ const Array<runtime::NDArray>& constants)>;
+
+/*! \brief Return a task filter for AutoTIR according to `relay.backend.tir_converter` */
+inline FTECompilerTIRConverter GetTIRConverter() {
+ String name = transform::PassContext::Current()
+ ->GetConfig<String>("relay.backend.tir_converter", "default")
+ .value();
+ const PackedFunc* f = runtime::Registry::Get("relay.backend.tir_converter." + name);
+ ICHECK(f != nullptr) << "IndexError: Cannot find TIR converter: " << name;
+ return FTECompilerTIRConverter(*f);
+}
+
+/*! \brief Converts a PrimFunc to IRModule. */
+inline IRModule PrimFuncToIRModule(tir::PrimFunc f) {
+ f = WithAttrs(f, Map<String, ObjectRef>{
+ {tvm::attr::kGlobalSymbol, String("main")},
+ {tvm::tir::attr::kNoAlias, Bool(1)},
+ });
+ return IRModule({{GlobalVar("main"), f}});
+}
+
/*!
* \brief Get the sequence of Relay optimization passes based on backend type.
* The prefix of the Relay passes almost overlaps between the vm and graph backend, with some slight
diff --git a/tests/python/integration/test_meta_schedule_auto_tensorize.py b/tests/python/integration/test_meta_schedule_auto_tensorize.py
index 3397eaabbe..7227ef0c7b 100644
--- a/tests/python/integration/test_meta_schedule_auto_tensorize.py
+++ b/tests/python/integration/test_meta_schedule_auto_tensorize.py
@@ -19,13 +19,12 @@ import tempfile
import numpy as np
import pytest
-
import tvm
import tvm.testing
import tvm.topi.testing
from tvm import meta_schedule as ms
from tvm import relay
-from tvm.meta_schedule import ApplyHistoryBest, postproc, schedule_rule
+from tvm.meta_schedule import postproc, schedule_rule
from tvm.meta_schedule.relay_integration import extract_task_from_relay
from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base
from tvm.meta_schedule.tune import tune_extracted_tasks
@@ -176,12 +175,11 @@ def tune_and_test(relay_mod, data_np, weight_np, op_name, target, sch_rules, pos
postprocs=lambda: postprocs,
)
- with ApplyHistoryBest(database):
- with tvm.transform.PassContext(
- opt_level=3,
- config={"relay.backend.use_meta_schedule": True},
- ):
- lib = relay.build(relay_mod, target=target, params=params)
+ with database, tvm.transform.PassContext(
+ opt_level=3,
+ config={"relay.backend.use_meta_schedule": True},
+ ):
+ lib = relay.build(relay_mod, target=target, params=params)
if "cascadelake" in target:
asm = lib.lib.get_source("asm")
@@ -267,12 +265,11 @@ def _test_bert_int8(target, sch_rules, postprocs):
postprocs=lambda: postprocs,
)
- with ApplyHistoryBest(database):
- with tvm.transform.PassContext(
- opt_level=3,
- config={"relay.backend.use_meta_schedule": True},
- ):
- lib = relay.build(relay_mod, target=target, params=params)
+ with database, tvm.transform.PassContext(
+ opt_level=3,
+ config={"relay.backend.use_meta_schedule": True},
+ ):
+ lib = relay.build(relay_mod, target=target, params=params)
dev = tvm.device("cuda" if "nvidia" in target else target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py
index 8e299dc935..c741ecb59a 100644
--- a/tests/python/unittest/test_link_params.py
+++ b/tests/python/unittest/test_link_params.py
@@ -19,20 +19,18 @@ import ctypes
import json
import os
import re
-from io import StringIO
from contextlib import redirect_stderr
+from io import StringIO
import numpy as np
-
import tvm
import tvm.relay
import tvm.testing
from tvm import meta_schedule as ms
from tvm import relay
-from tvm.relay.backend import Executor, Runtime
from tvm.contrib import utils
from tvm.meta_schedule.testing.utils import apply_fixed_schedules
-
+from tvm.relay.backend import Executor, Runtime
INPUT_SHAPE = (1, 3, 16, 16)
@@ -421,13 +419,12 @@ def test_tir_link_params():
database = apply_fixed_schedules(relay_mod, target, params, schedule_fn)
with StringIO() as stderr_buf, redirect_stderr(stderr_buf):
- with ms.ApplyHistoryBest(database):
- with tvm.transform.PassContext(
- opt_level=3,
- config={"relay.backend.use_meta_schedule": True},
- ):
- executor = Executor("graph", {"link-params": link_params})
- lib = relay.build(relay_mod, target=target, executor=executor)
+ with database, tvm.transform.PassContext(
+ opt_level=3,
+ config={"relay.backend.use_meta_schedule": True},
+ ):
+ executor = Executor("graph", {"link-params": link_params})
+ lib = relay.build(relay_mod, target=target, executor=executor)
# Workload look up should succeed. This does not work when the test is invoked from pytest.
assert not "Cannot find workload" in stderr_buf.getvalue()
diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py
index afce19a590..69522831ee 100644
--- a/tests/python/unittest/test_meta_schedule_integration.py
+++ b/tests/python/unittest/test_meta_schedule_integration.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Integration test for MetaSchedule"""
-from typing import Optional
import numpy as np
import pytest
import tvm
@@ -23,11 +22,10 @@ import tvm.testing
from tvm import IRModule
from tvm import meta_schedule as ms
from tvm import relay, te, tir
+from tvm._ffi import register_func
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base
from tvm.script import tir as T
-from tvm.target import Target
-from tvm.tir import Schedule
# pylint: disable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument,missing-docstring,invalid-name
@@ -58,10 +56,6 @@ def _has_torch():
requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed")
-def test_meta_schedule_apply_history_best_no_current():
- assert ms.ApplyHistoryBest.current() is None
-
-
def test_meta_schedule_dynamic_loop_extent():
a = relay.var("a", shape=(1, 8, 8, 512), dtype="float32")
b = relay.nn.adaptive_avg_pool2d(a, (7, 7), "NHWC")
@@ -125,7 +119,7 @@ def test_meta_schedule_integration_extract_from_bert_base():
12,
[[64, 768], [3072, 768], [64, 3072]],
),
- "fused_subtract_add_sqrt_divide_multiply_add": (
+ "fused_subtract_add_rsqrt_multiply_multiply_add": (
25,
[[1, 64, 768], [1, 64, 1], [1, 64, 1], [768], [768], [1, 64, 768]],
),
@@ -206,7 +200,8 @@ def test_meta_schedule_integration_extract_from_bert_base():
@requires_torch
def test_meta_schedule_integration_extract_from_resnet_with_filter_func():
- def filter_func(args) -> bool:
+ @register_func("relay.backend.tir_converter.remove_purely_spatial", override=True)
+ def filter_func(args, _) -> bool:
from tvm.te import create_prim_func # pylint: disable=import-outside-toplevel
has_complex_op = False
@@ -236,7 +231,7 @@ def test_meta_schedule_integration_extract_from_resnet_with_filter_func():
mod,
target="llvm",
params=params,
- te_filter_func=filter_func,
+ tir_converter="remove_purely_spatial",
)
expected_task_names = [
"fused_" + s
@@ -267,53 +262,6 @@ def test_meta_schedule_integration_extract_from_resnet_with_filter_func():
assert t.task_name in expected_task_names, t.task_name
-@requires_torch
-def test_meta_schedule_integration_apply_history_best():
- mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
- database = ms.database.MemoryDatabase()
- env = ms.ApplyHistoryBest(database)
- target = Target("llvm")
- workload = database.commit_workload(MockModule)
- database.commit_tuning_record(
- ms.database.TuningRecord(
- trace=Schedule(MockModule).trace,
- workload=workload,
- run_secs=[1.0],
- target=target,
- args_info=[],
- )
- )
- mod = env.query(
- task_name="mock-task",
- mod=mod,
- target=target,
- dispatched=[MockModule],
- )
- assert tvm.ir.structural_equal(mod, workload.mod)
-
-
-@requires_torch
-def test_meta_schedule_integration_apply_history_best_direct_dispatch():
- def direct_dispatch(mod: IRModule) -> Optional[IRModule]:
- if tvm.ir.structural_equal(mod, MockModule):
- return MockModule
- return None
-
- mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
- database = ms.database.MemoryDatabase()
- env = ms.ApplyHistoryBest(database)
- target = Target("llvm")
- workload = database.commit_workload(MockModule)
- mod = env.query(
- task_name="mock-task-direct-dispatch",
- mod=mod,
- target=target,
- dispatched=[MockModule],
- f_direct_dispatch=direct_dispatch,
- )
- assert tvm.ir.structural_equal(mod, workload.mod)
-
-
@pytest.mark.skip("Too slow on CI")
def extract_task_qbert():
mod, params, _ = load_quantized_bert_base(batch_size=1, seq_len=128)
diff --git a/tests/python/unittest/test_meta_schedule_multi_anchor.py b/tests/python/unittest/test_meta_schedule_multi_anchor.py
index b7d012ca04..1770017811 100644
--- a/tests/python/unittest/test_meta_schedule_multi_anchor.py
+++ b/tests/python/unittest/test_meta_schedule_multi_anchor.py
@@ -70,7 +70,7 @@ def test_dense_dense():
return False
database = apply_fixed_schedules(relay_mod, target, params, schedule_fn)
- with ms.ApplyHistoryBest(database):
+ with database:
with tvm.transform.PassContext(
opt_level=3,
config={"relay.backend.use_meta_schedule": True},
diff --git a/tests/python/unittest/test_meta_schedule_relay_tir_compute.py b/tests/python/unittest/test_meta_schedule_relay_tir_compute.py
index 058012cb64..939851a657 100644
--- a/tests/python/unittest/test_meta_schedule_relay_tir_compute.py
+++ b/tests/python/unittest/test_meta_schedule_relay_tir_compute.py
@@ -19,7 +19,6 @@ import tvm
import tvm.testing
import tvm.topi.testing
from tvm import autotvm, relay, te
-from tvm.meta_schedule import ApplyHistoryBest
from tvm.meta_schedule.testing.utils import apply_fixed_schedules
from tvm.relay.testing.temp_op_attr import TempOpAttr
from tvm.script import tir as T
@@ -152,17 +151,16 @@ def test_conv2d():
target,
params,
schedule_fn,
- te_filter_func="meta_schedule.DefaultTaskFilterAllowExtern",
+ tir_converter="allow_extern",
)
- with ApplyHistoryBest(
- database,
- te_filter_func="meta_schedule.DefaultTaskFilterAllowExtern",
+ with database, tvm.transform.PassContext(
+ opt_level=3,
+ config={
+ "relay.backend.use_meta_schedule": True,
+ "relay.backend.tir_converter": "allow_extern",
+ },
):
- with tvm.transform.PassContext(
- opt_level=3,
- config={"relay.backend.use_meta_schedule": True},
- ):
- lib = relay.build(relay_mod, target=target, params=params)
+ lib = relay.build(relay_mod, target=target, params=params)
dev = tvm.device(target, 0)
diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py
index 7d85b8757a..bc37fed7d6 100644
--- a/tests/python/unittest/test_meta_schedule_tune_relay.py
+++ b/tests/python/unittest/test_meta_schedule_tune_relay.py
@@ -245,12 +245,11 @@ def test_meta_schedule_te2primfunc_argument_order():
database.commit_workload(tvmgen_default_fused_layout_transform_1)
database.commit_workload(tvmgen_default_fused_nn_contrib_conv2d_NCHWc)
- with ms.ApplyHistoryBest(database):
- with tvm.transform.PassContext(
- opt_level=3,
- config={"relay.backend.use_meta_schedule": True},
- ):
- rt_mod1 = relay.build(mod, target=target, params=params)
+ with database, tvm.transform.PassContext(
+ opt_level=3,
+ config={"relay.backend.use_meta_schedule": True},
+ ):
+ rt_mod1 = relay.build(mod, target=target, params=params)
# Compile without meta-schedule for correctness check
with tvm.transform.PassContext(opt_level=0):
@@ -307,12 +306,11 @@ def test_meta_schedule_relay_lowering():
args_info=[],
)
)
- with ms.ApplyHistoryBest(database):
- with tvm.transform.PassContext(
- opt_level=3,
- config={"relay.backend.use_meta_schedule": True},
- ):
- rt_mod1 = relay.build(mod, target=target, params=params)
+ with database, tvm.transform.PassContext(
+ opt_level=3,
+ config={"relay.backend.use_meta_schedule": True},
+ ):
+ rt_mod1 = relay.build(mod, target=target, params=params)
# Compile without meta-schedule for correctness check
with tvm.transform.PassContext(opt_level=0):
@@ -472,24 +470,23 @@ def manual_tir_common(do_tune=False):
database = apply_fixed_schedules(relay_mod, target, params, schedule_fn)
- with ms.ApplyHistoryBest(database):
- with tvm.transform.PassContext(
- opt_level=3,
- config={"relay.backend.use_meta_schedule": True},
- ):
- # pylint: disable=W0105
- """
- The log should say
- Warning: Cannot find workload: tvmgen_default_fused_expand_dims
- Warning: Cannot find workload: tvmgen_default_fused_cast
- Warning: Cannot find workload: tvmgen_default_fused_cast_1
- Warning: Cannot find workload: tvmgen_default_fused_nn_batch_matmul
-
- This means batch matmul and others are scheduled by TE, and dense (the one not warned)
- is found in the meta schedule tuning database during ApplyHistoryBest
- """
- # pylint: enable=W0105
- lib = relay.build(relay_mod, target=target, params=params)
+ with database, tvm.transform.PassContext(
+ opt_level=3,
+ config={"relay.backend.use_meta_schedule": True},
+ ):
+ # pylint: disable=W0105
+ """
+ The log should say
+ Warning: Cannot find workload: tvmgen_default_fused_expand_dims
+ Warning: Cannot find workload: tvmgen_default_fused_cast
+ Warning: Cannot find workload: tvmgen_default_fused_cast_1
+ Warning: Cannot find workload: tvmgen_default_fused_nn_batch_matmul
+
+ This means batch matmul and others are scheduled by TE, and dense (the one not warned)
+ is found in the meta schedule tuning database during compilation
+ """
+ # pylint: enable=W0105
+ lib = relay.build(relay_mod, target=target, params=params)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))