You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/29 17:35:28 UTC
[tvm] branch main updated: [MetaSchedule] Add Script for TorchBench Model Tuning & Benchmarking (#12914)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2379917985 [MetaSchedule] Add Script for TorchBench Model Tuning & Benchmarking (#12914)
2379917985 is described below
commit 2379917985919ed3918dc12cad47f469f245be7a
Author: Lite Ye <ye...@gmail.com>
AuthorDate: Thu Sep 29 13:35:21 2022 -0400
[MetaSchedule] Add Script for TorchBench Model Tuning & Benchmarking (#12914)
This PR adds a script to tune and benchmark TorchBench models, using torchdynamo and the pytorch importer in TVM.
---
.../meta_schedule/testing/torchbench/__init__.py | 16 +
python/tvm/meta_schedule/testing/torchbench/run.py | 609 +++++++++++++++++++++
.../tvm/meta_schedule/testing/torchbench/utils.py | 103 ++++
3 files changed, 728 insertions(+)
diff --git a/python/tvm/meta_schedule/testing/torchbench/__init__.py b/python/tvm/meta_schedule/testing/torchbench/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/python/tvm/meta_schedule/testing/torchbench/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/python/tvm/meta_schedule/testing/torchbench/run.py b/python/tvm/meta_schedule/testing/torchbench/run.py
new file mode 100644
index 0000000000..f6984d1c9d
--- /dev/null
+++ b/python/tvm/meta_schedule/testing/torchbench/run.py
@@ -0,0 +1,609 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This script is for benchmarking TVM performance on models from TorchBench.
+It uses the TorchDynamo as the frontend to ingest models into TVM, and it also
+leverages the benchmark util from TorchDynamo.
+
+TorchDynamo (https://github.com/pytorch/torchdynamo) and TorchBench
+(https://github.com/pytorch/benchmark) need to be in the parent directory of TVM.
+We need a local clone of these repos because torchbench and the benchmark runner
+in TorchDynamo isn't designed to be used as a Python package.
+
+To setup the environment, run the following commands in the parent directory of TVM and with
+the appropriate Python environment:
+```bash
+# torchdynamo requires nightly pytorch. If it fails to find the specified version, try
+# installing the latest nightly pytorch.
+pip3 install --pre \
+ --extra-index-url https://download.pytorch.org/whl/nightly/cu116 \
+ torch==1.13.0.dev20220926 \
+ torchvision==0.14.0.dev20220926 \
+ torchtext==0.14.0.dev20220926
+
+git clone https://github.com/pytorch/torchdynamo
+pushd torchdynamo
+git checkout c537639f9712621dc04ca09908796dbbe86c354b
+pip install -e .
+popd
+
+sudo apt install git-lfs # git lfs is used for TorchBench
+git clone https://github.com/pytorch/benchmark
+pushd benchmark
+python install.py --continue_on_fail # fambench_xlmr might fail to install
+popd
+```
+
+To run a benchmark, the script can be run under 'tune' mode by
+```bash
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+ --mode tune \
+ --model resnet50 \
+ --target "nvidia/geforce-rtx-3070" \
+ --work-dir ../workdir \
+ --num-trials 20000 \
+ --rpc-host <rpc tracker host for tuning> \
+ --rpc-port <rpc tracker port for tuning> \
+ --rpc-key <rpc key> \
+```
+
+All available target tags (like nvidia/geforce-rtx-3070) can be found at
+https://github.com/apache/tvm/blob/main/src/target/tag.cc
+
+Then the script can be run under 'eval' mode to actual benchmark the performance,
+using the tuning database under the work directory. This can be executed on a different
+machine than the one executes tuning (the database json files need to be inside
+of the work directory).
+```bash
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+ --mode eval \
+ --model resnet50 \
+ --target "nvidia/geforce-rtx-3070" \
+ --work-dir ../workdir \
+ --num-trials 0
+```
+
+Alternatively, both tuning and evaluation can be done in a single run on the same machine,
+by
+```bash
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+ --mode all \
+ --model resnet50 \
+ --target "llvm -num-cores 6" \
+ --work-dir ../workdir \
+ --num-trials 0
+```
+"""
+
+# pylint: disable=logging-format-interpolation
+
+import argparse
+import functools
+import logging
+import warnings
+from enum import Enum
+from typing import Callable, List, Tuple
+
+import numpy as np # type: ignore
+import torch # type: ignore
+from scipy.stats import ttest_ind # type: ignore
+
+import tvm
+import tvm.relay
+from tvm import meta_schedule as ms
+from tvm.contrib.graph_executor import GraphModule
+from tvm.meta_schedule.testing.torchbench.utils import (
+ load_torchdynamo_benchmark_runner,
+ same,
+ timed,
+)
+from tvm.runtime.vm import VirtualMachine
+from tvm.support import describe
+
+# Needs to be imported after the .utils is executed
+import torchdynamo # type: ignore # isort: skip, pylint: disable=wrong-import-order
+
+
+class RunMode(Enum):
+ """
+ The running mode of this script. Available values are:
+ - tune: Only tune the model and create the tuning database.
+ - eval: Only benchmark model using pre-existing tuning database.
+ - all: Run both tuning and benchmark
+ """
+
+ ALL = "all"
+ TUNE = "tune"
+ EVAL = "eval"
+
+ @property
+ def should_tune(self):
+ """
+ Returns whether it should tune the model.
+ """
+ return self != RunMode.EVAL
+
+ @property
+ def should_eval(self):
+ """
+ Returns whether it should actually benchmark the model.
+ """
+ return self != RunMode.TUNE
+
+
+class ResultComparisonMetric(Enum):
+ """
+ This changes how it compares the resultl with the expected value during
+ accuracy check.
+ - cosine: Use the cosine similarity. It should be greater than 0.99.
+ - allclose-1e-4: Use the max element-wise absolute difference. It should be less than 1e-4.
+ """
+
+ COSINE = "cosine"
+ ALLCLOSE = "allclose-1e-4"
+
+
+def parse_args():
+ """
+ Parse arguments
+ """
+ args = argparse.ArgumentParser()
+
+ args.add_argument(
+ "--mode",
+ type=RunMode,
+ default=RunMode.ALL,
+ help=RunMode.__doc__,
+ )
+ args.add_argument(
+ "--batch-size",
+ type=int,
+ default=None,
+ help="The batch size of model input. Use TorchBench's default value if not specified.",
+ )
+ args.add_argument(
+ "--result-metric",
+ type=ResultComparisonMetric,
+ default=ResultComparisonMetric.ALLCLOSE,
+ help=ResultComparisonMetric.__doc__,
+ )
+ args.add_argument(
+ "--benchmark-repeat",
+ type=int,
+ default=10,
+ help="The number of times to repeat the benchmark measurement.",
+ )
+ args.add_argument(
+ "--benchmark-warmup-rounds",
+ type=int,
+ default=5,
+ help="The number of rounds to warmup before starting to measure the performance.",
+ )
+
+ # Model selection
+ args.add_argument(
+ "--model",
+ type=str,
+ required=True,
+ help="""
+ The name of model to run. It should a directory name under
+ https://github.com/pytorch/benchmark/tree/main/torchbenchmark/models.
+ """,
+ )
+
+ # Tuning-related config
+ args.add_argument(
+ "--target",
+ type=tvm.target.Target,
+ required=True,
+ help="The target to tune and run benchmark for.",
+ )
+ args.add_argument(
+ "--work-dir",
+ type=str,
+ required=True,
+ help="""
+ The working directory to save intermediate results and store databases for compilation.
+ """,
+ )
+ args.add_argument(
+ "--cache-dir",
+ type=str,
+ default=None,
+ help="""
+ The directory to cache the generated network.
+ If not specified, the cache will be disabled.
+ """,
+ )
+ args.add_argument(
+ "--num-trials",
+ type=int,
+ required=True,
+ help="The max number of trials to run MetaSchedule.",
+ )
+ args.add_argument(
+ "--max-trials-per-task",
+ type=int,
+ default=None,
+ help="""
+ The max number of trials to run per task extracted in MetaSchedule.
+ By default it's the same as --num-trials.
+ """,
+ )
+ args.add_argument(
+ "--backend",
+ type=str,
+ choices=["graph", "vm"],
+ default="graph",
+ help="The backend to use for relay compilation(graph / vm).",
+ )
+ # TODO(@yelite): Add a layout arg to transform the network after
+ # ingesting into Relay and before feeding into MetaSchedule.
+
+ # Evaluator-related config
+ args.add_argument(
+ "--number",
+ type=int,
+ default=3,
+ help="The number of times to run the model for taking average in a single measurement.",
+ )
+ args.add_argument(
+ "--repeat",
+ type=int,
+ default=1,
+ help="The number of times to repeat the measurement.",
+ )
+ args.add_argument(
+ "--min-repeat-ms",
+ type=int,
+ default=100,
+ help="""
+ Minimum repeat time in ms. The number of runs will be increased if the actual
+ repeat time is lowered than this.
+ """,
+ )
+ args.add_argument(
+ "--adaptive-training",
+ action="store_true",
+ help="Whether to use adpative training for cost model.",
+ )
+ args.add_argument(
+ "--cpu-flush",
+ action="store_true",
+ help="Whether to perform CPU cache flush.",
+ )
+
+ # RPC-related args
+ args.add_argument(
+ "--rpc-host",
+ type=str,
+ help="Host of the RPC Tracker for tuning. Use LocalRunner if not provided",
+ )
+ args.add_argument(
+ "--rpc-port",
+ type=int,
+ help="Port of the RPC Tracker for tuning",
+ )
+ args.add_argument(
+ "--rpc-key",
+ type=str,
+ help="Key of the RPC Tracker for tuning",
+ )
+
+ parsed = args.parse_args()
+ return parsed
+
+
+logging.basicConfig(
+ format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
+)
+logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
+ARGS = parse_args()
+IS_CUDA = ARGS.target.kind.name == "cuda"
+
+logger = logging.getLogger(__name__) # pylint: disable=invalid-name
+logger.setLevel(logging.INFO)
+
+
+runner = load_torchdynamo_benchmark_runner( # pylint: disable=invalid-name
+ IS_CUDA, cosine_similarity=ARGS.result_metric == ResultComparisonMetric.COSINE
+)
+
+
+def get_metaschedule_runner() -> ms.runner.PyRunner:
+ """
+ Get the Runner for MetaSchedule.
+
+ It returns RPCRunner if --rpc-host is given, otherwise it returns LocalRunner
+ """
+ if ARGS.rpc_host is not None:
+ assert ARGS.rpc_port is not None, "Missing rpc_port"
+ assert ARGS.rpc_key is not None, "Missing rpc_key"
+ return ms.runner.RPCRunner(
+ rpc_config=ms.runner.RPCConfig(
+ tracker_host=ARGS.rpc_host,
+ tracker_port=ARGS.rpc_port,
+ tracker_key=ARGS.rpc_key,
+ session_timeout_sec=600,
+ ),
+ evaluator_config=ms.runner.EvaluatorConfig(
+ number=ARGS.number,
+ repeat=ARGS.repeat,
+ min_repeat_ms=ARGS.min_repeat_ms,
+ enable_cpu_cache_flush=ARGS.cpu_flush,
+ ),
+ alloc_repeat=1,
+ )
+ else:
+ warnings.warn("Falling back to Metaschedule LocalRunner because --rpc-host isn't provided.")
+ return ms.runner.LocalRunner()
+
+
+def get_tune_config() -> ms.TuneConfig:
+ """
+ Get the TuneConfig.
+ """
+ if ARGS.mode.should_tune:
+ max_trials_per_task = ARGS.max_trials_per_task
+ max_trials_global = ARGS.num_trials
+ else:
+ max_trials_per_task = 0
+ max_trials_global = 0
+
+ if max_trials_per_task is None:
+ max_trials_per_task = max_trials_global
+
+ return ms.TuneConfig(
+ strategy="evolutionary",
+ num_trials_per_iter=64,
+ max_trials_per_task=max_trials_per_task,
+ max_trials_global=max_trials_global,
+ adaptive_training=ARGS.adaptive_training,
+ )
+
+
+def get_graph_executor_forward(mod: GraphModule, device: tvm.runtime.Device) -> Callable:
+ """
+ Get the forward function for graph executor, in order to integrate with TorchDynamo.
+ """
+
+ def forward(*args):
+ if IS_CUDA:
+ torch.cuda.synchronize()
+ args = tuple(arg.contiguous() for arg in args)
+ for idx, arg in enumerate(args, 0):
+ mod.set_input(
+ f"inp_{idx}",
+ tvm.nd.from_dlpack(arg),
+ )
+ mod.run()
+ device.sync()
+ result = [torch.from_dlpack(mod.get_output(i)) for i in range(mod.get_num_outputs())]
+ return result
+
+ return forward
+
+
+def get_vm_forward(virtual_machine: VirtualMachine, device: tvm.runtime.Device) -> Callable:
+ """
+ Get the forward function for VM, in order to integrate with TorchDynamo.
+ """
+
+ def forward(*args):
+ if IS_CUDA:
+ torch.cuda.synchronize()
+ args = tuple(tvm.nd.from_dlpack(arg.contiguous()) for arg in args)
+ result = virtual_machine.invoke("main", *args)
+ device.sync()
+
+ if isinstance(result, tvm.nd.NDArray):
+ result = [result]
+ return [torch.from_dlpack(m) for m in result]
+
+ return forward
+
+
+def create_tvm_task_collection_backend(tasks: List[ms.ExtractedTask]) -> Callable:
+ """
+ This torchdynamo backend only collects the extracted tasks from Metaschedule.
+ It doesn't tune the model.
+ """
+
+ def backend(graph_module, example_inputs):
+ jit_mod = torch.jit.trace(graph_module, example_inputs)
+ shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
+ ir_mod, params = tvm.relay.frontend.from_pytorch(jit_mod, shape_list)
+
+ extracted_tasks = ms.extract_task_from_relay(ir_mod, ARGS.target, params)
+ logger.info("Extracted %d tasks", len(extracted_tasks))
+ tasks.extend(extracted_tasks)
+
+ return graph_module.forward
+
+ return backend
+
+
+def create_tvm_compilation_backend(database: ms.database.Database) -> Callable:
+ """
+ This torchdynamo backend compiles the model using history best record from the
+ Metaschedule database.
+ """
+
+ def backend(graph_module, example_inputs):
+ # pylint: disable=import-outside-toplevel
+ from tvm.ir.transform import PassContext
+
+ # pylint: enable=import-outside-toplevel
+
+ jit_mod = torch.jit.trace(graph_module, example_inputs)
+ shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
+ ir_mod, params = tvm.relay.frontend.from_pytorch(jit_mod, shape_list)
+
+ relay_build = {"graph": tvm.relay.build, "vm": tvm.relay.vm.compile}[ARGS.backend]
+ with ARGS.target, ms.utils.autotvm_silencer(), database:
+ with PassContext(
+ opt_level=3,
+ config={
+ "relay.backend.use_meta_schedule": True,
+ "relay.backend.use_meta_schedule_dispatch": not IS_CUDA,
+ "relay.backend.tir_converter": "default",
+ },
+ ):
+ lib = relay_build(ir_mod, target=ARGS.target, params=params)
+
+ device = tvm.cuda(0) if IS_CUDA else tvm.cpu(0)
+
+ if ARGS.backend == "graph":
+ mod = GraphModule(lib["default"](device))
+ return get_graph_executor_forward(mod, device)
+ elif ARGS.backend == "vm":
+ vm = VirtualMachine(lib, device) # pylint: disable=invalid-name
+ return get_vm_forward(vm, device)
+ else:
+ raise RuntimeError(f"Unknown backend {ARGS.backend}")
+
+ return backend
+
+
+def format_time(seconds: float) -> str:
+ """
+ Format elapsed time based on its value.
+ """
+ if seconds > 1:
+ return f"{seconds:.3g}s"
+ else:
+ return f"{seconds * 1000:.3g}ms"
+
+
+def is_output_correct(output: torch.Tensor, expected: torch.Tensor) -> bool:
+ """
+ Check whether the output is correct.
+ """
+ comparison_metric = ARGS.result_metric
+ if comparison_metric == ResultComparisonMetric.COSINE:
+ return same(expected, output, cosine_similarity=True)
+ elif comparison_metric == ResultComparisonMetric.ALLCLOSE:
+ return same(expected, output, tol=1e-4)
+ else:
+ raise RuntimeError(f"Unknown comparison metric {comparison_metric}")
+
+
+def performance_experiment(
+ model_iter_fn: Callable, model: torch.nn.Module, example_inputs: Tuple[torch.Tensor]
+) -> str:
+ """
+ Performs the actual benchmarking
+ Simplified from https://github.com/pytorch/torchdynamo/blob/c537639f9712621dc04ca09908796dbbe86c354b/benchmarks/common.py#L494 pylint: disable=line-too-long
+ """
+ timings = np.zeros((ARGS.benchmark_repeat, 2), np.float64)
+
+ is_correct = True
+
+ frozen_model_iter_fn = torchdynamo.run(model_iter_fn)
+
+ for _ in range(ARGS.benchmark_warmup_rounds):
+ frozen_model_iter_fn(model, example_inputs)
+ model_iter_fn(model, example_inputs)
+
+ for rep in range(ARGS.benchmark_repeat):
+ # interleave the runs to handle frequency scaling and load changes
+ timings[rep, 0], expected_output = timed(
+ model, model_iter_fn, example_inputs, return_result=True
+ )
+ timings[rep, 1], actual_output = timed(
+ model, frozen_model_iter_fn, example_inputs, return_result=True
+ )
+ is_correct = is_correct and is_output_correct(expected_output, actual_output)
+
+ pvalue = ttest_ind(timings[:, 0], timings[:, 1]).pvalue
+ median = np.median(timings, axis=0)
+ speedup = median[0] / median[1]
+ logger.info(
+ f"eager:{format_time(median[0])} "
+ f"optimized:{format_time(median[1])} "
+ f"speedup:{speedup:.3f}x p:{pvalue:.3f}"
+ )
+ if not is_correct:
+ logger.error("Result is incorrect.")
+ logger.error(f"Expected (PyTorch eager): {expected_output}")
+ logger.error(f"Actual (Optimized): {actual_output}")
+
+ return ""
+
+
+def get_torch_device_type(target: tvm.target.Target) -> str:
+ if target.kind.name == "llvm":
+ return "cpu"
+ elif target.kind.name == "cuda":
+ return "cuda"
+ else:
+ raise RuntimeError(f"Unsupported target {target}")
+
+
+def main():
+ """
+ Entry point of the benchmark
+ """
+ describe()
+
+ if not ARGS.mode.should_tune:
+ ms_database = ms.default_config.database(None, ARGS.work_dir)
+ if len(ms_database) == 0:
+ raise RuntimeError(
+ "Script is runnig in eval mode while the tuning database is empty. "
+ "Please tune the model first."
+ )
+
+ if IS_CUDA and ARGS.cpu_flush:
+ warnings.warn(
+ "Benchmark is running on CUDA, while --cpu-flush is turned on. "
+ "This flag will have no effect on CUDA."
+ )
+
+ try:
+ _, name, model, example_inputs, batch_size = runner.load_model(
+ get_torch_device_type(ARGS.target),
+ ARGS.model,
+ batch_size=ARGS.batch_size,
+ )
+ logger.info(
+ f"batch size: {batch_size} input shape: {[input.shape for input in example_inputs]}"
+ )
+ except NotImplementedError:
+ logging.exception(f"{ARGS.model} failed to load")
+ return
+
+ tuning_tasks: List[ms.ExtractedTask] = []
+ task_collect_ctx = torchdynamo.optimize(create_tvm_task_collection_backend(tuning_tasks))
+ task_collect_ctx(runner.model_iter_fn)(model, example_inputs)
+
+ database = ms.tune_extracted_tasks(
+ extracted_tasks=tuning_tasks,
+ config=get_tune_config(),
+ work_dir=ARGS.work_dir,
+ runner=get_metaschedule_runner(), # type: ignore
+ )
+
+ if ARGS.mode.should_eval:
+ torchdynamo.reset()
+ model_compile_ctx = torchdynamo.optimize(create_tvm_compilation_backend(database))
+ experiment = functools.partial(performance_experiment, runner.model_iter_fn)
+ runner.run_one_model(name, model, example_inputs, model_compile_ctx, experiment)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/python/tvm/meta_schedule/testing/torchbench/utils.py b/python/tvm/meta_schedule/testing/torchbench/utils.py
new file mode 100644
index 0000000000..f5a745ea00
--- /dev/null
+++ b/python/tvm/meta_schedule/testing/torchbench/utils.py
@@ -0,0 +1,103 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Helper functions for running TorchBench through the benchmark functions
+from TorchDynamo.
+"""
+
+import os
+import sys
+from dataclasses import dataclass
+
+import torch # type: ignore
+
+
+def find_torchdynamo() -> str:
+ """
+ Find the directory of TorchDynamo repo.
+
+ It can't directly import the benchmark runner in TorchDynamo
+ becuase it isn't designed to be used as a Python package.
+ """
+ candidates = [
+ "torchdynamo",
+ "../torchdynamo",
+ "../../torchdynamo",
+ ]
+ for library_dir in candidates:
+ if os.path.exists(f"{library_dir}/benchmarks"):
+ return library_dir
+
+ raise RuntimeError(
+ """
+ Cannot find directory for torchdynamo.
+ You need to clone https://github.com/pytorch/torchdynamo to the parent directory of cwd.
+ """
+ )
+
+
+DYNAMO_DIR = find_torchdynamo()
+sys.path.append(DYNAMO_DIR)
+sys.path.append(f"{DYNAMO_DIR}/benchmarks")
+
+# pylint: disable=wrong-import-position, unused-import
+from benchmarks.common import same, timed # type: ignore
+from torchbench import TorchBenchmarkRunner # type: ignore
+
+# pylint: disable=wrong-import-position, unused-import
+
+
+def load_torchdynamo_benchmark_runner(
+ is_cuda: bool, cosine_similarity: bool = False
+) -> TorchBenchmarkRunner:
+ """
+ Load the benchmark runner from TorchDynamo.
+ """
+
+ @dataclass
+ class RunnerArgs:
+ """
+ This class simulates the parsed args required by the benchmark code from TorchDynamo.
+ """
+
+ ci: bool = False # Whether runs in CI mode. pylint: disable=invalid-name
+ training: bool = False # Whether it benchmarks training workload.
+ use_eval_mode: bool = True # Whether the model should be in eval mode.
+ dynamic_shapes: bool = False # Whether runs the model in dynamic shape mode.
+ float16: bool = False # Whether to cast model and inputs to float16
+ float32: bool = False # Whether to cast model and inputs to float32
+
+ accuracy: bool = False # Whether to perform a accuracy test
+ performance: bool = True # Whether to perform a performance test
+
+ cosine: bool = False # Whether to use consine similarity to check if output is correct.
+
+ args = RunnerArgs(cosine=cosine_similarity)
+
+ runner = TorchBenchmarkRunner()
+ runner.args = args
+ runner.model_iter_fn = runner.forward_pass
+
+ if is_cuda:
+ # pylint: disable=import-outside-toplevel
+ import benchmarks.common # type: ignore
+
+ # pylint: enable=import-outside-toplevel
+
+ benchmarks.common.synchronize = torch.cuda.synchronize
+
+ return runner