You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/04/20 01:09:51 UTC

[GitHub] [tvm] comaniac commented on a change in pull request #7823: [TVMC] A simplified TVMC API for python scripting.

comaniac commented on a change in pull request #7823:
URL: https://github.com/apache/tvm/pull/7823#discussion_r616268878



##########
File path: python/tvm/driver/tvmc/autotuner.py
##########
@@ -255,97 +388,113 @@ def drive_tune(args):
     # min_repeat_ms should be:
     # a. the value provided by the user, if any, or
     # b. 0ms in case target is "cpu"; otherwise 1000ms
-    if args.min_repeat_ms is not None:
-        min_repeat_ms = args.min_repeat_ms
-    else:
+    if min_repeat_ms is None:
         min_repeat_ms = 0 if target.keys[0] == "cpu" else 1000
         logger.debug("Default --min-repeat-ms for this target is %s", min_repeat_ms)
 
-    if args.rpc_tracker:
-        runner_ctor = auto_scheduler.RPCRunner if args.enable_autoscheduler else autotvm.RPCRunner
+    if rpc_key:
+        if hostname is None or port is None:
+            raise common.TVMCException(
+                "You must provide a hostname and port to connect to a remote RPC device."
+            )
+        if isinstance(port, str):
+            port = int(port)
+
+        runner_ctor = auto_scheduler.RPCRunner if enable_autoscheduler else autotvm.RPCRunner
         runner = runner_ctor(
-            key=args.rpc_key,
-            host=rpc_hostname,
-            port=rpc_port,
-            number=args.number,
-            repeat=args.repeat,
-            n_parallel=args.parallel,
-            timeout=args.timeout,
+            key=rpc_key,
+            host=hostname,
+            port=port,
+            number=number,
+            repeat=repeat,
+            n_parallel=parallel,
+            timeout=timeout,
             min_repeat_ms=min_repeat_ms,
         )
     else:
         logger.info("starting localhost tuning")
         runner_ctor = (
-            auto_scheduler.LocalRunner if args.enable_autoscheduler else autotvm.LocalRunner
+            auto_scheduler.LocalRPCMeasureContext if enable_autoscheduler else autotvm.LocalRunner
         )
-        runner = runner_ctor(
-            number=args.number,
-            repeat=args.repeat,
-            timeout=args.timeout,
+        local_server = runner_ctor(
+            number=number,
+            repeat=repeat,
+            timeout=timeout,
             min_repeat_ms=min_repeat_ms,
         )
 
-    if args.enable_autoscheduler:
-        # Specify hardware parameters
-        hardware_params = auto_scheduler.HardwareParams(
-            args.num_cores,
-            args.vector_unit_bytes,
-            args.cache_line_bytes,
-            args.max_shared_memory_per_block,
-            args.max_local_memory_per_block,
-            args.max_threads_per_block,
-            args.max_vthread_extent,
-            args.warp_size,
-        )
+        # For autoscheduling on some devices, we need to maintain a LocalRPCMeasureContext object.
+        if enable_autoscheduler:
+            runner = local_server.runner
+        else:
+            runner = local_server
+
+    if enable_autoscheduler:
+
         tasks, weights = autoscheduler_get_tuning_tasks(
             mod=mod,
             params=params,
             target=target,
-            alter_layout=args.desired_layout,
+            alter_layout=desired_layout,
             hardware_params=hardware_params,
-            include_simple_tasks=args.include_simple_tasks,
+            include_simple_tasks=include_simple_tasks,
         )
 
+        # If not specified, choose a number of trials likely to produce good results.
+        if trials is None:
+            trials = 10000
+
         # Create the autoscheduler tuning options
         tuning_options = auto_scheduler.TuningOptions(
-            num_measure_trials=args.trials,
-            measure_callbacks=[auto_scheduler.RecordToFile(args.output)],
+            num_measure_trials=trials,
+            measure_callbacks=[auto_scheduler.RecordToFile(tuning_records)],
             runner=runner,
-            early_stopping=args.early_stopping,
+            early_stopping=early_stopping,
         )
 
         # Schedule the tasks (i.e., produce a schedule for each task)
-        schedule_tasks(
-            tasks, weights, tuning_options, args.tuning_records, args.log_estimated_latency
-        )
+        schedule_tasks(tasks, weights, tuning_options, prior_records, log_estimated_latency)
     else:
         tasks = autotvm_get_tuning_tasks(
             mod=mod,
             params=params,
             target=target,
-            alter_layout=args.desired_layout,
+            alter_layout=desired_layout,
         )
 
+        # If trails isn't specified, default to a number likely to produce good
+        # results without taking too much time.
+        if trials is None:
+            trials = 1000

Review comment:
       Please note that the semantic of `trials` in AutoTVM and AutoScheduler are different. We should just use one in TVMC. For example, if you decided to let `trials` represent the total trials of a model (i.e., the semantic of AutoScheduler), then here you could do `task_trials = trials // len(tasks)`, and vice versa. In this way, we can also have one unified default value for this argument.

##########
File path: python/tvm/driver/tvmc/runner.py
##########
@@ -337,135 +353,73 @@ def run_module(
     times : list of str
         execution times generated by the time evaluator
     """
-
-    with tempfile.TemporaryDirectory() as tmp_dir:
-        logger.debug("extracting module file %s", module_file)
-        t = tarfile.open(module_file)
-        t.extractall(tmp_dir)
-        graph = open(os.path.join(tmp_dir, "mod.json")).read()
-        params = bytearray(open(os.path.join(tmp_dir, "mod.params"), "rb").read())
-
-        if hostname:
-            # Remote RPC
-            if rpc_key:
-                logger.debug("running on remote RPC tracker with key %s", rpc_key)
-                session = request_remote(rpc_key, hostname, port, timeout=1000)
-            else:
-                logger.debug("running on remote RPC with no key")
-                session = rpc.connect(hostname, port)
+    if not isinstance(tvmc_package, TVMCPackage):
+        raise TVMCException(
+            "This model doesn't seem to have been compiled yet. "
+            "Try calling tvmc.compile on the model before running it."
+        )
+
+    if hostname:
+        if isinstance(port, str):
+            port = int(port)
+        # Remote RPC
+        if rpc_key:
+            logger.debug("running on remote RPC tracker with key %s", rpc_key)
+            session = request_remote(rpc_key, hostname, port, timeout=1000)
         else:
-            # Local
-            logger.debug("running a local session")
-            session = rpc.LocalSession()
-
-        session.upload(os.path.join(tmp_dir, "mod.so"))
-        lib = session.load_module("mod.so")
-
-        # TODO expand to other supported devices, as listed in tvm.rpc.client (@leandron)
-        logger.debug("device is %s", device)
-        if device == "gpu":
-            dev = session.gpu()
-        elif device == "cl":
-            dev = session.cl()
-        else:
-            assert device == "cpu"
-            dev = session.cpu()
-
-        if profile:
-            logger.debug("creating runtime with profiling enabled")
-            module = debug_executor.create(graph, lib, dev, dump_root="./prof")
-        else:
-            logger.debug("creating runtime with profiling disabled")
-            module = runtime.create(graph, lib, dev)
-
-        logger.debug("load params into the runtime module")
-        module.load_params(params)
-
-        shape_dict, dtype_dict = get_input_info(graph, params)
-        inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode)
-
-        logger.debug("setting inputs to the module")
-        module.set_input(**inputs_dict)
-
-        # Run must be called explicitly if profiling
-        if profile:
-            logger.debug("running the module with profiling enabled")
-            module.run()
-
-        # create the module time evaluator (returns a function)
-        timer = module.module.time_evaluator("run", dev, 1, repeat=repeat)
-        # call the evaluator function to invoke the module and save execution times
-        prof_result = timer()
-        # collect a list of execution times from the profiling results
-        times = prof_result.results
-
-        logger.debug("collecting the output tensors")
-        num_outputs = module.get_num_outputs()
-        outputs = {}
-        for i in range(num_outputs):
-            output_name = "output_{}".format(i)
-            outputs[output_name] = module.get_output(i).asnumpy()
-
-        return outputs, times
-
-
-def get_top_results(outputs, max_results):
-    """Return the top n results from the output tensor.
-
-    This function is primarily for image classification and will
-    not necessarily generalise.
-
-    Parameters
-    ----------
-    outputs : dict
-        Outputs dictionary - {output_name: np.array}.
-    max_results : int
-        Number of results to return
-
-    Returns
-    -------
-    top_results : np.array
-        Results array of shape (2, n).
-        The first row is the indices and the second is the values.
-
-    """
-    output = np.copy(outputs["output_0"])
-    sorted_labels = output.argsort()[0][-max_results:][::-1]
-    output.sort()
-    sorted_values = output[0][-max_results:][::-1]
-    top_results = np.array([sorted_labels, sorted_values])
-    return top_results
-
-
-def format_times(times):
-    """Format the mean, max, min and std of the execution times.
-
-    This has the effect of producing a small table that looks like:
-
-        Execution time summary:
-        mean (ms)   max (ms)    min (ms)    std (ms)
-        0.14310    0.16161    0.12933    0.01004
-
-    Parameters
-    ----------
-    times : list
-        A list of execution times (in seconds).
-
-    Returns
-    -------
-    str
-        A formatted string containing the statistics.
-    """
-
-    # timestamps
-    mean_ts = np.mean(times) * 1000
-    std_ts = np.std(times) * 1000
-    max_ts = np.max(times) * 1000
-    min_ts = np.min(times) * 1000
-
-    header = "Execution time summary:\n{0:^10} {1:^10} {2:^10} {3:^10}".format(
-        "mean (ms)", "max (ms)", "min (ms)", "std (ms)"
-    )
-    stats = "{0:^10.2f} {1:^10.2f} {2:^10.2f} {3:^10.2f}".format(mean_ts, max_ts, min_ts, std_ts)
+            logger.debug("running on remote RPC with no key")
+            session = rpc.connect(hostname, port)
+    else:
+        # Local
+        logger.debug("running a local session")
+        session = rpc.LocalSession()
+
+    session.upload(tvmc_package.lib_path)
+    lib = session.load_module(tvmc_package.lib_name)
+
+    # TODO expand to other supported devices, as listed in tvm.rpc.client (@leandron)
+    logger.debug("device is %s", device)
+    if device == "gpu":
+        dev = session.gpu()
+    elif device == "cl":
+        dev = session.cl()
+    else:
+        assert device == "cpu"
+        dev = session.cpu()
 
-    return "%s\n%s\n" % (header, stats)
+    if profile:
+        logger.debug("creating runtime with profiling enabled")
+        module = debug_executor.create(tvmc_package.graph, lib, dev, dump_root="./prof")
+    else:
+        logger.debug("creating runtime with profiling disabled")
+        module = runtime.create(tvmc_package.graph, lib, dev)
+
+    logger.debug("load params into the runtime module")
+    module.load_params(tvmc_package.params)
+
+    shape_dict, dtype_dict = get_input_info(tvmc_package.graph, tvmc_package.params)
+    inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode)
+
+    logger.debug("setting inputs to the module")
+    module.set_input(**inputs_dict)
+
+    # Run must be called explicitly if profiling
+    if profile:
+        logger.debug("running the module with profiling enabled")

Review comment:
       Use `info`?

##########
File path: python/tvm/driver/tvmc/autotuner.py
##########
@@ -228,24 +240,137 @@ def drive_tune(args):
     args: argparse.Namespace
         Arguments from command line parser.
     """
-    # extra arguments validation before importing the model, so that obvious errors
-    # are pointed in advance.
-    if args.rpc_tracker:
-        parsed_url = urlparse("//%s" % args.rpc_tracker)
+    tvmc_model = frontends.load_model(args.FILE, args.model_format, shape_dict=args.input_shapes)
+    tvmc_model.tuning_records = args.tuning_records
+    # Specify hardware parameters, although they'll only be used if autoscheduling.
+    hardware_params = auto_scheduler.HardwareParams(
+        args.num_cores,
+        args.vector_unit_bytes,
+        args.cache_line_bytes,
+        args.max_shared_memory_per_block,
+        args.max_local_memory_per_block,
+        args.max_threads_per_block,
+        args.max_vthread_extent,
+        args.warp_size,
+        args.target,
+        args.target_host,
+    )
+
+    tune_model(
+        tvmc_model,
+        args.target,
+        args.output,
+        args.enable_autoscheduler,
+        args.rpc_key,
+        args.rpc_tracker,
+        args.trials,
+        args.target_host,
+        args.tuner,
+        args.min_repeat_ms,
+        args.early_stopping,
+        args.desired_layout,
+        args.timeout,
+        args.number,
+        args.repeat,
+        args.parallel,
+        hardware_params,
+        args.include_simple_tasks,
+        args.log_estimated_latency,
+    )
+
+
+def tune_model(
+    tvmc_model: TVMCModel,
+    target: str,
+    tuning_records: Optional[str] = None,
+    enable_autoscheduler: bool = False,
+    rpc_key: Optional[str] = None,
+    rpc_tracker: Optional[str] = None,
+    trials: Optional[int] = None,
+    target_host: str = "llvm",
+    tuner: str = "xgb",
+    min_repeat_ms: Optional[int] = None,
+    early_stopping: Optional[int] = None,
+    desired_layout: Optional[str] = None,
+    timeout: int = 10,
+    number: int = 10,
+    repeat: int = 1,
+    parallel: int = 4,
+    hardware_params: Optional[HardwareParams] = None,
+    include_simple_tasks: bool = False,
+    log_estimated_latency: bool = False,
+):
+    """Use tuning to automatically optimize the functions in a model.
+
+    Parameters
+    ----------
+    tvmc_model : TVMCModel
+        The model to be optimized.
+    target : str
+        Compilation target as plain string, inline JSON or path to a JSON file.
+    tuning_records: str, optional
+        The path to a file that tuning results will be saved to. If not specified,
+        a temporary file will be used.

Review comment:
       Given that tuning record is the most important result of runing `.tune`, it is even worthwhile to enforce this argument to be specified by users, IMHO.
   
   Meanwhile, it might be better to just eliminate `prior_records`. When `tuning_records` has prior records, we could directly use them to hot-start the tuning. It's a bit confusing to have these two arguments especially both autotvm and auto-scheduler support tuning resuming and append new records to the file without overriding existing records.

##########
File path: python/tvm/driver/tvmc/autotuner.py
##########
@@ -255,97 +380,100 @@ def drive_tune(args):
     # min_repeat_ms should be:
     # a. the value provided by the user, if any, or
     # b. 0ms in case target is "cpu"; otherwise 1000ms
-    if args.min_repeat_ms is not None:
-        min_repeat_ms = args.min_repeat_ms
-    else:
+    if min_repeat_ms is None:
         min_repeat_ms = 0 if target.keys[0] == "cpu" else 1000
         logger.debug("Default --min-repeat-ms for this target is %s", min_repeat_ms)
 
-    if args.rpc_tracker:
-        runner_ctor = auto_scheduler.RPCRunner if args.enable_autoscheduler else autotvm.RPCRunner
+    if rpc_tracker:
+        runner_ctor = auto_scheduler.RPCRunner if enable_autoscheduler else autotvm.RPCRunner
         runner = runner_ctor(
-            key=args.rpc_key,
+            key=rpc_key,
             host=rpc_hostname,
             port=rpc_port,
-            number=args.number,
-            repeat=args.repeat,
-            n_parallel=args.parallel,
-            timeout=args.timeout,
+            number=number,
+            repeat=repeat,
+            n_parallel=parallel,
+            timeout=timeout,
             min_repeat_ms=min_repeat_ms,
         )
     else:
         logger.info("starting localhost tuning")
         runner_ctor = (
-            auto_scheduler.LocalRunner if args.enable_autoscheduler else autotvm.LocalRunner
+            auto_scheduler.LocalRPCMeasureContext if enable_autoscheduler else autotvm.LocalRunner
         )
         runner = runner_ctor(
-            number=args.number,
-            repeat=args.repeat,
-            timeout=args.timeout,
+            number=number,
+            repeat=repeat,
+            timeout=timeout,
             min_repeat_ms=min_repeat_ms,
         )
 
-    if args.enable_autoscheduler:
-        # Specify hardware parameters
-        hardware_params = auto_scheduler.HardwareParams(
-            args.num_cores,
-            args.vector_unit_bytes,
-            args.cache_line_bytes,
-            args.max_shared_memory_per_block,
-            args.max_local_memory_per_block,
-            args.max_threads_per_block,
-            args.max_vthread_extent,
-            args.warp_size,
-        )
+    if enable_autoscheduler:
+
         tasks, weights = autoscheduler_get_tuning_tasks(
             mod=mod,
             params=params,
             target=target,
-            alter_layout=args.desired_layout,
+            alter_layout=desired_layout,
             hardware_params=hardware_params,
-            include_simple_tasks=args.include_simple_tasks,
+            include_simple_tasks=include_simple_tasks,
         )
 
+        # If not specified, choose a number of trials likely to produce good results.
+        if trials is None:
+            trials = 10000

Review comment:
       I agree that we should not hide a constant here. Since `trials` is an argument of this function and it could even be popogated from the CLI. We could just set the default value of the function and CLI argument.
   
   In general, any value we set in this function should have a corresponding message to let users know what value they are using. For example, the logging level of showing min_repeat_ms value should be INFO instead of DEBUG IMHO. A more formal approach is displaying a table of configuration with all values being used before tuning, so that users can double check if they miss anything.

##########
File path: python/tvm/driver/tvmc/runner.py
##########
@@ -325,8 +337,12 @@ def run_module(
         The fill-mode to use when generating data for input tensors.
         Valid options are "zeros", "ones" and "random".
         Defaults to "random".
+    number : int, optional

Review comment:
       ```suggestion
       repeat : int, optional
   ```

##########
File path: python/tvm/driver/tvmc/model.py
##########
@@ -0,0 +1,372 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This file contains the definition of a set of classes that wrap the outputs
+of TVMC functions to create a simpler and more intuitive API.
+
+There is one class for each required stage of a TVM workflow.
+The TVMCModel represents the result of importing a model into TVM, it
+contains the precompiled graph definition and parameters that define
+what the model does.
+
+Compiling a TVMCModel produces a TVMCPackage, which contains the generated
+artifacts that allow the model to be run on the target hardware.
+
+Running a TVMCPackage produces a TVMCResult, which contains the outputs of
+the model and the measured runtime.
+
+Examples
+--------
+The following code shows a full lifecycle for a model using tvmc, first the
+model is imported from an exterior framework, in this case onnx, then it
+is tuned to find the best schedules on CPU, then compiled into a TVMCPackage,
+and finally run.
+
+.. code-block:: python
+    tvmc_model = tvmc.load("my_model.onnx")
+    tuning_records = tvmc.tune(tvmc_model, target="llvm")
+    tvmc_package = tvmc.compile(tvmc_model, target="llvm", tuning_records=tuning_records)
+    result = tvmc.run(tvmc_package, device="cpu")
+    print(result)
+"""
+import os
+import tarfile
+from typing import Optional, Union, List, Dict, Callable, TextIO
+import numpy as np
+
+import tvm
+import tvm.contrib.cc
+from tvm import relay
+from tvm.contrib import utils
+from tvm.relay.backend.graph_executor_factory import GraphExecutorFactoryModule
+
+from .common import TVMCException
+
+
+class TVMCModel(object):
+    """Initialize a TVMC model from a relay model definition or a saved file.
+
+    Parameters
+    ----------
+    mod : tvm.IRModule, optional
+        The relay module corresponding to this model.
+    params : dict, optional
+        A parameter dictionary for the model.
+    model_path: str, optional
+        An alternative way to load a TVMCModel, the path to a previously
+        saved model.
+    """
+
+    def __init__(
+        self,
+        mod: Optional[tvm.IRModule] = None,
+        params: Optional[Dict[str, tvm.nd.NDArray]] = None,
+        model_path: Optional[str] = None,
+    ):
+        if (mod is None or params is None) and (model_path is None):
+            raise TVMCException(
+                "Either mod and params must be provided "
+                "or a path to a previously saved TVMCModel"
+            )
+        self._tmp_dir = utils.tempdir()
+        if model_path is not None:
+            self.load(model_path)
+        else:
+            self.mod = mod
+            self.params = params if params else {}
+
+    def save(self, model_path: str):
+        """Save the TVMCModel to disk.
+
+        Note that this saves the graph representation,
+        the parameters, and the tuning records if applicable. It will not save any
+        compiled artifacts.
+
+        Parameters
+        ----------
+        model_path : str
+            A full path to save this TVMCModel to including the output file name.
+            The file will be saved as a tar file so using a ".tar" extension is advised.
+        """
+        temp = self._tmp_dir
+
+        # Save relay graph
+        relay_name = "model.json"
+        relay_path = temp.relpath(relay_name)
+        with open(relay_path, "w") as relay_file:
+            relay_file.write(tvm.ir.save_json(self.mod))
+
+        # Save params
+        params_name = "model.params"
+        params_path = temp.relpath(params_name)
+        with open(params_path, "wb") as params_file:
+            params_file.write(relay.save_param_dict(self.params))
+
+        # Create a tar file.
+        with tarfile.open(model_path, "w") as tar:
+            tar.add(relay_path, relay_name)
+            tar.add(params_path, params_name)
+
+    def load(self, model_path: str):
+        """Load a TVMCModel from disk.
+
+        Parameters
+        ----------
+        model_path : str
+            A path to load the TVMCModel from.
+        """
+        temp = self._tmp_dir
+        t = tarfile.open(model_path)
+        t.extractall(temp.relpath("."))
+
+        # Load relay IR.
+        relay_path = temp.relpath("model.json")
+        with open(relay_path, "r") as relay_file:
+            self.mod = tvm.ir.load_json(relay_file.read())
+
+        # Load parameter dictionary.
+        params_path = temp.relpath("model.params")
+        with open(params_path, "rb") as params_file:
+            self.params = relay.load_param_dict(params_file.read())
+
+    def get_temp_path(self, file_name: str):
+        """Get the full path for a filename stored in this model's temp directory.
+
+        Parameters
+        ----------
+        file_name : str
+            The name of the file within this model's temp directory.
+
+        Returns
+        -------
+        temp_path : str
+            A path to a file in this model's temporary directory.
+        """
+        return self._tmp_dir.relpath(file_name)
+
+    def export_package(
+        self,
+        executor_factory: GraphExecutorFactoryModule,
+        package_path: Optional[str] = None,
+        cross: Optional[Union[str, Callable]] = None,
+        lib_format: str = "so",
+    ):
+        """Save this TVMCModel to file.
+        Parameters
+        ----------
+        executor_factory : GraphExecutorFactoryModule
+            The factory containing compiled the compiled artifacts needed to run this model.
+        package_path : str, None
+            Where the model should be saved. Note that it will be packaged as a .tar file.
+            If not provided, the package will be saved to a generically named file in tmp.
+        cross : str or callable object, optional
+            Function that performs the actual compilation.
+        lib_format : str
+            How to export the modules function library. Must be one of "so" or "tar".
+
+        Returns
+        -------
+        package_path : str
+            The path that the package was saved to.
+        """
+        if lib_format not in ["so", "tar"]:
+            raise TVMCException("Only .so and .tar export formats are supported.")
+        lib_name = "mod." + lib_format
+        graph_name = "mod.json"
+        param_name = "mod.params"
+
+        temp = self._tmp_dir
+        if package_path is None:
+            package_path = temp.relpath("model_package.tar")
+        path_lib = temp.relpath(lib_name)
+
+        if not cross:
+            executor_factory.get_lib().export_library(path_lib)
+        else:
+            executor_factory.get_lib().export_library(
+                path_lib, tvm.contrib.cc.cross_compiler(cross)
+            )
+        self.lib_path = path_lib
+
+        with open(temp.relpath(graph_name), "w") as graph_file:
+            graph_file.write(executor_factory.get_json())
+
+        with open(temp.relpath(param_name), "wb") as params_file:
+            params_file.write(relay.save_param_dict(executor_factory.get_params()))
+
+        # Package up all the temp files into a tar file.
+        with tarfile.open(package_path, "w") as tar:
+            tar.add(path_lib, lib_name)
+            tar.add(temp.relpath(graph_name), graph_name)
+            tar.add(temp.relpath(param_name), param_name)
+
+        return package_path
+
+    def summary(self, file: TextIO = None):

Review comment:
       Why the IR is called "summary"?

##########
File path: python/tvm/auto_scheduler/search_task.py
##########
@@ -43,40 +43,74 @@
 
 @tvm._ffi.register_object("auto_scheduler.HardwareParams")
 class HardwareParams(Object):
-    """The parameters of target hardware used to guide the search policy
+    """The parameters of target hardware used to guide the search policy.
+
+    When a parameter isn't provided, it will instead use the
+    current machine's default value if target is specified.
     TODO(jcf94): This is considered to be merged with the new Target specification:
     https://discuss.tvm.apache.org/t/rfc-tvm-target-specification/6844
     Parameters
     ----------
-    num_cores : int
+    num_cores : int, optional
         The number of device cores.
-    vector_unit_bytes : int
+    vector_unit_bytes : int, optional
         The width of vector units in bytes.
-    cache_line_bytes : int
+    cache_line_bytes : int, optional
         The size of cache line in bytes.
-    max_shared_memory_per_block : int
+    max_shared_memory_per_block : int, optional
         The max shared memory per block in bytes.
-    max_local_memory_per_block : int
+    max_local_memory_per_block : int, optional
         The max local memory per block in bytes.
-    max_threads_per_block : int
+    max_threads_per_block : int, optional
         The max number of threads per block.
-    max_vthread_extent : int
+    max_vthread_extent : int, optional
         The max vthread extent.
-    warp_size : int
+    warp_size : int, optional
         The thread numbers of a warp.
+    target : str or Target, optional
+        The compilation target. Used to determine default values if provided.
+    target_host : str or Target, optional
+        The compilation target host. Used to determine default values if provided.
     """
 
     def __init__(
         self,
-        num_cores,
-        vector_unit_bytes,
-        cache_line_bytes,
-        max_shared_memory_per_block,
-        max_local_memory_per_block,
-        max_threads_per_block,
-        max_vthread_extent,
-        warp_size,
+        num_cores=None,
+        vector_unit_bytes=None,
+        cache_line_bytes=None,
+        max_shared_memory_per_block=None,
+        max_local_memory_per_block=None,
+        max_threads_per_block=None,
+        max_vthread_extent=None,
+        warp_size=None,
+        target=None,
+        target_host=None,
     ):
+        # If target is provided, get the default paramters for this machine.
+        if target is not None:

Review comment:
       Does this work when all arguments (including target) are None?

##########
File path: python/tvm/driver/tvmc/compiler.py
##########
@@ -182,25 +188,21 @@ def compile_model(
 
     Returns
     -------
-    graph : str
-        A JSON-serialized TVM execution graph.
-    lib : tvm.module.Module
-        A TVM module containing the compiled functions.
-    params : dict
-        The parameters (weights) for the TVM module.
-    dumps : dict
-        Dictionary containing the dumps specified.
+    compiled_model : TVMCPackage
+        The compiled TVMCModel ready to be run.
 
     """
-    dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else None
+    mod, params = tvmc_model.mod, tvmc_model.params
+
     config = {}
 
-    if alter_layout:
-        mod = common.convert_graph_layout(mod, alter_layout)
+    if desired_layout:
+        mod = common.convert_graph_layout(mod, desired_layout)
 
     tvm_target, extra_targets = common.target_from_cli(target)
     target_host = tvm_target if not target_host else target_host
-    tvm_target, target_host = Target.check_and_update_host_consist(tvm_target, target_host)
+    if target_host is not None:

Review comment:
       Seems not possible to be None?

##########
File path: tests/python/driver/tvmc/test_autoscheduler.py
##########
@@ -26,28 +24,30 @@
 
 
 def _get_tasks(model):
-    mod, params = tvmc.frontends.load_model(model)
-    tasks, weights = tvmc.autotuner.autoscheduler_get_tuning_tasks(mod, params, "llvm")
+    tvmc_model = tvmc.frontends.load_model(model)
+    tasks, weights = tvmc.autotuner.autoscheduler_get_tuning_tasks(
+        tvmc_model.mod, tvmc_model.params, "llvm"
+    )
     return (tasks, weights)
 
 
-def _autoscheduler_test_helper(
-    model, tmpdir_name, tasks_weights=None, early_stopping=1, tuning_records=None
-):
-    tasks, weights = tasks_weights if tasks_weights else _get_tasks(model)
+def _autoscheduler_test_helper(model, tmpdir_name, early_stopping=1, prior_records=None):
+    tvmc_model = tvmc.frontends.load_model(model)
     log_file = os.path.join(tmpdir_name, "autoscheduler.json")
 
-    tuning_options = auto_scheduler.TuningOptions(
-        num_measure_trials=1,
-        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
-        runner="local",
-        builder="local",
-        verbose=0,
+    hardware_params = auto_scheduler.HardwareParams(num_cores=4, target="llvm")
+
+    tvmc.tune(
+        tvmc_model,
+        target="llvm",
+        tuning_records=log_file,
+        prior_records=prior_records,
         early_stopping=early_stopping,
+        enable_autoscheduler=True,
+        trials=32,

Review comment:
       Can we avoid really tuning for 32 trials on CI?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org