You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/01/30 05:51:14 UTC
[tvm] branch main updated: [MetaSchedule][M4a] User-API: Tune-TE/TIR/Relay (#10079)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 779dc51 [MetaSchedule][M4a] User-API: Tune-TE/TIR/Relay (#10079)
779dc51 is described below
commit 779dc51e1332f417fa4c304b595ce76891dfc33a
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Sat Jan 29 21:50:24 2022 -0800
[MetaSchedule][M4a] User-API: Tune-TE/TIR/Relay (#10079)
* Add tuning scripts for tir, te & relay.
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
Minor fix.
Nits.
Add back tests.
* slightly improve tune.py
Co-authored-by: Junru Shao <ju...@gmail.com>
---
python/tvm/meta_schedule/__init__.py | 13 +-
python/tvm/meta_schedule/integration.py | 2 +-
python/tvm/meta_schedule/testing/__init__.py | 4 +-
python/tvm/meta_schedule/testing/relay_workload.py | 80 +++
python/tvm/meta_schedule/tune.py | 719 +++++++++++++++++++++
python/tvm/meta_schedule/utils.py | 28 +
src/meta_schedule/integration.cc | 4 +-
src/meta_schedule/task_scheduler/task_scheduler.cc | 4 +-
src/meta_schedule/utils.h | 16 -
src/tir/schedule/primitive/for_kind.cc | 3 +-
.../unittest/test_meta_schedule_integration.py | 2 +-
.../unittest/test_meta_schedule_tune_relay.py | 151 +++++
.../python/unittest/test_meta_schedule_tune_te.py | 52 ++
.../python/unittest/test_meta_schedule_tune_tir.py | 218 +++++++
14 files changed, 1270 insertions(+), 26 deletions(-)
diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py
index e41e5b3..2a69d3c 100644
--- a/python/tvm/meta_schedule/__init__.py
+++ b/python/tvm/meta_schedule/__init__.py
@@ -19,10 +19,19 @@ from . import arg_info
from . import database
from . import builder
from . import runner
+from . import mutator
+from . import postproc
+from . import schedule_rule
from . import space_generator
from . import search_strategy
-from . import schedule_rule
from . import integration
from . import feature_extractor
+from . import cost_model
+from .search_strategy import (
+ EvolutionarySearchConfig,
+ MeasureCandidate,
+ ReplayFuncConfig,
+ ReplayTraceConfig,
+)
+from .tune import tune_te, tune_tir, tune_relay
from .tune_context import TuneContext
-from .search_strategy import MeasureCandidate
diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py
index 794591c..727c7fe 100644
--- a/python/tvm/meta_schedule/integration.py
+++ b/python/tvm/meta_schedule/integration.py
@@ -184,7 +184,7 @@ class ApplyHistoryBest(MetaScheduleContext):
self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member
-def extract_task(
+def extract_task_from_relay(
mod: Union[IRModule, RelayFunc],
target: Target,
params: Optional[Dict[str, NDArray]] = None,
diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py
index a5291f7..85b48b3 100644
--- a/python/tvm/meta_schedule/testing/__init__.py
+++ b/python/tvm/meta_schedule/testing/__init__.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Testing utilities in meta schedule"""
-from .local_rpc import LocalRPC
-from .relay_workload import get_network
from .byoc_trt import relay_build_with_tensorrt
+from .local_rpc import LocalRPC
+from .relay_workload import MODEL_TYPE, MODEL_TYPES, get_network, get_torch_model
diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py
index 1eb9950..2f1ffdd 100644
--- a/python/tvm/meta_schedule/testing/relay_workload.py
+++ b/python/tvm/meta_schedule/testing/relay_workload.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Workloads in Relay IR"""
+from enum import Enum
from typing import Dict, Tuple
import tvm.relay.testing # pylint: disable=unused-import
@@ -22,6 +23,85 @@ from tvm import relay
from tvm.ir import IRModule
from tvm.runtime import NDArray
+# Model types supported in Torchvision
+class MODEL_TYPE(Enum): # pylint: disable=invalid-name
+ IMAGE_CLASSIFICATION = (1,)
+ VIDEO_CLASSIFICATION = (2,)
+ SEGMENTATION = (3,)
+ OBJECT_DETECTION = (4,)
+ TEXT_CLASSIFICATION = (5,)
+
+
+# Specify the type of each model
+MODEL_TYPES = {
+ "resnet18": MODEL_TYPE.IMAGE_CLASSIFICATION,
+ "mobilenet_v2": MODEL_TYPE.IMAGE_CLASSIFICATION,
+ "bert_base": MODEL_TYPE.TEXT_CLASSIFICATION,
+}
+
+
+def get_torch_model(
+ model_name: str,
+ input_shape: Tuple[int, ...],
+ output_shape: Tuple[int, int], # pylint: disable=unused-argument
+ dtype: str = "float32",
+) -> Tuple[IRModule, Dict[str, NDArray]]:
+ """Load model from torch model zoo
+ Parameters
+ ----------
+ model_name : str
+ The name of the model to load
+ input_shape: Tuple[int, ...]
+ Tuple for input shape
+ output_shape: Tuple[int, int]
+ Tuple for output shape
+ dtype: str
+ Tensor data type
+ """
+
+ assert dtype == "float32"
+
+ import torch # type: ignore # pylint: disable=import-error,import-outside-toplevel
+ from torchvision import models # type: ignore # pylint: disable=import-error,import-outside-toplevel
+ import transformers # type: ignore # pylint: disable=import-error,import-outside-toplevel
+ import os # type: ignore # pylint: disable=import-error,import-outside-toplevel
+
+ def do_trace(model, inp):
+ model.eval()
+ model_trace = torch.jit.trace(model, inp)
+ model_trace.eval()
+ return model_trace
+
+ # Load model from torchvision
+ if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+ model = transformers.BertModel(
+ transformers.BertConfig(
+ num_hidden_layers=12,
+ hidden_size=768,
+ intermediate_size=3072,
+ num_attention_heads=12,
+ return_dict=False,
+ )
+ )
+ model.eval()
+ input_data = torch.randint(10000, input_shape)
+ shape_list = [("input_ids", input_shape)]
+ scripted_model = torch.jit.trace(model, [input_data], strict=False)
+ elif MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION:
+ model = getattr(models, model_name)()
+ # Setup input
+ input_data = torch.randn(input_shape).type(torch.float32)
+ shape_list = [("input0", input_shape)]
+ # Get trace. Depending on the model type, wrapper may be necessary.
+ scripted_model = do_trace(model, input_data)
+ else:
+ raise ValueError("Unsupported model in Torch model zoo.")
+
+ # Convert torch model to relay module
+ mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
+ return mod, params
+
def get_network(
name: str,
diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py
new file mode 100644
index 0000000..faf61f5
--- /dev/null
+++ b/python/tvm/meta_schedule/tune.py
@@ -0,0 +1,719 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""User-facing Tuning API"""
+
+import logging
+import os.path
+from typing import Callable, Dict, List, Optional, Union
+
+import tvm
+from tvm import relay
+from tvm._ffi import register_func
+from tvm.ir import IRModule, structural_equal, structural_hash
+from tvm.relay import Function as RelayFunc
+from tvm.runtime import Module, NDArray
+from tvm.target import Target
+from tvm.te import Tensor, create_prim_func
+from tvm.tir import PrimFunc, Schedule
+
+from .builder import Builder, LocalBuilder
+from .cost_model import CostModel, XGBModel
+from .database import Database, JSONDatabase, TuningRecord
+from .feature_extractor import PerStoreFeature
+from .integration import ApplyHistoryBest, extract_task_from_relay
+from .measure_callback import MeasureCallback
+from .mutator import Mutator
+from .postproc import Postproc
+from .runner import LocalRunner, Runner
+from .schedule_rule import ScheduleRule
+from .search_strategy import (
+ EvolutionarySearchConfig,
+ ReplayFuncConfig,
+ ReplayTraceConfig,
+)
+from .space_generator import PostOrderApply, SpaceGenerator
+from .task_scheduler import RoundRobin, TaskScheduler
+from .tune_context import TuneContext
+
+
+logger = logging.getLogger(__name__) # pylint: disable=invalid-name
+
+SearchStrategyConfig = Union[
+ ReplayFuncConfig,
+ ReplayTraceConfig,
+ EvolutionarySearchConfig,
+]
+FnSpaceGenerator = Callable[[], SpaceGenerator]
+FnScheduleRule = Callable[[], List[ScheduleRule]]
+FnPostproc = Callable[[], List[Postproc]]
+FnMutatorProb = Callable[[], Dict[Mutator, float]]
+FnTaskScheduler = Callable[
+ [
+ List[TuneContext],
+ Builder,
+ Runner,
+ Database,
+ CostModel,
+ List[MeasureCallback],
+ ],
+ TaskScheduler,
+]
+
+
+class DefaultLLVM:
+ """Default tuning configuration for LLVM."""
+
+ @staticmethod
+ def _sch_rules() -> List[ScheduleRule]:
+ from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel
+ schedule_rule as M,
+ )
+
+ return [
+ M.AutoInline(
+ into_producer=False,
+ into_consumer=True,
+ inline_const_tensor=True,
+ disallow_if_then_else=True,
+ require_injective=True,
+ require_ordered=True,
+ disallow_op=["tir.exp"],
+ ),
+ M.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64),
+ M.MultiLevelTiling(
+ structure="SSRSRS",
+ tile_binds=None,
+ max_innermost_factor=64,
+ vector_load_lens=None,
+ reuse_read=None,
+ reuse_write=M.ReuseType(
+ req="may",
+ levels=[1, 2],
+ scope="global",
+ ),
+ ),
+ M.ParallelizeVectorizeUnroll(
+ max_jobs_per_core=16,
+ max_vectorize_extent=64,
+ unroll_max_steps=[0, 16, 64, 512],
+ unroll_explicit=True,
+ ),
+ M.RandomComputeLocation(),
+ ]
+
+ @staticmethod
+ def _postproc() -> List[Postproc]:
+ from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel
+ postproc as M,
+ )
+
+ return [
+ M.DisallowDynamicLoop(),
+ M.RewriteParallelVectorizeUnroll(),
+ M.RewriteReductionBlock(),
+ ]
+
+ @staticmethod
+ def _mutator_probs() -> Dict[Mutator, float]:
+ from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel
+ mutator as M,
+ )
+
+ return {
+ M.MutateTileSize(): 0.9,
+ M.MutateComputeLocation(): 0.05,
+ M.MutateUnroll(): 0.03,
+ M.MutateParallel(max_jobs_per_core=16): 0.02,
+ }
+
+
+class DefaultCUDA:
+ """Default tuning configuration for CUDA."""
+
+ @staticmethod
+ def _sch_rules() -> List[ScheduleRule]:
+ from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel
+ schedule_rule as M,
+ )
+
+ return [
+ M.MultiLevelTiling(
+ structure="SSSRRSRS",
+ tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
+ max_innermost_factor=64,
+ vector_load_lens=[1, 2, 3, 4],
+ reuse_read=M.ReuseType(
+ req="must",
+ levels=[4],
+ scope="shared",
+ ),
+ reuse_write=M.ReuseType(
+ req="must",
+ levels=[3],
+ scope="local",
+ ),
+ ),
+ M.AutoInline(
+ into_producer=True,
+ into_consumer=True,
+ # into_cache_only=False,
+ inline_const_tensor=True,
+ disallow_if_then_else=False,
+ require_injective=False,
+ require_ordered=False,
+ disallow_op=None,
+ ),
+ M.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]),
+ M.ParallelizeVectorizeUnroll(
+ max_jobs_per_core=-1, # disable parallelize
+ max_vectorize_extent=-1, # disable vectorize
+ unroll_max_steps=[0, 16, 64, 512, 1024],
+ unroll_explicit=True,
+ ),
+ ]
+
+ @staticmethod
+ def _postproc() -> List[Postproc]:
+ from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel
+ postproc as M,
+ )
+
+ return [
+ M.DisallowDynamicLoop(),
+ M.RewriteCooperativeFetch(),
+ M.RewriteUnboundBlock(),
+ M.RewriteParallelVectorizeUnroll(),
+ M.RewriteReductionBlock(),
+ M.VerifyGPUCode(),
+ ]
+
+ @staticmethod
+ def _mutator_probs() -> Dict[Mutator, float]:
+ from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel
+ mutator as M,
+ )
+
+ return {
+ # M.MutateTileSize(): 0.9,
+ M.MutateUnroll(): 0.1,
+ }
+
+
+class Parse:
+ """Parse tuning configuration from user inputs."""
+
+ @staticmethod
+ @register_func("tvm.meta_schedule.tune.parse_mod") # for use in ApplyHistoryBest
+ def _mod(mod: Union[PrimFunc, IRModule]) -> IRModule:
+ if isinstance(mod, PrimFunc):
+ mod = mod.with_attr("global_symbol", "main")
+ mod = mod.with_attr("tir.noalias", True)
+ mod = IRModule({"main": mod})
+ if not isinstance(mod, IRModule):
+ raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}")
+ # in order to make sure the mod can be found in ApplyHistoryBest
+ # different func name can cause structural unequal
+ func_names = mod.get_global_vars()
+ (func_name,) = func_names
+ if len(func_names) == 1 and func_name != "main":
+ mod = IRModule({"main": mod[func_name]})
+ return mod
+
+ @staticmethod
+ def _target(target: Union[str, Target]) -> Target:
+ if isinstance(target, str):
+ target = Target(target)
+ if not isinstance(target, Target):
+ raise TypeError(f"Expected `target` to be str or Target, but gets: {target}")
+ return target
+
+ @staticmethod
+ def _builder(builder: Optional[Builder]) -> Builder:
+ if builder is None:
+ builder = LocalBuilder()
+ if not isinstance(builder, Builder):
+ raise TypeError(f"Expected `builder` to be Builder, but gets: {builder}")
+ return builder
+
+ @staticmethod
+ def _runner(runner: Optional[Runner]) -> Runner:
+ if runner is None:
+ runner = LocalRunner()
+ if not isinstance(runner, Runner):
+ raise TypeError(f"Expected `runner` to be Runner, but gets: {runner}")
+ return runner
+
+ @staticmethod
+ def _database(database: Union[None, Database], task_name: str, path: str) -> Database:
+ if database is None:
+ path_workload = os.path.join(path, f"{task_name}_database_workload.json")
+ path_tuning_record = os.path.join(path, f"{task_name}_database_tuning_record.json")
+ logger.info(
+ "Creating JSONDatabase. Workload at: %s. Tuning records at: %s",
+ path_workload,
+ path_tuning_record,
+ )
+ database = JSONDatabase(
+ path_workload=path_workload,
+ path_tuning_record=path_tuning_record,
+ )
+ if not isinstance(database, Database):
+ raise TypeError(f"Expected `database` to be Database, but gets: {database}")
+ return database
+
+ @staticmethod
+ def _callbacks(
+ measure_callbacks: Optional[List[MeasureCallback]],
+ ) -> List[MeasureCallback]:
+ if measure_callbacks is None:
+ from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel
+ measure_callback as M,
+ )
+
+ return [
+ M.AddToDatabase(),
+ M.RemoveBuildArtifact(),
+ M.EchoStatistics(),
+ M.UpdateCostModel(),
+ ]
+ if not isinstance(measure_callbacks, (list, tuple)):
+ raise TypeError(
+ f"Expected `measure_callbacks` to be List[MeasureCallback], "
+ f"but gets: {measure_callbacks}"
+ )
+ measure_callbacks = list(measure_callbacks)
+ for i, callback in enumerate(measure_callbacks):
+ if not isinstance(callback, MeasureCallback):
+ raise TypeError(
+ f"Expected `measure_callbacks` to be List[MeasureCallback], "
+ f"but measure_callbacks[{i}] is: {callback}"
+ )
+ return measure_callbacks
+
+ @staticmethod
+ def _cost_model(cost_model: Optional[CostModel]) -> CostModel:
+ if cost_model is None:
+ return XGBModel(extractor=PerStoreFeature())
+ if not isinstance(cost_model, CostModel):
+ raise TypeError(f"Expected `cost_model` to be CostModel, but gets: {cost_model}")
+ return cost_model
+
+ @staticmethod
+ def _space_generator(space_generator: Optional[FnSpaceGenerator]) -> SpaceGenerator:
+ if space_generator is None:
+ return PostOrderApply()
+ if callable(space_generator):
+ space_generator = space_generator()
+ if not isinstance(space_generator, SpaceGenerator):
+ raise TypeError(
+ f"Expected `space_generator` to return SpaceGenerator, "
+ f"but gets: {space_generator}"
+ )
+ return space_generator
+
+ @staticmethod
+ def _sch_rules(sch_rules: Optional[FnScheduleRule], target: Target) -> List[ScheduleRule]:
+ if callable(sch_rules):
+ return sch_rules()
+ if sch_rules is not None:
+ raise TypeError(f"Expected `sch_rules` to be None or callable, but gets: {sch_rules}")
+ # pylint: disable=protected-access
+ if target.kind.name == "llvm":
+ return DefaultLLVM._sch_rules()
+ if target.kind.name == "cuda":
+ return DefaultCUDA._sch_rules()
+ # pylint: enable=protected-access
+ raise ValueError(f"Unsupported target: {target}")
+
+ @staticmethod
+ def _postproc(postproc: Optional[FnPostproc], target: Target) -> List[Postproc]:
+ if callable(postproc):
+ return postproc()
+ if postproc is not None:
+ raise TypeError(f"Expected `postproc` to be None or callable, but gets: {postproc}")
+ # pylint: disable=protected-access
+ if target.kind.name == "llvm":
+ return DefaultLLVM._postproc()
+ if target.kind.name == "cuda":
+ return DefaultCUDA._postproc()
+ # pylint: enable=protected-access
+ raise ValueError(f"Unsupported target: {target}")
+
+ @staticmethod
+ def _mutator_probs(
+ mutator_probs: Optional[FnMutatorProb],
+ target: Target,
+ ) -> Dict[Mutator, float]:
+ if callable(mutator_probs):
+ return mutator_probs()
+ if mutator_probs is not None:
+ raise TypeError(
+ f"Expected `mutator_probs` to be None or callable, but gets: {mutator_probs}"
+ )
+ # pylint: disable=protected-access
+ if target.kind.name == "llvm":
+ return DefaultLLVM._mutator_probs()
+ if target.kind.name == "cuda":
+ return DefaultCUDA._mutator_probs()
+ # pylint: enable=protected-access
+ raise ValueError(f"Unsupported target: {target}")
+
+ @staticmethod
+ def _tune_context(
+ tune_context: Optional[TuneContext],
+ mod: IRModule,
+ target: Target,
+ config: SearchStrategyConfig,
+ task_name: str,
+ space_generator: Optional[FnSpaceGenerator],
+ sch_rules: Optional[FnScheduleRule],
+ postprocs: Optional[FnPostproc],
+ mutator_probs: Optional[FnMutatorProb],
+ num_threads: Optional[int],
+ ) -> TuneContext:
+ if tune_context is None:
+ return TuneContext(
+ mod=mod,
+ target=target,
+ # pylint: disable=protected-access
+ space_generator=Parse._space_generator(space_generator),
+ search_strategy=config.create_strategy(),
+ sch_rules=Parse._sch_rules(sch_rules, target),
+ postprocs=Parse._postproc(postprocs, target),
+ mutator_probs=Parse._mutator_probs(mutator_probs, target),
+ # pylint: enable=protected-access
+ task_name=task_name,
+ rand_state=-1,
+ num_threads=num_threads,
+ )
+ if not isinstance(tune_context, TuneContext):
+ raise TypeError(f"Expected `tune_context` to be TuneContext, but gets: {tune_context}")
+ return tune_context
+
+ @staticmethod
+ def _task_scheduler(
+ task_scheduler: Union[None, TaskScheduler, FnTaskScheduler],
+ tasks: List[TuneContext],
+ builder: Builder,
+ runner: Runner,
+ database: Database,
+ cost_model: CostModel,
+ measure_callbacks: List[MeasureCallback],
+ ):
+ if task_scheduler is None:
+ return RoundRobin(
+ tasks=tasks,
+ builder=builder,
+ runner=runner,
+ database=database,
+ cost_model=cost_model,
+ measure_callbacks=measure_callbacks,
+ )
+ if callable(task_scheduler):
+ return task_scheduler(
+ tasks,
+ builder,
+ runner,
+ database,
+ cost_model,
+ measure_callbacks,
+ )
+ if not isinstance(task_scheduler, TaskScheduler):
+ raise TypeError(
+ f"Expected `task_scheduler` to be TaskScheduler, but gets: {task_scheduler}"
+ )
+ return task_scheduler
+
+
+def tune_tir(
+ mod: Union[IRModule, PrimFunc],
+ target: Union[str, Target],
+ config: SearchStrategyConfig,
+ work_dir: str,
+ *,
+ task_name: str = "main",
+ builder: Optional[Builder] = None,
+ runner: Optional[Runner] = None,
+ database: Optional[Database] = None,
+ cost_model: Optional[CostModel] = None,
+ measure_callbacks: Optional[List[MeasureCallback]] = None,
+ task_scheduler: Optional[TaskScheduler] = None,
+ space: Optional[FnSpaceGenerator] = None,
+ sch_rules: Optional[FnScheduleRule] = None,
+ postprocs: Optional[FnPostproc] = None,
+ mutator_probs: Optional[FnMutatorProb] = None,
+ num_threads: Optional[int] = None,
+) -> Optional[Schedule]:
+ """Tune a TIR IRModule with a given target.
+
+ Parameters
+ ----------
+ mod : Union[IRModule, PrimFunc]
+ The module to tune.
+ target : Union[str, Target]
+ The target to tune for.
+ config : SearchStrategyConfig
+ The search strategy config.
+ task_name : str
+ The name of the task.
+ work_dir : Optional[str]
+ The working directory to save intermediate results.
+ builder : Optional[Builder]
+ The builder to use.
+ runner : Optional[Runner]
+ The runner to use.
+ database : Optional[Database]
+ The database to use.
+ cost_model : Optional[CostModel]
+ The cost model to use.
+ measure_callbacks : Optional[List[MeasureCallback]]
+ The callbacks used during tuning.
+ f_tune_context : Optional[TYPE_F_TUNE_CONTEXT]
+ The function to create TuneContext.
+ f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER]
+ The function to create TaskScheduler.
+
+ Returns
+ -------
+ sch : Optional[Schedule]
+ The tuned schedule.
+ """
+
+ logger.info("Working directory: %s", work_dir)
+ # pylint: disable=protected-access
+ mod = Parse._mod(mod)
+ database = Parse._database(database, task_name, work_dir)
+ tune_context = Parse._tune_context(
+ tune_context=None,
+ mod=mod,
+ target=Parse._target(target),
+ config=config,
+ task_name=task_name,
+ space_generator=space,
+ sch_rules=sch_rules,
+ postprocs=postprocs,
+ mutator_probs=mutator_probs,
+ num_threads=num_threads,
+ )
+ task_scheduler = Parse._task_scheduler(
+ task_scheduler,
+ [tune_context],
+ builder=Parse._builder(builder),
+ runner=Parse._runner(runner),
+ database=database,
+ cost_model=Parse._cost_model(cost_model),
+ measure_callbacks=Parse._callbacks(measure_callbacks),
+ )
+ # pylint: enable=protected-access
+ task_scheduler.tune()
+ bests: List[TuningRecord] = database.get_top_k(
+ database.commit_workload(mod),
+ top_k=1,
+ )
+ if not bests:
+ return None
+ assert len(bests) == 1
+ sch = Schedule(mod)
+ bests[0].trace.apply_to_schedule(sch, remove_postproc=False)
+ task_scheduler.cost_model.save(os.path.join(work_dir, f"{task_name}.xgb"))
+ return sch
+
+
+def tune_te(
+ tensors: List[Tensor],
+ target: Union[str, Target],
+ config: SearchStrategyConfig,
+ work_dir: str,
+ *,
+ task_name: str = "main",
+ builder: Optional[Builder] = None,
+ runner: Optional[Runner] = None,
+ database: Optional[Database] = None,
+ cost_model: Optional[CostModel] = None,
+ measure_callbacks: Optional[List[MeasureCallback]] = None,
+ task_scheduler: Optional[TaskScheduler] = None,
+ space: Optional[FnSpaceGenerator] = None,
+ sch_rules: Optional[FnScheduleRule] = None,
+ postprocs: Optional[FnPostproc] = None,
+ mutator_probs: Optional[FnMutatorProb] = None,
+ num_threads: Optional[int] = None,
+) -> Optional[Schedule]:
+ """Tune a TE compute DAG with a given target.
+
+ Parameters
+ ----------
+ tensor : List[Tensor]
+ The list of input/output tensors of the TE compute DAG.
+ target : Union[str, Target]
+ The target to tune for.
+ config : SearchStrategyConfig
+ The search strategy config.
+ task_name : str
+ The name of the task.
+ work_dir : Optional[str]
+ The working directory to save intermediate results.
+ builder : Optional[Builder]
+ The builder to use.
+ runner : Optional[Runner]
+ The runner to use.
+ database : Optional[Database]
+ The database to use.
+ measure_callbacks : Optional[List[MeasureCallback]]
+ The callbacks used during tuning.
+ f_tune_context : Optional[TYPE_F_TUNE_CONTEXT]
+ The function to create TuneContext.
+ f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER]
+ The function to create TaskScheduler.
+
+ Returns
+ -------
+ sch : Optional[Schedule]
+ The tuned schedule.
+ """
+ return tune_tir(
+ mod=create_prim_func(tensors),
+ target=target,
+ config=config,
+ work_dir=work_dir,
+ task_name=task_name,
+ builder=builder,
+ runner=runner,
+ database=database,
+ cost_model=cost_model,
+ measure_callbacks=measure_callbacks,
+ task_scheduler=task_scheduler,
+ space=space,
+ sch_rules=sch_rules,
+ postprocs=postprocs,
+ mutator_probs=mutator_probs,
+ num_threads=num_threads,
+ )
+
+
+def tune_relay(
+ mod: Union[RelayFunc, IRModule],
+ target: Union[str, Target],
+ config: SearchStrategyConfig,
+ work_dir: str,
+ *,
+ params: Optional[Dict[str, NDArray]] = None,
+ task_name: str = "main",
+ builder: Optional[Builder] = None,
+ runner: Optional[Runner] = None,
+ database: Optional[Database] = None,
+ cost_model: Optional[CostModel] = None,
+ measure_callbacks: Optional[List[MeasureCallback]] = None,
+ task_scheduler: Optional[TaskScheduler] = None,
+ space: Optional[FnSpaceGenerator] = None,
+ sch_rules: Optional[FnScheduleRule] = None,
+ postprocs: Optional[FnPostproc] = None,
+ mutator_probs: Optional[FnMutatorProb] = None,
+ num_threads: Optional[int] = None,
+) -> Module:
+ """Tune a TIR IRModule with a given target.
+
+ Parameters
+ ----------
+ mod : Union[RelayFunc, IRModule]
+ The module to tune.
+ target : Union[str, Target]
+ The target to tune for.
+ config : SearchStrategyConfig
+ The search strategy config.
+ params : Optional[Dict[str, tvm.runtime.NDArray]]
+ The associated parameters of the program
+ task_name : str
+ The name of the task.
+ work_dir : Optional[str]
+ The working directory to save intermediate results.
+ builder : Optional[Builder]
+ The builder to use.
+ runner : Optional[Runner]
+ The runner to use.
+ database : Optional[Database]
+ The database to use.
+ measure_callbacks : Optional[List[MeasureCallback]]
+ The callbacks used during tuning.
+ f_tune_context : Optional[TYPE_F_TUNE_CONTEXT]
+ The function to create TuneContext.
+ f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER]
+ The function to create TaskScheduler.
+
+ Returns
+ -------
+ lib : Module
+ The built runtime module for the given relay workload.
+ """
+
+ logger.info("Working directory: %s", work_dir)
+ extracted_tasks = extract_task_from_relay(mod, target, params)
+ # pylint: disable=protected-access
+ tune_contexts = []
+ target = Parse._target(target)
+ database = Parse._database(database, task_name, work_dir)
+ # parse the tuning contexts
+ for task in extracted_tasks:
+ assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now"
+ tune_contexts.append(
+ Parse._tune_context(
+ tune_context=None,
+ mod=Parse._mod(task.dispatched[0]),
+ target=target,
+ config=config,
+ task_name=task.task_name,
+ space_generator=space,
+ sch_rules=sch_rules,
+ postprocs=postprocs,
+ mutator_probs=mutator_probs,
+ num_threads=num_threads,
+ )
+ )
+ # deduplication
+ logger.info("Before task deduplication: %d tasks", len(tune_contexts))
+ tasks: List[TuneContext] = []
+ hashs: List[int] = []
+ for i, task in enumerate(tune_contexts):
+ struct_hash: int = structural_hash(task.mod)
+ flag: bool = False
+ if struct_hash in hashs:
+ for other_task in tune_contexts[i + 1 :]:
+ if structural_equal(task.mod, other_task.mod):
+ flag = True
+ break
+ if not flag:
+ tasks.append(task)
+ hashs.append(struct_hash)
+ logger.info("After task deduplication: %d tasks", len(tasks))
+
+ # parse the task scheduler
+ task_scheduler = Parse._task_scheduler(
+ task_scheduler,
+ tasks,
+ builder=Parse._builder(builder),
+ runner=Parse._runner(runner),
+ database=database,
+ cost_model=Parse._cost_model(cost_model),
+ measure_callbacks=Parse._callbacks(measure_callbacks),
+ )
+ # pylint: enable=protected-access
+ task_scheduler.tune()
+ with ApplyHistoryBest(database):
+ with tvm.transform.PassContext(
+ opt_level=3,
+ config={"relay.backend.use_meta_schedule": True},
+ ):
+ return relay.build(mod, target=target, params=params)
diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py
index ceb5f72..b6fe348 100644
--- a/python/tvm/meta_schedule/utils.py
+++ b/python/tvm/meta_schedule/utils.py
@@ -33,9 +33,37 @@ from tvm.tir import FloatImm, IntImm
@register_func("meta_schedule.cpu_count")
def _cpu_count_impl(logical: bool = True) -> int:
+ """Return the number of logical or physical CPUs in the system
+ Parameters
+ ----------
+ logical : bool = True
+ If True, return the number of logical CPUs, otherwise return the number of physical CPUs
+ Returns
+ -------
+ cpu_count : int
+ The number of logical or physical CPUs in the system
+ Note
+ ----
+ The meta schedule search infra intentionally does not adopt the following convention in TVM:
+ - C++ API `tvm::runtime::threading::MaxConcurrency()`
+ - Environment variable `TVM_NUM_THREADS` or
+ - Environment variable `OMP_NUM_THREADS`
+ This is because these variables are dedicated to controlling
+ the runtime behavior of generated kernels, instead of the host-side search.
+ Setting these variables may interfere the host-side search with profiling of generated kernels
+ when measuring locally.
+ """
return psutil.cpu_count(logical=logical) or 1
+@register_func("meta_schedule._process_error_message")
+def _process_error_message(error_msg: str) -> str:
+ error_msg_lines = str(error_msg).splitlines()
+ if len(error_msg_lines) >= 50:
+ return "\n".join(error_msg_lines[:25] + ["..."] + error_msg_lines[-25:])
+ return error_msg
+
+
def cpu_count(logical: bool = True) -> int:
"""Return the number of logical or physical CPUs in the system
diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc
index e9d3012..1ae2e02 100644
--- a/src/meta_schedule/integration.cc
+++ b/src/meta_schedule/integration.cc
@@ -120,7 +120,9 @@ Optional<ObjectRef> ApplyHistoryBestNode::Query(runtime::String task_name, IRMod
IRModule prim_mod = dispatched.value()[0];
ICHECK(HasOnlyOneFunction<tir::PrimFunc>(prim_mod)) << prim_mod;
// Unify func name to make sure it can be found in database
- prim_mod = UnifyFuncName(prim_mod);
+ const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod");
+ ICHECK(parse_mod_func) << "Parse mod function not defined!";
+ prim_mod = (*parse_mod_func)(prim_mod);
if (database->HasWorkload(prim_mod)) {
Array<TuningRecord> records = database->GetTopK(database->CommitWorkload(prim_mod), 1);
if (records.size() == 1) {
diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc
index 1f3943d..28f95b2 100644
--- a/src/meta_schedule/task_scheduler/task_scheduler.cc
+++ b/src/meta_schedule/task_scheduler/task_scheduler.cc
@@ -124,7 +124,7 @@ void TaskSchedulerNode::Tune() {
int running_tasks = tasks.size();
for (int task_id; (task_id = NextTaskId()) != -1;) {
- LOG(INFO) << "Scheduler picks Task #" << task_id << ": " << tasks[task_id]->task_name;
+ LOG(INFO) << "Scheduler picks Task #" << task_id + 1 << ": " << tasks[task_id]->task_name;
TuneContext task = tasks[task_id];
ICHECK(!task->is_stopped);
ICHECK(!task->runner_futures.defined());
@@ -138,7 +138,7 @@ void TaskSchedulerNode::Tune() {
} else {
SetTaskStopped(task_id);
--running_tasks;
- LOG(INFO) << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks;
+ LOG(INFO) << "Task #" << task_id + 1 << " has finished. Remaining task(s): " << running_tasks;
}
}
ICHECK_EQ(running_tasks, 0) << "Not all tasks are finished";
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index afeb159..bd76ca7 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -351,22 +351,6 @@ inline int GetTargetNumCores(const Target& target) {
return num_cores;
}
-/*!
- * \brief Unify the function name in workload to "main".
- * \param mod The workload.
- * \return The new workload with unified function name.
- * \note If the name is not unified, the workload may not be found in database.
- */
-inline IRModule UnifyFuncName(const IRModule& mod) {
- if (!mod->ContainGlobalVar("main") && mod->GetGlobalTypeVars().size() == 1) {
- IRModule new_mod = IRModule(
- Map<GlobalVar, BaseFunc>({{GlobalVar("main"), mod->functions[mod->GetGlobalVars()[0]]}}));
- return new_mod;
- } else {
- return mod;
- }
-}
-
} // namespace meta_schedule
} // namespace tvm
diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc
index 55869e1..bff4293 100644
--- a/src/tir/schedule/primitive/for_kind.cc
+++ b/src/tir/schedule/primitive/for_kind.cc
@@ -83,7 +83,8 @@ void CheckLoopParallelizableInBlock(const ScheduleState& self, ForKind for_kind,
const Block& block = block_realize->block;
// Cond 1. The block is required to have affine bindings.
- CheckAffineBinding(self, block);
+ // TODO(@automation): fix the check
+ // CheckAffineBinding(self, block);
// Cond 2. For each block iter whose binding contains `loop_var`, only two cases are allowed.
ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size());
diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py
index bc1d5f2..76ca52e 100644
--- a/tests/python/unittest/test_meta_schedule_integration.py
+++ b/tests/python/unittest/test_meta_schedule_integration.py
@@ -116,7 +116,7 @@ def test_meta_schedule_integration_extract_from_resnet():
layout="NHWC",
dtype="float32",
)
- extracted_tasks = ms.integration.extract_task(mod, target="llvm", params=params)
+ extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params)
assert len(extracted_tasks) == 30
diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py
new file mode 100644
index 0000000..7e6f89d
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_tune_relay.py
@@ -0,0 +1,151 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import logging
+import tempfile
+import pytest
+import numpy as np
+from typing import Tuple, List
+
+import tvm
+from tvm import relay
+from tvm.ir import IRModule
+from tvm.runtime.ndarray import cpu, cuda
+from tvm.target.target import Target
+from tvm.contrib import graph_executor
+from tvm.meta_schedule import ReplayTraceConfig
+from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord
+from tvm.meta_schedule.testing import MODEL_TYPE, MODEL_TYPES, get_torch_model
+from tvm.meta_schedule.tune import tune_relay
+
+logging.basicConfig()
+logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
+
+
+class DummyDatabase(PyDatabase):
+ def __init__(self):
+ super().__init__()
+ self.records = []
+ self.workload_reg = []
+
+ def has_workload(self, mod: IRModule) -> Workload:
+ for workload in self.workload_reg:
+ if tvm.ir.structural_equal(workload.mod, mod):
+ return True
+ return False
+
+ def commit_tuning_record(self, record: TuningRecord) -> None:
+ self.records.append(record)
+
+ def commit_workload(self, mod: IRModule) -> Workload:
+ for workload in self.workload_reg:
+ if tvm.ir.structural_equal(workload.mod, mod):
+ return workload
+ workload = Workload(mod)
+ self.workload_reg.append(workload)
+ return workload
+
+ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
+ return list(
+ filter(
+ lambda x: x.workload == workload,
+ sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
+ )
+ )[: int(top_k)]
+
+ def __len__(self) -> int:
+ return len(self.records)
+
+ def print_results(self) -> None:
+ print("\n".join([str(r) for r in self.records]))
+
+
+@pytest.mark.skip("Integration test")
+@pytest.mark.parametrize("model_name", ["resnet18", "mobilenet_v2", "bert_base"])
+@pytest.mark.parametrize("batch_size", [1])
+@pytest.mark.parametrize("target", ["llvm --num-cores=16", "nvidia/geforce-rtx-3070"])
+def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str):
+ if model_name == "inception_v3" and batch_size == 1:
+ pytest.skip("inception_v3 does not handle batch_size of 1")
+
+ input_shape: Tuple[int, ...]
+ input_name = "input0"
+ dev = tvm.cpu() if str(target).startswith("llvm") else cuda()
+ if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
+ seq_length = 128
+ input_name = "input_ids"
+ input_shape = (batch_size, seq_length)
+ data = tvm.nd.array(np.random.randint(0, 30521, size=input_shape), dev) # embedding size
+ else:
+ if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION:
+ input_shape = (batch_size, 3, 299, 299)
+ elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION:
+ input_shape = (batch_size, 3, 299, 299)
+ elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION:
+ input_shape = (1, 3, 300, 300)
+ elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION:
+ input_shape = (batch_size, 3, 3, 299, 299)
+ else:
+ raise ValueError("Unsupported model: " + model_name)
+ data = tvm.nd.array(np.random.randn(*input_shape).astype("float32"), dev)
+
+ output_shape: Tuple[int, int] = (batch_size, 1000)
+
+ mod, params = get_torch_model(
+ model_name=model_name,
+ input_shape=input_shape,
+ output_shape=output_shape,
+ dtype="float32",
+ )
+
+ with tempfile.TemporaryDirectory() as work_dir:
+ target = Target(target)
+ database = DummyDatabase()
+ rt_mod: tvm.module = tune_relay(
+ mod=mod,
+ params=params,
+ target=target,
+ config=ReplayTraceConfig(
+ num_trials_per_iter=32,
+ num_trials_total=32,
+ ),
+ work_dir=work_dir,
+ database=database,
+ )
+ # Compile without meta-scheduler for correctness check
+ with tvm.transform.PassContext(opt_level=0):
+ rt_mod2 = relay.build(mod, target=target, params=params)
+
+ def get_output(data, lib):
+ module = graph_executor.GraphModule(lib["default"](dev))
+ module.set_input(input_name, data)
+ module.run()
+ return module.get_output(0).numpy()
+
+ # Check correctness
+ actual_output = get_output(data, rt_mod)
+ expected_output = get_output(data, rt_mod2)
+ assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4)
+
+
+if __name__ == """__main__""":
+ test_meta_schedule_tune_relay("resnet18", 1, "llvm --num-cores=16")
+ test_meta_schedule_tune_relay("resnet18", 1, "nvidia/geforce-rtx-3070")
+ test_meta_schedule_tune_relay("mobilenet_v2", 1, "llvm --num-cores=16")
+ test_meta_schedule_tune_relay("mobilenet_v2", 1, "nvidia/geforce-rtx-3070")
+ test_meta_schedule_tune_relay("bert_base", 1, "llvm --num-cores=16")
+ test_meta_schedule_tune_relay("bert_base", 1, "nvidia/geforce-rtx-3070")
diff --git a/tests/python/unittest/test_meta_schedule_tune_te.py b/tests/python/unittest/test_meta_schedule_tune_te.py
new file mode 100644
index 0000000..a07bf17
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_tune_te.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import logging
+import tempfile
+
+import pytest
+from tvm.meta_schedule import ReplayTraceConfig, tune_te
+from tvm.meta_schedule.testing import te_workload
+from tvm.target.target import Target
+from tvm.tir import Schedule
+
+
+logging.basicConfig()
+logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
+
+
+@pytest.mark.skip("Integration test")
+def test_tune_matmul():
+ with tempfile.TemporaryDirectory() as work_dir:
+ sch: Schedule = tune_te(
+ tensors=te_workload.batch_matmul_nkkm(B=1, N=128, M=128, K=128),
+ target=Target("llvm --num-cores=16"),
+ config=ReplayTraceConfig(
+ num_trials_per_iter=32,
+ num_trials_total=32,
+ ),
+ work_dir=work_dir,
+ )
+ if sch is None:
+ print("No valid schedule found!")
+ else:
+ print(sch.mod.script())
+ print(sch.trace)
+
+
+if __name__ == """__main__""":
+ test_tune_matmul()
diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py
new file mode 100644
index 0000000..277fa24
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_tune_tir.py
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import logging
+import tempfile
+
+import tvm
+import pytest
+from tvm.meta_schedule import ReplayTraceConfig, tune_tir
+from tvm.meta_schedule.tune_context import TuneContext
+from tvm.meta_schedule import schedule_rule, postproc
+from tvm.meta_schedule.space_generator import PostOrderApply
+from tvm.script import tir as T
+from tvm.target.target import Target
+from tvm.te.operation import create_prim_func
+from tvm.tir import Schedule
+from tvm.meta_schedule.testing import te_workload
+
+logging.basicConfig()
+logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
+
+
+# pylint: disable=no-member,invalid-name,unused-variable
+
+
+@T.prim_func
+def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128])
+ B = T.match_buffer(b, [128, 128])
+ C = T.match_buffer(c, [128, 128])
+ for i, j, k in T.grid(128, 128, 128):
+ with T.block("update"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = 0.0
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+# pylint: enable=no-member,invalid-name,unused-variable
+
+
+@pytest.mark.skip("Integration test")
+def test_tune_matmul_cpu():
+ with tempfile.TemporaryDirectory() as work_dir:
+ sch: Schedule = tune_tir(
+ mod=matmul,
+ target=Target("llvm --num-cores=16"),
+ config=ReplayTraceConfig(
+ num_trials_per_iter=32,
+ num_trials_total=32,
+ ),
+ work_dir=work_dir,
+ )
+ if sch is None:
+ print("No valid schedule found!")
+ else:
+ print(sch.mod.script())
+ print(sch.trace)
+
+
+@pytest.mark.skip("Integration test")
+def test_tune_matmul_cuda():
+ with tempfile.TemporaryDirectory() as work_dir:
+ sch: Schedule = tune_tir(
+ mod=matmul,
+ target=Target("nvidia/geforce-rtx-3070"),
+ config=ReplayTraceConfig(
+ num_trials_per_iter=32,
+ num_trials_total=32,
+ ),
+ work_dir=work_dir,
+ )
+ if sch is None:
+ print("No valid schedule found!")
+ else:
+ print(sch.mod.script())
+ print(sch.trace)
+
+
+@pytest.mark.skip("Integeration test")
+def test_tune_matmul_cuda_tensor_core():
+ n = 512
+ mod = create_prim_func(te_workload.matmul_fp16(n, n, n))
+ target = Target("nvidia/geforce-rtx-3070")
+ config = ReplayTraceConfig(
+ num_trials_per_iter=32,
+ num_trials_total=320,
+ )
+
+ class DefaultTensorCore:
+ @staticmethod
+ def _sch_rules():
+ from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel
+ schedule_rule as M,
+ )
+
+ return [
+ M.AutoInline(
+ into_producer=False,
+ into_consumer=True,
+ # into_cache_only=False,
+ inline_const_tensor=True,
+ disallow_if_then_else=False,
+ require_injective=False,
+ require_ordered=False,
+ disallow_op=None,
+ ),
+ M.MultiLevelTiling(
+ structure="SSSRRSRS",
+ tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"],
+ # use_tensor_core=True,
+ max_innermost_factor=64,
+ vector_load_lens=[1, 2, 3, 4],
+ reuse_read=schedule_rule.ReuseType(
+ req="must",
+ levels=[4],
+ scope="shared",
+ ),
+ reuse_write=schedule_rule.ReuseType(
+ req="no",
+ levels=[],
+ scope="",
+ ),
+ ),
+ M.AutoInline(
+ into_producer=True,
+ into_consumer=True,
+ # into_cache_only=True,
+ inline_const_tensor=True,
+ disallow_if_then_else=False,
+ require_injective=False,
+ require_ordered=False,
+ disallow_op=None,
+ ),
+ M.ParallelizeVectorizeUnroll(
+ max_jobs_per_core=-1, # disable parallelize
+ max_vectorize_extent=-1, # disable vectorize
+ unroll_max_steps=[0, 16, 64, 512, 1024],
+ unroll_explicit=True,
+ ),
+ ]
+
+ @staticmethod
+ def _postproc():
+ from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel
+ postproc as M,
+ )
+
+ return [
+ # M.RewriteCooperativeFetch(),
+ M.RewriteParallelVectorizeUnroll(),
+ M.RewriteReductionBlock(),
+ # M.RewriteTensorCore(),
+ M.VerifyGPUCode(),
+ ]
+
+ with tempfile.TemporaryDirectory() as work_dir:
+ sch: Schedule = tune_tir(
+ mod=mod,
+ target=target,
+ config=config,
+ work_dir=work_dir,
+ space=PostOrderApply(),
+ sch_rules=DefaultTensorCore._sch_rules,
+ postprocs=DefaultTensorCore._postproc,
+ num_threads=None,
+ )
+ if sch is None:
+ print("No valid schedule found!")
+ else:
+ print(sch.mod.script())
+ print(sch.trace)
+
+ from tvm.contrib import nvcc
+ import numpy as np
+
+ ctx = tvm.gpu(0)
+ if nvcc.have_tensorcore(ctx.compute_version):
+ with tvm.transform.PassContext():
+ func = tvm.build(sch.mod["main"], [], "cuda")
+ print(sch.mod.script())
+ print(func.imported_modules[0].get_source())
+ a_np = np.random.uniform(size=(n, n)).astype("float16")
+ b_np = np.random.uniform(size=(n, n)).astype("float16")
+ a = tvm.nd.array(a_np, ctx)
+ b = tvm.nd.array(b_np, ctx)
+ c = tvm.nd.array(np.zeros((n, n), dtype="float32"), ctx)
+ evaluator = func.time_evaluator(
+ func.entry_name, ctx, number=3, repeat=1, min_repeat_ms=40
+ )
+ print("matmul with tensor core: %f ms" % (evaluator(a, b, c).mean * 1e3))
+
+ np.testing.assert_allclose(
+ c.asnumpy(),
+ np.matmul(a_np.astype("float32"), b_np.astype("float32")),
+ rtol=1e-4,
+ atol=1e-4,
+ )
+
+
+if __name__ == """__main__""":
+ test_tune_matmul_cpu()
+ test_tune_matmul_cuda()
+ test_tune_matmul_cuda_tensor_core()