You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/01/30 05:51:14 UTC

[tvm] branch main updated: [MetaSchedule][M4a] User-API: Tune-TE/TIR/Relay (#10079)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 779dc51  [MetaSchedule][M4a] User-API: Tune-TE/TIR/Relay (#10079)
779dc51 is described below

commit 779dc51e1332f417fa4c304b595ce76891dfc33a
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Sat Jan 29 21:50:24 2022 -0800

    [MetaSchedule][M4a] User-API: Tune-TE/TIR/Relay (#10079)
    
    * Add tuning scripts for tir, te & relay.
    
    Co-authored-by: Junru Shao <ju...@gmail.com>
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Ruihang Lai <la...@qq.com>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    
    Minor fix.
    
    Nits.
    
    Add back tests.
    
    * slightly improve tune.py
    
    Co-authored-by: Junru Shao <ju...@gmail.com>
---
 python/tvm/meta_schedule/__init__.py               |  13 +-
 python/tvm/meta_schedule/integration.py            |   2 +-
 python/tvm/meta_schedule/testing/__init__.py       |   4 +-
 python/tvm/meta_schedule/testing/relay_workload.py |  80 +++
 python/tvm/meta_schedule/tune.py                   | 719 +++++++++++++++++++++
 python/tvm/meta_schedule/utils.py                  |  28 +
 src/meta_schedule/integration.cc                   |   4 +-
 src/meta_schedule/task_scheduler/task_scheduler.cc |   4 +-
 src/meta_schedule/utils.h                          |  16 -
 src/tir/schedule/primitive/for_kind.cc             |   3 +-
 .../unittest/test_meta_schedule_integration.py     |   2 +-
 .../unittest/test_meta_schedule_tune_relay.py      | 151 +++++
 .../python/unittest/test_meta_schedule_tune_te.py  |  52 ++
 .../python/unittest/test_meta_schedule_tune_tir.py | 218 +++++++
 14 files changed, 1270 insertions(+), 26 deletions(-)

diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py
index e41e5b3..2a69d3c 100644
--- a/python/tvm/meta_schedule/__init__.py
+++ b/python/tvm/meta_schedule/__init__.py
@@ -19,10 +19,19 @@ from . import arg_info
 from . import database
 from . import builder
 from . import runner
+from . import mutator
+from . import postproc
+from . import schedule_rule
 from . import space_generator
 from . import search_strategy
-from . import schedule_rule
 from . import integration
 from . import feature_extractor
+from . import cost_model
+from .search_strategy import (
+    EvolutionarySearchConfig,
+    MeasureCandidate,
+    ReplayFuncConfig,
+    ReplayTraceConfig,
+)
+from .tune import tune_te, tune_tir, tune_relay
 from .tune_context import TuneContext
-from .search_strategy import MeasureCandidate
diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py
index 794591c..727c7fe 100644
--- a/python/tvm/meta_schedule/integration.py
+++ b/python/tvm/meta_schedule/integration.py
@@ -184,7 +184,7 @@ class ApplyHistoryBest(MetaScheduleContext):
         self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database)  # type: ignore # pylint: disable=no-member
 
 
-def extract_task(
+def extract_task_from_relay(
     mod: Union[IRModule, RelayFunc],
     target: Target,
     params: Optional[Dict[str, NDArray]] = None,
diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py
index a5291f7..85b48b3 100644
--- a/python/tvm/meta_schedule/testing/__init__.py
+++ b/python/tvm/meta_schedule/testing/__init__.py
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """Testing utilities in meta schedule"""
-from .local_rpc import LocalRPC
-from .relay_workload import get_network
 from .byoc_trt import relay_build_with_tensorrt
+from .local_rpc import LocalRPC
+from .relay_workload import MODEL_TYPE, MODEL_TYPES, get_network, get_torch_model
diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py
index 1eb9950..2f1ffdd 100644
--- a/python/tvm/meta_schedule/testing/relay_workload.py
+++ b/python/tvm/meta_schedule/testing/relay_workload.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Workloads in Relay IR"""
+from enum import Enum
 from typing import Dict, Tuple
 
 import tvm.relay.testing  # pylint: disable=unused-import
@@ -22,6 +23,85 @@ from tvm import relay
 from tvm.ir import IRModule
 from tvm.runtime import NDArray
 
+# Model types supported in Torchvision
+class MODEL_TYPE(Enum):  # pylint: disable=invalid-name
+    IMAGE_CLASSIFICATION = (1,)
+    VIDEO_CLASSIFICATION = (2,)
+    SEGMENTATION = (3,)
+    OBJECT_DETECTION = (4,)
+    TEXT_CLASSIFICATION = (5,)
+
+
+# Specify the type of each model
+MODEL_TYPES = {
+    "resnet18": MODEL_TYPE.IMAGE_CLASSIFICATION,
+    "mobilenet_v2": MODEL_TYPE.IMAGE_CLASSIFICATION,
+    "bert_base": MODEL_TYPE.TEXT_CLASSIFICATION,
+}
+
+
+def get_torch_model(
+    model_name: str,
+    input_shape: Tuple[int, ...],
+    output_shape: Tuple[int, int],  # pylint: disable=unused-argument
+    dtype: str = "float32",
+) -> Tuple[IRModule, Dict[str, NDArray]]:
+    """Load model from torch model zoo
+    Parameters
+    ----------
+    model_name : str
+        The name of the model to load
+    input_shape: Tuple[int, ...]
+        Tuple for input shape
+    output_shape: Tuple[int, int]
+        Tuple for output shape
+    dtype: str
+        Tensor data type
+    """
+
+    assert dtype == "float32"
+
+    import torch  # type: ignore # pylint: disable=import-error,import-outside-toplevel
+    from torchvision import models  # type: ignore # pylint: disable=import-error,import-outside-toplevel
+    import transformers  # type: ignore # pylint: disable=import-error,import-outside-toplevel
+    import os  # type: ignore # pylint: disable=import-error,import-outside-toplevel
+
+    def do_trace(model, inp):
+        model.eval()
+        model_trace = torch.jit.trace(model, inp)
+        model_trace.eval()
+        return model_trace
+
+    # Load model from torchvision
+    if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
+        os.environ["TOKENIZERS_PARALLELISM"] = "false"
+        model = transformers.BertModel(
+            transformers.BertConfig(
+                num_hidden_layers=12,
+                hidden_size=768,
+                intermediate_size=3072,
+                num_attention_heads=12,
+                return_dict=False,
+            )
+        )
+        model.eval()
+        input_data = torch.randint(10000, input_shape)
+        shape_list = [("input_ids", input_shape)]
+        scripted_model = torch.jit.trace(model, [input_data], strict=False)
+    elif MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION:
+        model = getattr(models, model_name)()
+        # Setup input
+        input_data = torch.randn(input_shape).type(torch.float32)
+        shape_list = [("input0", input_shape)]
+        # Get trace. Depending on the model type, wrapper may be necessary.
+        scripted_model = do_trace(model, input_data)
+    else:
+        raise ValueError("Unsupported model in Torch model zoo.")
+
+    # Convert torch model to relay module
+    mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
+    return mod, params
+
 
 def get_network(
     name: str,
diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py
new file mode 100644
index 0000000..faf61f5
--- /dev/null
+++ b/python/tvm/meta_schedule/tune.py
@@ -0,0 +1,719 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""User-facing Tuning API"""
+
+import logging
+import os.path
+from typing import Callable, Dict, List, Optional, Union
+
+import tvm
+from tvm import relay
+from tvm._ffi import register_func
+from tvm.ir import IRModule, structural_equal, structural_hash
+from tvm.relay import Function as RelayFunc
+from tvm.runtime import Module, NDArray
+from tvm.target import Target
+from tvm.te import Tensor, create_prim_func
+from tvm.tir import PrimFunc, Schedule
+
+from .builder import Builder, LocalBuilder
+from .cost_model import CostModel, XGBModel
+from .database import Database, JSONDatabase, TuningRecord
+from .feature_extractor import PerStoreFeature
+from .integration import ApplyHistoryBest, extract_task_from_relay
+from .measure_callback import MeasureCallback
+from .mutator import Mutator
+from .postproc import Postproc
+from .runner import LocalRunner, Runner
+from .schedule_rule import ScheduleRule
+from .search_strategy import (
+    EvolutionarySearchConfig,
+    ReplayFuncConfig,
+    ReplayTraceConfig,
+)
+from .space_generator import PostOrderApply, SpaceGenerator
+from .task_scheduler import RoundRobin, TaskScheduler
+from .tune_context import TuneContext
+
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+SearchStrategyConfig = Union[
+    ReplayFuncConfig,
+    ReplayTraceConfig,
+    EvolutionarySearchConfig,
+]
+FnSpaceGenerator = Callable[[], SpaceGenerator]
+FnScheduleRule = Callable[[], List[ScheduleRule]]
+FnPostproc = Callable[[], List[Postproc]]
+FnMutatorProb = Callable[[], Dict[Mutator, float]]
+FnTaskScheduler = Callable[
+    [
+        List[TuneContext],
+        Builder,
+        Runner,
+        Database,
+        CostModel,
+        List[MeasureCallback],
+    ],
+    TaskScheduler,
+]
+
+
+class DefaultLLVM:
+    """Default tuning configuration for LLVM."""
+
+    @staticmethod
+    def _sch_rules() -> List[ScheduleRule]:
+        from tvm.meta_schedule import (  # pylint: disable=import-outside-toplevel
+            schedule_rule as M,
+        )
+
+        return [
+            M.AutoInline(
+                into_producer=False,
+                into_consumer=True,
+                inline_const_tensor=True,
+                disallow_if_then_else=True,
+                require_injective=True,
+                require_ordered=True,
+                disallow_op=["tir.exp"],
+            ),
+            M.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64),
+            M.MultiLevelTiling(
+                structure="SSRSRS",
+                tile_binds=None,
+                max_innermost_factor=64,
+                vector_load_lens=None,
+                reuse_read=None,
+                reuse_write=M.ReuseType(
+                    req="may",
+                    levels=[1, 2],
+                    scope="global",
+                ),
+            ),
+            M.ParallelizeVectorizeUnroll(
+                max_jobs_per_core=16,
+                max_vectorize_extent=64,
+                unroll_max_steps=[0, 16, 64, 512],
+                unroll_explicit=True,
+            ),
+            M.RandomComputeLocation(),
+        ]
+
+    @staticmethod
+    def _postproc() -> List[Postproc]:
+        from tvm.meta_schedule import (  # pylint: disable=import-outside-toplevel
+            postproc as M,
+        )
+
+        return [
+            M.DisallowDynamicLoop(),
+            M.RewriteParallelVectorizeUnroll(),
+            M.RewriteReductionBlock(),
+        ]
+
+    @staticmethod
+    def _mutator_probs() -> Dict[Mutator, float]:
+        from tvm.meta_schedule import (  # pylint: disable=import-outside-toplevel
+            mutator as M,
+        )
+
+        return {
+            M.MutateTileSize(): 0.9,
+            M.MutateComputeLocation(): 0.05,
+            M.MutateUnroll(): 0.03,
+            M.MutateParallel(max_jobs_per_core=16): 0.02,
+        }
+
+
+class DefaultCUDA:
+    """Default tuning configuration for CUDA."""
+
+    @staticmethod
+    def _sch_rules() -> List[ScheduleRule]:
+        from tvm.meta_schedule import (  # pylint: disable=import-outside-toplevel
+            schedule_rule as M,
+        )
+
+        return [
+            M.MultiLevelTiling(
+                structure="SSSRRSRS",
+                tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
+                max_innermost_factor=64,
+                vector_load_lens=[1, 2, 3, 4],
+                reuse_read=M.ReuseType(
+                    req="must",
+                    levels=[4],
+                    scope="shared",
+                ),
+                reuse_write=M.ReuseType(
+                    req="must",
+                    levels=[3],
+                    scope="local",
+                ),
+            ),
+            M.AutoInline(
+                into_producer=True,
+                into_consumer=True,
+                # into_cache_only=False,
+                inline_const_tensor=True,
+                disallow_if_then_else=False,
+                require_injective=False,
+                require_ordered=False,
+                disallow_op=None,
+            ),
+            M.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]),
+            M.ParallelizeVectorizeUnroll(
+                max_jobs_per_core=-1,  # disable parallelize
+                max_vectorize_extent=-1,  # disable vectorize
+                unroll_max_steps=[0, 16, 64, 512, 1024],
+                unroll_explicit=True,
+            ),
+        ]
+
+    @staticmethod
+    def _postproc() -> List[Postproc]:
+        from tvm.meta_schedule import (  # pylint: disable=import-outside-toplevel
+            postproc as M,
+        )
+
+        return [
+            M.DisallowDynamicLoop(),
+            M.RewriteCooperativeFetch(),
+            M.RewriteUnboundBlock(),
+            M.RewriteParallelVectorizeUnroll(),
+            M.RewriteReductionBlock(),
+            M.VerifyGPUCode(),
+        ]
+
+    @staticmethod
+    def _mutator_probs() -> Dict[Mutator, float]:
+        from tvm.meta_schedule import (  # pylint: disable=import-outside-toplevel
+            mutator as M,
+        )
+
+        return {
+            # M.MutateTileSize(): 0.9,
+            M.MutateUnroll(): 0.1,
+        }
+
+
+class Parse:
+    """Parse tuning configuration from user inputs."""
+
+    @staticmethod
+    @register_func("tvm.meta_schedule.tune.parse_mod")  # for use in ApplyHistoryBest
+    def _mod(mod: Union[PrimFunc, IRModule]) -> IRModule:
+        if isinstance(mod, PrimFunc):
+            mod = mod.with_attr("global_symbol", "main")
+            mod = mod.with_attr("tir.noalias", True)
+            mod = IRModule({"main": mod})
+        if not isinstance(mod, IRModule):
+            raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}")
+        # in order to make sure the mod can be found in ApplyHistoryBest
+        # different func name can cause structural unequal
+        func_names = mod.get_global_vars()
+        (func_name,) = func_names
+        if len(func_names) == 1 and func_name != "main":
+            mod = IRModule({"main": mod[func_name]})
+        return mod
+
+    @staticmethod
+    def _target(target: Union[str, Target]) -> Target:
+        if isinstance(target, str):
+            target = Target(target)
+        if not isinstance(target, Target):
+            raise TypeError(f"Expected `target` to be str or Target, but gets: {target}")
+        return target
+
+    @staticmethod
+    def _builder(builder: Optional[Builder]) -> Builder:
+        if builder is None:
+            builder = LocalBuilder()
+        if not isinstance(builder, Builder):
+            raise TypeError(f"Expected `builder` to be Builder, but gets: {builder}")
+        return builder
+
+    @staticmethod
+    def _runner(runner: Optional[Runner]) -> Runner:
+        if runner is None:
+            runner = LocalRunner()
+        if not isinstance(runner, Runner):
+            raise TypeError(f"Expected `runner` to be Runner, but gets: {runner}")
+        return runner
+
+    @staticmethod
+    def _database(database: Union[None, Database], task_name: str, path: str) -> Database:
+        if database is None:
+            path_workload = os.path.join(path, f"{task_name}_database_workload.json")
+            path_tuning_record = os.path.join(path, f"{task_name}_database_tuning_record.json")
+            logger.info(
+                "Creating JSONDatabase. Workload at: %s. Tuning records at: %s",
+                path_workload,
+                path_tuning_record,
+            )
+            database = JSONDatabase(
+                path_workload=path_workload,
+                path_tuning_record=path_tuning_record,
+            )
+        if not isinstance(database, Database):
+            raise TypeError(f"Expected `database` to be Database, but gets: {database}")
+        return database
+
+    @staticmethod
+    def _callbacks(
+        measure_callbacks: Optional[List[MeasureCallback]],
+    ) -> List[MeasureCallback]:
+        if measure_callbacks is None:
+            from tvm.meta_schedule import (  # pylint: disable=import-outside-toplevel
+                measure_callback as M,
+            )
+
+            return [
+                M.AddToDatabase(),
+                M.RemoveBuildArtifact(),
+                M.EchoStatistics(),
+                M.UpdateCostModel(),
+            ]
+        if not isinstance(measure_callbacks, (list, tuple)):
+            raise TypeError(
+                f"Expected `measure_callbacks` to be List[MeasureCallback], "
+                f"but gets: {measure_callbacks}"
+            )
+        measure_callbacks = list(measure_callbacks)
+        for i, callback in enumerate(measure_callbacks):
+            if not isinstance(callback, MeasureCallback):
+                raise TypeError(
+                    f"Expected `measure_callbacks` to be List[MeasureCallback], "
+                    f"but measure_callbacks[{i}] is: {callback}"
+                )
+        return measure_callbacks
+
+    @staticmethod
+    def _cost_model(cost_model: Optional[CostModel]) -> CostModel:
+        if cost_model is None:
+            return XGBModel(extractor=PerStoreFeature())
+        if not isinstance(cost_model, CostModel):
+            raise TypeError(f"Expected `cost_model` to be CostModel, but gets: {cost_model}")
+        return cost_model
+
+    @staticmethod
+    def _space_generator(space_generator: Optional[FnSpaceGenerator]) -> SpaceGenerator:
+        if space_generator is None:
+            return PostOrderApply()
+        if callable(space_generator):
+            space_generator = space_generator()
+        if not isinstance(space_generator, SpaceGenerator):
+            raise TypeError(
+                f"Expected `space_generator` to return SpaceGenerator, "
+                f"but gets: {space_generator}"
+            )
+        return space_generator
+
+    @staticmethod
+    def _sch_rules(sch_rules: Optional[FnScheduleRule], target: Target) -> List[ScheduleRule]:
+        if callable(sch_rules):
+            return sch_rules()
+        if sch_rules is not None:
+            raise TypeError(f"Expected `sch_rules` to be None or callable, but gets: {sch_rules}")
+        # pylint: disable=protected-access
+        if target.kind.name == "llvm":
+            return DefaultLLVM._sch_rules()
+        if target.kind.name == "cuda":
+            return DefaultCUDA._sch_rules()
+        # pylint: enable=protected-access
+        raise ValueError(f"Unsupported target: {target}")
+
+    @staticmethod
+    def _postproc(postproc: Optional[FnPostproc], target: Target) -> List[Postproc]:
+        if callable(postproc):
+            return postproc()
+        if postproc is not None:
+            raise TypeError(f"Expected `postproc` to be None or callable, but gets: {postproc}")
+        # pylint: disable=protected-access
+        if target.kind.name == "llvm":
+            return DefaultLLVM._postproc()
+        if target.kind.name == "cuda":
+            return DefaultCUDA._postproc()
+        # pylint: enable=protected-access
+        raise ValueError(f"Unsupported target: {target}")
+
+    @staticmethod
+    def _mutator_probs(
+        mutator_probs: Optional[FnMutatorProb],
+        target: Target,
+    ) -> Dict[Mutator, float]:
+        if callable(mutator_probs):
+            return mutator_probs()
+        if mutator_probs is not None:
+            raise TypeError(
+                f"Expected `mutator_probs` to be None or callable, but gets: {mutator_probs}"
+            )
+        # pylint: disable=protected-access
+        if target.kind.name == "llvm":
+            return DefaultLLVM._mutator_probs()
+        if target.kind.name == "cuda":
+            return DefaultCUDA._mutator_probs()
+        # pylint: enable=protected-access
+        raise ValueError(f"Unsupported target: {target}")
+
+    @staticmethod
+    def _tune_context(
+        tune_context: Optional[TuneContext],
+        mod: IRModule,
+        target: Target,
+        config: SearchStrategyConfig,
+        task_name: str,
+        space_generator: Optional[FnSpaceGenerator],
+        sch_rules: Optional[FnScheduleRule],
+        postprocs: Optional[FnPostproc],
+        mutator_probs: Optional[FnMutatorProb],
+        num_threads: Optional[int],
+    ) -> TuneContext:
+        if tune_context is None:
+            return TuneContext(
+                mod=mod,
+                target=target,
+                # pylint: disable=protected-access
+                space_generator=Parse._space_generator(space_generator),
+                search_strategy=config.create_strategy(),
+                sch_rules=Parse._sch_rules(sch_rules, target),
+                postprocs=Parse._postproc(postprocs, target),
+                mutator_probs=Parse._mutator_probs(mutator_probs, target),
+                # pylint: enable=protected-access
+                task_name=task_name,
+                rand_state=-1,
+                num_threads=num_threads,
+            )
+        if not isinstance(tune_context, TuneContext):
+            raise TypeError(f"Expected `tune_context` to be TuneContext, but gets: {tune_context}")
+        return tune_context
+
+    @staticmethod
+    def _task_scheduler(
+        task_scheduler: Union[None, TaskScheduler, FnTaskScheduler],
+        tasks: List[TuneContext],
+        builder: Builder,
+        runner: Runner,
+        database: Database,
+        cost_model: CostModel,
+        measure_callbacks: List[MeasureCallback],
+    ):
+        if task_scheduler is None:
+            return RoundRobin(
+                tasks=tasks,
+                builder=builder,
+                runner=runner,
+                database=database,
+                cost_model=cost_model,
+                measure_callbacks=measure_callbacks,
+            )
+        if callable(task_scheduler):
+            return task_scheduler(
+                tasks,
+                builder,
+                runner,
+                database,
+                cost_model,
+                measure_callbacks,
+            )
+        if not isinstance(task_scheduler, TaskScheduler):
+            raise TypeError(
+                f"Expected `task_scheduler` to be TaskScheduler, but gets: {task_scheduler}"
+            )
+        return task_scheduler
+
+
+def tune_tir(
+    mod: Union[IRModule, PrimFunc],
+    target: Union[str, Target],
+    config: SearchStrategyConfig,
+    work_dir: str,
+    *,
+    task_name: str = "main",
+    builder: Optional[Builder] = None,
+    runner: Optional[Runner] = None,
+    database: Optional[Database] = None,
+    cost_model: Optional[CostModel] = None,
+    measure_callbacks: Optional[List[MeasureCallback]] = None,
+    task_scheduler: Optional[TaskScheduler] = None,
+    space: Optional[FnSpaceGenerator] = None,
+    sch_rules: Optional[FnScheduleRule] = None,
+    postprocs: Optional[FnPostproc] = None,
+    mutator_probs: Optional[FnMutatorProb] = None,
+    num_threads: Optional[int] = None,
+) -> Optional[Schedule]:
+    """Tune a TIR IRModule with a given target.
+
+    Parameters
+    ----------
+    mod : Union[IRModule, PrimFunc]
+        The module to tune.
+    target : Union[str, Target]
+        The target to tune for.
+    config : SearchStrategyConfig
+        The search strategy config.
+    task_name : str
+        The name of the task.
+    work_dir : Optional[str]
+        The working directory to save intermediate results.
+    builder : Optional[Builder]
+        The builder to use.
+    runner : Optional[Runner]
+        The runner to use.
+    database : Optional[Database]
+        The database to use.
+    cost_model : Optional[CostModel]
+        The cost model to use.
+    measure_callbacks : Optional[List[MeasureCallback]]
+        The callbacks used during tuning.
+    f_tune_context : Optional[TYPE_F_TUNE_CONTEXT]
+        The function to create TuneContext.
+    f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER]
+        The function to create TaskScheduler.
+
+    Returns
+    -------
+    sch : Optional[Schedule]
+        The tuned schedule.
+    """
+
+    logger.info("Working directory: %s", work_dir)
+    # pylint: disable=protected-access
+    mod = Parse._mod(mod)
+    database = Parse._database(database, task_name, work_dir)
+    tune_context = Parse._tune_context(
+        tune_context=None,
+        mod=mod,
+        target=Parse._target(target),
+        config=config,
+        task_name=task_name,
+        space_generator=space,
+        sch_rules=sch_rules,
+        postprocs=postprocs,
+        mutator_probs=mutator_probs,
+        num_threads=num_threads,
+    )
+    task_scheduler = Parse._task_scheduler(
+        task_scheduler,
+        [tune_context],
+        builder=Parse._builder(builder),
+        runner=Parse._runner(runner),
+        database=database,
+        cost_model=Parse._cost_model(cost_model),
+        measure_callbacks=Parse._callbacks(measure_callbacks),
+    )
+    # pylint: enable=protected-access
+    task_scheduler.tune()
+    bests: List[TuningRecord] = database.get_top_k(
+        database.commit_workload(mod),
+        top_k=1,
+    )
+    if not bests:
+        return None
+    assert len(bests) == 1
+    sch = Schedule(mod)
+    bests[0].trace.apply_to_schedule(sch, remove_postproc=False)
+    task_scheduler.cost_model.save(os.path.join(work_dir, f"{task_name}.xgb"))
+    return sch
+
+
+def tune_te(
+    tensors: List[Tensor],
+    target: Union[str, Target],
+    config: SearchStrategyConfig,
+    work_dir: str,
+    *,
+    task_name: str = "main",
+    builder: Optional[Builder] = None,
+    runner: Optional[Runner] = None,
+    database: Optional[Database] = None,
+    cost_model: Optional[CostModel] = None,
+    measure_callbacks: Optional[List[MeasureCallback]] = None,
+    task_scheduler: Optional[TaskScheduler] = None,
+    space: Optional[FnSpaceGenerator] = None,
+    sch_rules: Optional[FnScheduleRule] = None,
+    postprocs: Optional[FnPostproc] = None,
+    mutator_probs: Optional[FnMutatorProb] = None,
+    num_threads: Optional[int] = None,
+) -> Optional[Schedule]:
+    """Tune a TE compute DAG with a given target.
+
+    Parameters
+    ----------
+    tensor : List[Tensor]
+        The list of input/output tensors of the TE compute DAG.
+    target : Union[str, Target]
+        The target to tune for.
+    config : SearchStrategyConfig
+        The search strategy config.
+    task_name : str
+        The name of the task.
+    work_dir : Optional[str]
+        The working directory to save intermediate results.
+    builder : Optional[Builder]
+        The builder to use.
+    runner : Optional[Runner]
+        The runner to use.
+    database : Optional[Database]
+        The database to use.
+    measure_callbacks : Optional[List[MeasureCallback]]
+        The callbacks used during tuning.
+    f_tune_context : Optional[TYPE_F_TUNE_CONTEXT]
+        The function to create TuneContext.
+    f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER]
+        The function to create TaskScheduler.
+
+    Returns
+    -------
+    sch : Optional[Schedule]
+        The tuned schedule.
+    """
+    return tune_tir(
+        mod=create_prim_func(tensors),
+        target=target,
+        config=config,
+        work_dir=work_dir,
+        task_name=task_name,
+        builder=builder,
+        runner=runner,
+        database=database,
+        cost_model=cost_model,
+        measure_callbacks=measure_callbacks,
+        task_scheduler=task_scheduler,
+        space=space,
+        sch_rules=sch_rules,
+        postprocs=postprocs,
+        mutator_probs=mutator_probs,
+        num_threads=num_threads,
+    )
+
+
+def tune_relay(
+    mod: Union[RelayFunc, IRModule],
+    target: Union[str, Target],
+    config: SearchStrategyConfig,
+    work_dir: str,
+    *,
+    params: Optional[Dict[str, NDArray]] = None,
+    task_name: str = "main",
+    builder: Optional[Builder] = None,
+    runner: Optional[Runner] = None,
+    database: Optional[Database] = None,
+    cost_model: Optional[CostModel] = None,
+    measure_callbacks: Optional[List[MeasureCallback]] = None,
+    task_scheduler: Optional[TaskScheduler] = None,
+    space: Optional[FnSpaceGenerator] = None,
+    sch_rules: Optional[FnScheduleRule] = None,
+    postprocs: Optional[FnPostproc] = None,
+    mutator_probs: Optional[FnMutatorProb] = None,
+    num_threads: Optional[int] = None,
+) -> Module:
+    """Tune a TIR IRModule with a given target.
+
+    Parameters
+    ----------
+    mod : Union[RelayFunc, IRModule]
+        The module to tune.
+    target : Union[str, Target]
+        The target to tune for.
+    config : SearchStrategyConfig
+        The search strategy config.
+    params : Optional[Dict[str, tvm.runtime.NDArray]]
+        The associated parameters of the program
+    task_name : str
+        The name of the task.
+    work_dir : Optional[str]
+        The working directory to save intermediate results.
+    builder : Optional[Builder]
+        The builder to use.
+    runner : Optional[Runner]
+        The runner to use.
+    database : Optional[Database]
+        The database to use.
+    measure_callbacks : Optional[List[MeasureCallback]]
+        The callbacks used during tuning.
+    f_tune_context : Optional[TYPE_F_TUNE_CONTEXT]
+        The function to create TuneContext.
+    f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER]
+        The function to create TaskScheduler.
+
+    Returns
+    -------
+    lib : Module
+        The built runtime module for the given relay workload.
+    """
+
+    logger.info("Working directory: %s", work_dir)
+    extracted_tasks = extract_task_from_relay(mod, target, params)
+    # pylint: disable=protected-access
+    tune_contexts = []
+    target = Parse._target(target)
+    database = Parse._database(database, task_name, work_dir)
+    # parse the tuning contexts
+    for task in extracted_tasks:
+        assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now"
+        tune_contexts.append(
+            Parse._tune_context(
+                tune_context=None,
+                mod=Parse._mod(task.dispatched[0]),
+                target=target,
+                config=config,
+                task_name=task.task_name,
+                space_generator=space,
+                sch_rules=sch_rules,
+                postprocs=postprocs,
+                mutator_probs=mutator_probs,
+                num_threads=num_threads,
+            )
+        )
+    # deduplication
+    logger.info("Before task deduplication: %d tasks", len(tune_contexts))
+    tasks: List[TuneContext] = []
+    hashs: List[int] = []
+    for i, task in enumerate(tune_contexts):
+        struct_hash: int = structural_hash(task.mod)
+        flag: bool = False
+        if struct_hash in hashs:
+            for other_task in tune_contexts[i + 1 :]:
+                if structural_equal(task.mod, other_task.mod):
+                    flag = True
+                    break
+        if not flag:
+            tasks.append(task)
+            hashs.append(struct_hash)
+    logger.info("After task deduplication: %d tasks", len(tasks))
+
+    # parse the task scheduler
+    task_scheduler = Parse._task_scheduler(
+        task_scheduler,
+        tasks,
+        builder=Parse._builder(builder),
+        runner=Parse._runner(runner),
+        database=database,
+        cost_model=Parse._cost_model(cost_model),
+        measure_callbacks=Parse._callbacks(measure_callbacks),
+    )
+    # pylint: enable=protected-access
+    task_scheduler.tune()
+    with ApplyHistoryBest(database):
+        with tvm.transform.PassContext(
+            opt_level=3,
+            config={"relay.backend.use_meta_schedule": True},
+        ):
+            return relay.build(mod, target=target, params=params)
diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py
index ceb5f72..b6fe348 100644
--- a/python/tvm/meta_schedule/utils.py
+++ b/python/tvm/meta_schedule/utils.py
@@ -33,9 +33,37 @@ from tvm.tir import FloatImm, IntImm
 
 @register_func("meta_schedule.cpu_count")
 def _cpu_count_impl(logical: bool = True) -> int:
+    """Return the number of logical or physical CPUs in the system
+    Parameters
+    ----------
+    logical : bool = True
+        If True, return the number of logical CPUs, otherwise return the number of physical CPUs
+    Returns
+    -------
+    cpu_count : int
+        The number of logical or physical CPUs in the system
+    Note
+    ----
+    The meta schedule search infra intentionally does not adopt the following convention in TVM:
+    - C++ API `tvm::runtime::threading::MaxConcurrency()`
+    - Environment variable `TVM_NUM_THREADS` or
+    - Environment variable `OMP_NUM_THREADS`
+    This is because these variables are dedicated to controlling
+    the runtime behavior of generated kernels, instead of the host-side search.
+    Setting these variables may interfere the host-side search with profiling of generated kernels
+    when measuring locally.
+    """
     return psutil.cpu_count(logical=logical) or 1
 
 
+@register_func("meta_schedule._process_error_message")
+def _process_error_message(error_msg: str) -> str:
+    error_msg_lines = str(error_msg).splitlines()
+    if len(error_msg_lines) >= 50:
+        return "\n".join(error_msg_lines[:25] + ["..."] + error_msg_lines[-25:])
+    return error_msg
+
+
 def cpu_count(logical: bool = True) -> int:
     """Return the number of logical or physical CPUs in the system
 
diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc
index e9d3012..1ae2e02 100644
--- a/src/meta_schedule/integration.cc
+++ b/src/meta_schedule/integration.cc
@@ -120,7 +120,9 @@ Optional<ObjectRef> ApplyHistoryBestNode::Query(runtime::String task_name, IRMod
   IRModule prim_mod = dispatched.value()[0];
   ICHECK(HasOnlyOneFunction<tir::PrimFunc>(prim_mod)) << prim_mod;
   // Unify func name to make sure it can be found in database
-  prim_mod = UnifyFuncName(prim_mod);
+  const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod");
+  ICHECK(parse_mod_func) << "Parse mod function not defined!";
+  prim_mod = (*parse_mod_func)(prim_mod);
   if (database->HasWorkload(prim_mod)) {
     Array<TuningRecord> records = database->GetTopK(database->CommitWorkload(prim_mod), 1);
     if (records.size() == 1) {
diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc
index 1f3943d..28f95b2 100644
--- a/src/meta_schedule/task_scheduler/task_scheduler.cc
+++ b/src/meta_schedule/task_scheduler/task_scheduler.cc
@@ -124,7 +124,7 @@ void TaskSchedulerNode::Tune() {
 
   int running_tasks = tasks.size();
   for (int task_id; (task_id = NextTaskId()) != -1;) {
-    LOG(INFO) << "Scheduler picks Task #" << task_id << ": " << tasks[task_id]->task_name;
+    LOG(INFO) << "Scheduler picks Task #" << task_id + 1 << ": " << tasks[task_id]->task_name;
     TuneContext task = tasks[task_id];
     ICHECK(!task->is_stopped);
     ICHECK(!task->runner_futures.defined());
@@ -138,7 +138,7 @@ void TaskSchedulerNode::Tune() {
     } else {
       SetTaskStopped(task_id);
       --running_tasks;
-      LOG(INFO) << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks;
+      LOG(INFO) << "Task #" << task_id + 1 << " has finished. Remaining task(s): " << running_tasks;
     }
   }
   ICHECK_EQ(running_tasks, 0) << "Not all tasks are finished";
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index afeb159..bd76ca7 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -351,22 +351,6 @@ inline int GetTargetNumCores(const Target& target) {
   return num_cores;
 }
 
-/*!
- * \brief Unify the function name in workload to "main".
- * \param mod The workload.
- * \return The new workload with unified function name.
- * \note If the name is not unified, the workload may not be found in database.
- */
-inline IRModule UnifyFuncName(const IRModule& mod) {
-  if (!mod->ContainGlobalVar("main") && mod->GetGlobalTypeVars().size() == 1) {
-    IRModule new_mod = IRModule(
-        Map<GlobalVar, BaseFunc>({{GlobalVar("main"), mod->functions[mod->GetGlobalVars()[0]]}}));
-    return new_mod;
-  } else {
-    return mod;
-  }
-}
-
 }  // namespace meta_schedule
 }  // namespace tvm
 
diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc
index 55869e1..bff4293 100644
--- a/src/tir/schedule/primitive/for_kind.cc
+++ b/src/tir/schedule/primitive/for_kind.cc
@@ -83,7 +83,8 @@ void CheckLoopParallelizableInBlock(const ScheduleState& self, ForKind for_kind,
   const Block& block = block_realize->block;
 
   // Cond 1. The block is required to have affine bindings.
-  CheckAffineBinding(self, block);
+  // TODO(@automation): fix the check
+  // CheckAffineBinding(self, block);
 
   // Cond 2. For each block iter whose binding contains `loop_var`, only two cases are allowed.
   ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size());
diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py
index bc1d5f2..76ca52e 100644
--- a/tests/python/unittest/test_meta_schedule_integration.py
+++ b/tests/python/unittest/test_meta_schedule_integration.py
@@ -116,7 +116,7 @@ def test_meta_schedule_integration_extract_from_resnet():
         layout="NHWC",
         dtype="float32",
     )
-    extracted_tasks = ms.integration.extract_task(mod, target="llvm", params=params)
+    extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params)
     assert len(extracted_tasks) == 30
 
 
diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py
new file mode 100644
index 0000000..7e6f89d
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_tune_relay.py
@@ -0,0 +1,151 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import logging
+import tempfile
+import pytest
+import numpy as np
+from typing import Tuple, List
+
+import tvm
+from tvm import relay
+from tvm.ir import IRModule
+from tvm.runtime.ndarray import cpu, cuda
+from tvm.target.target import Target
+from tvm.contrib import graph_executor
+from tvm.meta_schedule import ReplayTraceConfig
+from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord
+from tvm.meta_schedule.testing import MODEL_TYPE, MODEL_TYPES, get_torch_model
+from tvm.meta_schedule.tune import tune_relay
+
+logging.basicConfig()
+logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
+
+
+class DummyDatabase(PyDatabase):
+    def __init__(self):
+        super().__init__()
+        self.records = []
+        self.workload_reg = []
+
+    def has_workload(self, mod: IRModule) -> Workload:
+        for workload in self.workload_reg:
+            if tvm.ir.structural_equal(workload.mod, mod):
+                return True
+        return False
+
+    def commit_tuning_record(self, record: TuningRecord) -> None:
+        self.records.append(record)
+
+    def commit_workload(self, mod: IRModule) -> Workload:
+        for workload in self.workload_reg:
+            if tvm.ir.structural_equal(workload.mod, mod):
+                return workload
+        workload = Workload(mod)
+        self.workload_reg.append(workload)
+        return workload
+
+    def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
+        return list(
+            filter(
+                lambda x: x.workload == workload,
+                sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
+            )
+        )[: int(top_k)]
+
+    def __len__(self) -> int:
+        return len(self.records)
+
+    def print_results(self) -> None:
+        print("\n".join([str(r) for r in self.records]))
+
+
+@pytest.mark.skip("Integration test")
+@pytest.mark.parametrize("model_name", ["resnet18", "mobilenet_v2", "bert_base"])
+@pytest.mark.parametrize("batch_size", [1])
+@pytest.mark.parametrize("target", ["llvm --num-cores=16", "nvidia/geforce-rtx-3070"])
+def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str):
+    if model_name == "inception_v3" and batch_size == 1:
+        pytest.skip("inception_v3 does not handle batch_size of 1")
+
+    input_shape: Tuple[int, ...]
+    input_name = "input0"
+    dev = tvm.cpu() if str(target).startswith("llvm") else cuda()
+    if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
+        seq_length = 128
+        input_name = "input_ids"
+        input_shape = (batch_size, seq_length)
+        data = tvm.nd.array(np.random.randint(0, 30521, size=input_shape), dev)  # embedding size
+    else:
+        if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION:
+            input_shape = (batch_size, 3, 299, 299)
+        elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION:
+            input_shape = (batch_size, 3, 299, 299)
+        elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION:
+            input_shape = (1, 3, 300, 300)
+        elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION:
+            input_shape = (batch_size, 3, 3, 299, 299)
+        else:
+            raise ValueError("Unsupported model: " + model_name)
+        data = tvm.nd.array(np.random.randn(*input_shape).astype("float32"), dev)
+
+    output_shape: Tuple[int, int] = (batch_size, 1000)
+
+    mod, params = get_torch_model(
+        model_name=model_name,
+        input_shape=input_shape,
+        output_shape=output_shape,
+        dtype="float32",
+    )
+
+    with tempfile.TemporaryDirectory() as work_dir:
+        target = Target(target)
+        database = DummyDatabase()
+        rt_mod: tvm.module = tune_relay(
+            mod=mod,
+            params=params,
+            target=target,
+            config=ReplayTraceConfig(
+                num_trials_per_iter=32,
+                num_trials_total=32,
+            ),
+            work_dir=work_dir,
+            database=database,
+        )
+        # Compile without meta-scheduler for correctness check
+        with tvm.transform.PassContext(opt_level=0):
+            rt_mod2 = relay.build(mod, target=target, params=params)
+
+        def get_output(data, lib):
+            module = graph_executor.GraphModule(lib["default"](dev))
+            module.set_input(input_name, data)
+            module.run()
+            return module.get_output(0).numpy()
+
+        # Check correctness
+        actual_output = get_output(data, rt_mod)
+        expected_output = get_output(data, rt_mod2)
+        assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4)
+
+
+if __name__ == """__main__""":
+    test_meta_schedule_tune_relay("resnet18", 1, "llvm --num-cores=16")
+    test_meta_schedule_tune_relay("resnet18", 1, "nvidia/geforce-rtx-3070")
+    test_meta_schedule_tune_relay("mobilenet_v2", 1, "llvm --num-cores=16")
+    test_meta_schedule_tune_relay("mobilenet_v2", 1, "nvidia/geforce-rtx-3070")
+    test_meta_schedule_tune_relay("bert_base", 1, "llvm --num-cores=16")
+    test_meta_schedule_tune_relay("bert_base", 1, "nvidia/geforce-rtx-3070")
diff --git a/tests/python/unittest/test_meta_schedule_tune_te.py b/tests/python/unittest/test_meta_schedule_tune_te.py
new file mode 100644
index 0000000..a07bf17
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_tune_te.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import logging
+import tempfile
+
+import pytest
+from tvm.meta_schedule import ReplayTraceConfig, tune_te
+from tvm.meta_schedule.testing import te_workload
+from tvm.target.target import Target
+from tvm.tir import Schedule
+
+
+logging.basicConfig()
+logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
+
+
+@pytest.mark.skip("Integration test")
+def test_tune_matmul():
+    with tempfile.TemporaryDirectory() as work_dir:
+        sch: Schedule = tune_te(
+            tensors=te_workload.batch_matmul_nkkm(B=1, N=128, M=128, K=128),
+            target=Target("llvm --num-cores=16"),
+            config=ReplayTraceConfig(
+                num_trials_per_iter=32,
+                num_trials_total=32,
+            ),
+            work_dir=work_dir,
+        )
+        if sch is None:
+            print("No valid schedule found!")
+        else:
+            print(sch.mod.script())
+            print(sch.trace)
+
+
+if __name__ == """__main__""":
+    test_tune_matmul()
diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py
new file mode 100644
index 0000000..277fa24
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_tune_tir.py
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import logging
+import tempfile
+
+import tvm
+import pytest
+from tvm.meta_schedule import ReplayTraceConfig, tune_tir
+from tvm.meta_schedule.tune_context import TuneContext
+from tvm.meta_schedule import schedule_rule, postproc
+from tvm.meta_schedule.space_generator import PostOrderApply
+from tvm.script import tir as T
+from tvm.target.target import Target
+from tvm.te.operation import create_prim_func
+from tvm.tir import Schedule
+from tvm.meta_schedule.testing import te_workload
+
+logging.basicConfig()
+logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
+
+
+# pylint: disable=no-member,invalid-name,unused-variable
+
+
+@T.prim_func
+def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128])
+    B = T.match_buffer(b, [128, 128])
+    C = T.match_buffer(c, [128, 128])
+    for i, j, k in T.grid(128, 128, 128):
+        with T.block("update"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                C[vi, vj] = 0.0
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+# pylint: enable=no-member,invalid-name,unused-variable
+
+
+@pytest.mark.skip("Integration test")
+def test_tune_matmul_cpu():
+    with tempfile.TemporaryDirectory() as work_dir:
+        sch: Schedule = tune_tir(
+            mod=matmul,
+            target=Target("llvm --num-cores=16"),
+            config=ReplayTraceConfig(
+                num_trials_per_iter=32,
+                num_trials_total=32,
+            ),
+            work_dir=work_dir,
+        )
+        if sch is None:
+            print("No valid schedule found!")
+        else:
+            print(sch.mod.script())
+            print(sch.trace)
+
+
+@pytest.mark.skip("Integration test")
+def test_tune_matmul_cuda():
+    with tempfile.TemporaryDirectory() as work_dir:
+        sch: Schedule = tune_tir(
+            mod=matmul,
+            target=Target("nvidia/geforce-rtx-3070"),
+            config=ReplayTraceConfig(
+                num_trials_per_iter=32,
+                num_trials_total=32,
+            ),
+            work_dir=work_dir,
+        )
+        if sch is None:
+            print("No valid schedule found!")
+        else:
+            print(sch.mod.script())
+            print(sch.trace)
+
+
+@pytest.mark.skip("Integeration test")
+def test_tune_matmul_cuda_tensor_core():
+    n = 512
+    mod = create_prim_func(te_workload.matmul_fp16(n, n, n))
+    target = Target("nvidia/geforce-rtx-3070")
+    config = ReplayTraceConfig(
+        num_trials_per_iter=32,
+        num_trials_total=320,
+    )
+
+    class DefaultTensorCore:
+        @staticmethod
+        def _sch_rules():
+            from tvm.meta_schedule import (  # pylint: disable=import-outside-toplevel
+                schedule_rule as M,
+            )
+
+            return [
+                M.AutoInline(
+                    into_producer=False,
+                    into_consumer=True,
+                    # into_cache_only=False,
+                    inline_const_tensor=True,
+                    disallow_if_then_else=False,
+                    require_injective=False,
+                    require_ordered=False,
+                    disallow_op=None,
+                ),
+                M.MultiLevelTiling(
+                    structure="SSSRRSRS",
+                    tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"],
+                    # use_tensor_core=True,
+                    max_innermost_factor=64,
+                    vector_load_lens=[1, 2, 3, 4],
+                    reuse_read=schedule_rule.ReuseType(
+                        req="must",
+                        levels=[4],
+                        scope="shared",
+                    ),
+                    reuse_write=schedule_rule.ReuseType(
+                        req="no",
+                        levels=[],
+                        scope="",
+                    ),
+                ),
+                M.AutoInline(
+                    into_producer=True,
+                    into_consumer=True,
+                    # into_cache_only=True,
+                    inline_const_tensor=True,
+                    disallow_if_then_else=False,
+                    require_injective=False,
+                    require_ordered=False,
+                    disallow_op=None,
+                ),
+                M.ParallelizeVectorizeUnroll(
+                    max_jobs_per_core=-1,  # disable parallelize
+                    max_vectorize_extent=-1,  # disable vectorize
+                    unroll_max_steps=[0, 16, 64, 512, 1024],
+                    unroll_explicit=True,
+                ),
+            ]
+
+        @staticmethod
+        def _postproc():
+            from tvm.meta_schedule import (  # pylint: disable=import-outside-toplevel
+                postproc as M,
+            )
+
+            return [
+                # M.RewriteCooperativeFetch(),
+                M.RewriteParallelVectorizeUnroll(),
+                M.RewriteReductionBlock(),
+                # M.RewriteTensorCore(),
+                M.VerifyGPUCode(),
+            ]
+
+    with tempfile.TemporaryDirectory() as work_dir:
+        sch: Schedule = tune_tir(
+            mod=mod,
+            target=target,
+            config=config,
+            work_dir=work_dir,
+            space=PostOrderApply(),
+            sch_rules=DefaultTensorCore._sch_rules,
+            postprocs=DefaultTensorCore._postproc,
+            num_threads=None,
+        )
+        if sch is None:
+            print("No valid schedule found!")
+        else:
+            print(sch.mod.script())
+            print(sch.trace)
+
+            from tvm.contrib import nvcc
+            import numpy as np
+
+            ctx = tvm.gpu(0)
+            if nvcc.have_tensorcore(ctx.compute_version):
+                with tvm.transform.PassContext():
+                    func = tvm.build(sch.mod["main"], [], "cuda")
+                    print(sch.mod.script())
+                    print(func.imported_modules[0].get_source())
+                a_np = np.random.uniform(size=(n, n)).astype("float16")
+                b_np = np.random.uniform(size=(n, n)).astype("float16")
+                a = tvm.nd.array(a_np, ctx)
+                b = tvm.nd.array(b_np, ctx)
+                c = tvm.nd.array(np.zeros((n, n), dtype="float32"), ctx)
+                evaluator = func.time_evaluator(
+                    func.entry_name, ctx, number=3, repeat=1, min_repeat_ms=40
+                )
+                print("matmul with tensor core: %f ms" % (evaluator(a, b, c).mean * 1e3))
+
+                np.testing.assert_allclose(
+                    c.asnumpy(),
+                    np.matmul(a_np.astype("float32"), b_np.astype("float32")),
+                    rtol=1e-4,
+                    atol=1e-4,
+                )
+
+
+if __name__ == """__main__""":
+    test_tune_matmul_cpu()
+    test_tune_matmul_cuda()
+    test_tune_matmul_cuda_tensor_core()