You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/09/27 21:58:37 UTC

[GitHub] [tvm] shingjan commented on a diff in pull request #12914: Add a script to tune and benchmark models from TorchBench

shingjan commented on code in PR #12914:
URL: https://github.com/apache/tvm/pull/12914#discussion_r981737528


##########
python/tvm/meta_schedule/testing/torchbench/run.py:
##########
@@ -0,0 +1,511 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This script is for benchmarking TVM performance on models from TorchBench.
+It uses the TorchDynamo as the frontend to ingest models into TVM, and it also
+leverages the benchmark util from TorchDynamo.
+
+TorchDynamo (https://github.com/pytorch/torchdynamo) and TorchBench
+(https://github.com/pytorch/benchmark) need to be in the parent directory of TVM.
+We need a local clone of these repos because torchbench and the benchmark runner
+in TorchDynamo isn't designed to be used as a Python package.
+
+To setup the environment, run the following commands in the parent directory of TVM and with 
+the appropriate Python environment:
+```bash
+# torchdynamo requires nightly pytorch. If it fails to find the specified version, try
+# installing the latest nightly pytorch.
+pip3 install --pre \
+    --extra-index-url https://download.pytorch.org/whl/nightly/cu116 \
+    torch==1.13.0.dev20220926 \
+    torchvision==0.14.0.dev20220926 \
+    torchtext==0.14.0.dev20220926
+
+git clone https://github.com/pytorch/torchdynamo
+pushd torchdynamo
+git checkout c537639f9712621dc04ca09908796dbbe86c354b
+pip install -e .
+popd
+
+sudo apt install git-lfs  # git lfs is used for TorchBench
+git clone https://github.com/pytorch/benchmark
+pushd benchmark
+python install.py --continue_on_fail  # fambench_xlmr might fail to install
+popd
+```
+
+To run a benchmark, the script can be run under 'tune' mode by
+```bash
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+    --mode tune \
+    --model resnet50 \
+    --target "nvidia/geforce-rtx-3070" \
+    --work-dir ../workdir \
+    --num-trials 20000 \
+    --rpc-host <rpc tracker host for tuning> \
+    --rpc-port <rpc tracker port for tuning> \
+    --rpc-key <rpc key> \
+```
+
+All available target tags (like nvidia/geforce-rtx-3070) can be found at
+https://github.com/apache/tvm/blob/main/src/target/tag.cc
+
+Then the script can be run under 'eval' mode to actual benchmark the performance,
+using the tuning database under the work directory. This can be executed on a different
+machine than the one executes tuning.
+```bash
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+    --mode eval \

Review Comment:
   NIT: As no perf evaluation will be done with `--tuning`, I feel like we should combine `tuning` and `all`, if perf evaluation doesn't really take much time. 



##########
python/tvm/meta_schedule/testing/torchbench/run.py:
##########
@@ -0,0 +1,511 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This script is for benchmarking TVM performance on models from TorchBench.
+It uses the TorchDynamo as the frontend to ingest models into TVM, and it also
+leverages the benchmark util from TorchDynamo.
+
+TorchDynamo (https://github.com/pytorch/torchdynamo) and TorchBench
+(https://github.com/pytorch/benchmark) need to be in the parent directory of TVM.
+We need a local clone of these repos because torchbench and the benchmark runner
+in TorchDynamo isn't designed to be used as a Python package.
+
+To setup the environment, run the following commands in the parent directory of TVM and with 
+the appropriate Python environment:
+```bash
+# torchdynamo requires nightly pytorch. If it fails to find the specified version, try
+# installing the latest nightly pytorch.
+pip3 install --pre \
+    --extra-index-url https://download.pytorch.org/whl/nightly/cu116 \
+    torch==1.13.0.dev20220926 \
+    torchvision==0.14.0.dev20220926 \
+    torchtext==0.14.0.dev20220926
+
+git clone https://github.com/pytorch/torchdynamo
+pushd torchdynamo
+git checkout c537639f9712621dc04ca09908796dbbe86c354b
+pip install -e .
+popd
+
+sudo apt install git-lfs  # git lfs is used for TorchBench
+git clone https://github.com/pytorch/benchmark
+pushd benchmark
+python install.py --continue_on_fail  # fambench_xlmr might fail to install
+popd
+```
+
+To run a benchmark, the script can be run under 'tune' mode by
+```bash
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+    --mode tune \
+    --model resnet50 \
+    --target "nvidia/geforce-rtx-3070" \
+    --work-dir ../workdir \
+    --num-trials 20000 \
+    --rpc-host <rpc tracker host for tuning> \
+    --rpc-port <rpc tracker port for tuning> \
+    --rpc-key <rpc key> \
+```
+
+All available target tags (like nvidia/geforce-rtx-3070) can be found at
+https://github.com/apache/tvm/blob/main/src/target/tag.cc
+
+Then the script can be run under 'eval' mode to actual benchmark the performance,
+using the tuning database under the work directory. This can be executed on a different
+machine than the one executes tuning.
+```bash
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+    --mode eval \
+    --model resnet50 \
+    --target "nvidia/geforce-rtx-3070" \
+    --work-dir ../workdir \
+    --num-trials 0 
+```
+
+Alternatively, both tuning and evaluation can be done in a single run on the same machine,
+by
+```bash
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+    --model resnet50 \
+    --target "llvm -num-cores 6" \
+    --work-dir ../workdir \
+    --num-trials 0
+```
+"""
+
+import argparse
+import functools
+import logging
+import warnings
+
+import numpy as np
+import torch
+from scipy.stats import ttest_ind
+
+import tvm
+import tvm.relay
+from tvm import meta_schedule as ms
+from tvm.contrib import graph_executor
+from tvm.meta_schedule.testing.torchbench.utils import (
+    load_torchdynamo_benchmark_runner,
+    same,
+    timed,
+)
+from tvm.runtime.vm import VirtualMachine
+from tvm.support import describe
+
+
+def parse_args():
+    args = argparse.ArgumentParser()
+
+    args.add_argument(
+        "--mode",
+        type=str,
+        choices=["tune", "eval", "all"],
+        default="all",
+        help="""
+        The running mode of this script. Available values are:
+        - tune: Only tune the model and create the tuning database.
+        - eval: Only benchmark model using pre-existing tuning database.
+        - all: Run both tuning and benchmark
+        """,
+    )
+    args.add_argument(
+        "--batch-size",
+        type=int,
+        default=None,
+        help="The batch size of model input. Use TorchBench's default value if not specified.",
+    )
+    args.add_argument(
+        "--cosine-similarity",
+        action="store_true",
+        help="""
+        Whether to use cosine similarity to determine whether the output is the same as 
+        expected. By default torch.allclose is used.
+        """,
+    )
+    args.add_argument(
+        "--benchmark-repeat",
+        type=int,
+        default=10,
+        help="The number of times to repeat the benchmark measurement.",
+    )
+
+    # Model selection
+    args.add_argument(
+        "--model",
+        type=str,
+        required=True,
+        help="""
+        The name of model to run. It should a directory name under 
+        https://github.com/pytorch/benchmark/tree/main/torchbenchmark/models.
+        """,
+    )
+
+    # Tuning-related config
+    args.add_argument(
+        "--target",
+        type=tvm.target.Target,
+        required=True,
+        help="The target to tune and run benchmark for.",
+    )
+    args.add_argument(
+        "--work-dir",
+        type=str,
+        required=True,
+        help="The working directory to save intermediate results.",
+    )
+    args.add_argument(
+        "--cache-dir",
+        type=str,
+        default=None,
+        help="""
+        The directory to cache the generated network.
+        If not specified, the cache will be disabled.
+        """,
+    )
+    args.add_argument(
+        "--num-trials",
+        type=int,
+        required=True,
+        help="The max number of trials to run MetaSchedule.",
+    )
+    args.add_argument(
+        "--max-trials-per-task",
+        type=int,
+        default=None,
+        help="""
+        The max number of trials to run per task extracted in MetaSchedule. 
+        By default it's the same as --num-trials.
+        """,
+    )
+    args.add_argument(
+        "--backend",
+        type=str,
+        choices=["graph", "vm"],
+        default="graph",
+        help="The backend to use for relay compilation(graph / vm).",
+    )
+    # TODO(@yelite): Add a layout arg to transform the network after
+    # ingesting into Relay and before feeding into MetaSchedule.
+
+    # Evaluator-related config
+    args.add_argument(
+        "--number",
+        type=int,
+        default=3,
+        help="The number of times to run the model for taking average in a single measurement.",
+    )
+    args.add_argument(
+        "--repeat",
+        type=int,
+        default=1,
+        help="The number of times to repeat the measurement.",
+    )
+    args.add_argument(
+        "--min-repeat-ms",
+        type=int,
+        default=100,
+        help="""
+        Minimum repeat time in ms. The number of runs will be increased if the actual
+        repeat time is lowered than this.
+        """,
+    )
+    args.add_argument(
+        "--adaptive-training",
+        action="store_true",
+        help="Whether to use adpative training for cost model.",
+    )
+    args.add_argument(
+        "--cpu-flush",
+        action="store_true",
+        help="Whether to perform CPU cache flush.",
+    )
+
+    # RPC-related args
+    args.add_argument(
+        "--rpc-host",
+        type=str,
+        help="Host of the RPC Tracker for tuning. Use LocalRunner if not provided",
+    )
+    args.add_argument(
+        "--rpc-port",
+        type=int,
+        help="Port of the RPC Tracker for tuning",
+    )
+    args.add_argument(
+        "--rpc-key",
+        type=str,
+        help="Key of the RPC Tracker for tuning",
+    )
+
+    parsed = args.parse_args()
+    return parsed
+
+
+logging.basicConfig(
+    format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
+)
+logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
+ARGS = parse_args()
+IS_CUDA = ARGS.target.kind.name == "cuda"
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+runner = load_torchdynamo_benchmark_runner(IS_CUDA)
+import torchdynamo
+
+
+def get_metaschedule_runner():
+    if ARGS.rpc_host is not None:
+        return ms.runner.RPCRunner(
+            rpc_config=ms.runner.RPCConfig(
+                tracker_host=ARGS.rpc_host,
+                tracker_port=ARGS.rpc_port,
+                tracker_key=ARGS.rpc_key,
+                session_timeout_sec=600,
+            ),
+            evaluator_config=ms.runner.EvaluatorConfig(
+                number=ARGS.number,
+                repeat=ARGS.repeat,
+                min_repeat_ms=ARGS.min_repeat_ms,
+                enable_cpu_cache_flush=ARGS.cpu_flush,
+            ),
+            alloc_repeat=1,
+        )
+    else:
+        warnings.warn("Falling back to Metaschedule LocalRunner because --rpc-host isn't provided.")
+        return ms.runner.LocalRunner()
+
+
+def get_tune_config():
+    if ARGS.mode == "eval":
+        max_trials_per_task = 0
+        max_trials_global = 0
+    else:
+        max_trials_per_task = ARGS.max_trials_per_task
+        max_trials_global = ARGS.num_trials
+
+    if max_trials_per_task is None:
+        max_trials_per_task = max_trials_global
+
+    return ms.TuneConfig(
+        strategy="evolutionary",
+        num_trials_per_iter=64,
+        max_trials_per_task=max_trials_per_task,
+        max_trials_global=max_trials_global,
+        adaptive_training=ARGS.adaptive_training,
+    )
+
+
+def get_graph_executor_forward(mod, device):
+    def forward(*args):
+        if IS_CUDA:
+            torch.cuda.synchronize()
+        args = [arg.contiguous() for arg in args]
+        for idx, arg in enumerate(args, 0):
+            mod.set_input(
+                f"inp_{idx}",
+                tvm.nd.from_dlpack(arg),

Review Comment:
   i think this could potentially be a problem and that is reason why in torchdynamo's TVM backend torch.Tensor is converted to numpy and then to TVM.NDarray. And if the arg is typed `torch.Tensor` you may need `torch.utils.dlpack.to_dlpack(arg)` for this approach as well.



##########
python/tvm/meta_schedule/testing/torchbench/run.py:
##########
@@ -0,0 +1,508 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This script is for benchmarking TVM performance on models from TorchBench.
+It uses the TorchDynamo as the frontend to ingest models into TVM, and it also
+leverages the benchmark util from TorchDynamo.
+
+TorchDynamo (https://github.com/pytorch/torchdynamo) and TorchBench
+(https://github.com/pytorch/benchmark) need to be in the parent directory of TVM.
+We need a local clone of these repos because torchbench and the benchmark runner
+in TorchDynamo isn't designed to be used as a Python package.
+
+To setup the environment, run the following commands in the parent directory of TVM and with 
+the appropriate Python environment:
+```bash
+# torchdynamo requires nightly pytorch. If it fails to find the specified version, try
+# installing the latest nightly pytorch.
+pip3 install --pre \
+    --extra-index-url https://download.pytorch.org/whl/nightly/cu116 \
+    torch==1.13.0.dev20220926 \
+    torchvision==0.14.0.dev20220926 \
+    torchtext==0.14.0.dev20220926
+
+git clone https://github.com/pytorch/torchdynamo
+pushd torchdynamo
+git checkout c537639f9712621dc04ca09908796dbbe86c354b
+pip install -e .
+popd
+
+sudo apt install git-lfs  # git lfs is used for TorchBench
+git clone https://github.com/pytorch/benchmark
+pushd benchmark
+python install.py --continue_on_fail  # fambench_xlmr might fail to install
+popd
+```
+
+To run a benchmark, the script can be run under 'tune' mode by
+```bash
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+    --mode tune \
+    --model resnet50 \
+    --target "nvidia/geforce-rtx-3070" \
+    --work-dir ../workdir \
+    --num-trials 20000 \
+    --rpc-host <rpc tracker host for tuning> \
+    --rpc-port <rpc tracker port for tuning> \
+    --rpc-key <rpc key> \
+```
+
+Then the script can be run under 'eval' mode to actual benchmark the performance,
+using the tuning database under the work directory. This can be executed on a different
+machine than the one executes tuning.
+```bash
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+    --mode eval \
+    --model resnet50 \
+    --target "nvidia/geforce-rtx-3070" \
+    --work-dir ../workdir \
+    --num-trials 0 
+```
+
+Alternatively, both tuning and evaluation can be done in a single run on the same machine,
+by
+```bash
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+    --model resnet50 \
+    --target "llvm -num-cores 6" \
+    --work-dir ../workdir \
+    --num-trials 0
+```
+"""
+
+import argparse
+import functools
+import logging
+
+import numpy as np
+import torch
+from scipy.stats import ttest_ind
+
+import tvm
+import tvm.relay
+from tvm import meta_schedule as ms
+from tvm.contrib import graph_executor
+from tvm.meta_schedule.testing.torchbench.utils import (
+    load_torchdynamo_benchmark_runner,
+    same,
+    timed,
+)
+from tvm.runtime.vm import VirtualMachine
+from tvm.support import describe
+
+
+def parse_args():
+    args = argparse.ArgumentParser()
+
+    args.add_argument(
+        "--mode",
+        type=str,
+        choices=["tune", "eval", "all"],
+        default="all",
+        help="""
+        The running mode of this script. Available values are:
+        - tune: Only tune the model and create the tuning database.
+        - eval: Only benchmark model using pre-existing tuning database.
+        - all: Run both tuning and benchmark
+        """,
+    )
+    args.add_argument(
+        "--batch-size",
+        type=int,
+        default=None,
+        help="The batch size of model input. Use TorchBench's default value if not specified.",
+    )
+    args.add_argument(
+        "--cosine-similarity",
+        action="store_true",
+        help="""
+        Whether to use cosine similarity to determine whether the output is the same as 
+        expected. By default torch.allclose is used.
+        """,
+    )
+    args.add_argument(
+        "--benchmark-repeat",
+        type=int,
+        default=10,
+        help="The number of times to repeat the benchmark measurement.",

Review Comment:
   In torchdynamo there is warm up rounds. Adding this could be an option here.



##########
python/tvm/meta_schedule/testing/torchbench/run.py:
##########
@@ -0,0 +1,504 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This script is for benchmarking TVM performance on models from TorchBench.
+It uses the torchdynamo as the frontend to ingest models into TVM, and it also
+leverages the benchmark util from torchdynamo.
+
+torchdynamo (https://github.com/pytorch/torchdynamo) and TorchBench 
+(https://github.com/pytorch/benchmark) need to be in the parent directory of TVM. 
+We need a local clone of these repos because torchbench and the benchmark runner 
+in torchdynamo isn't designed to be used as a Python package.
+
+To setup the environment, run the following commands in the parent directory of TVM and with 
+the appropriate Python environment:
+```
+# torchdynamo requires nightly pytorch
+pip3 install --pre \
+    --extra-index-url https://download.pytorch.org/whl/nightly/cu116 \
+    torch==1.13.0.dev20220926 \
+    torchvision==0.14.0.dev20220926 \
+    torchtext==0.14.0.dev20220926 
+
+git clone https://github.com/pytorch/torchdynamo
+pushd torchdynamo
+git checkout c537639f9712621dc04ca09908796dbbe86c354b
+pip install -e .
+popd
+
+sudo apt install git-lfs  # git lfs is used for TorchBench
+git clone https://github.com/pytorch/benchmark
+pushd benchmark
+python install.py --continue_on_fail  # fambench_xlmr might fail to install
+popd
+```
+
+To run a benchmark, the script can be run under 'tune' mode by
+```
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+    --mode tune \
+    --model resnet50 \
+    --target "nvidia/geforce-rtx-3070" \
+    --work-dir ../workdir \
+    --num-trials 20000 \
+    --rpc-host <rpc tracker host for tuning> \
+    --rpc-port <rpc tracker port for tuning> \
+    --rpc-key <rpc key> \
+```
+
+Then the script can be run under 'eval' mode to actual benchmark the performance, 
+using the tuning database under the work directory. This can be executed on a different
+machine than the one executes tuning.
+```
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+    --mode eval \
+    --model resnet50 \
+    --target "nvidia/geforce-rtx-3070" \
+    --work-dir ../workdir \
+    --num-trials 0 
+```
+
+Alternatively, both tuning and evaluation can be done in a single run on the same machine,
+by 
+```
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+    --model resnet50 \
+    --target "llvm -num-cores 6" \
+    --work-dir ../workdir \
+    --num-trials 0
+```
+"""
+
+import argparse
+import functools
+import logging
+
+import numpy as np
+import torch
+from scipy.stats import ttest_ind
+
+import tvm
+import tvm.relay
+from tvm import meta_schedule as ms
+from tvm.contrib import graph_executor
+from tvm.meta_schedule.testing.torchbench.utils import (
+    load_torchdynamo_benchmark_runner,
+    same,
+    timed,
+)
+from tvm.runtime.vm import VirtualMachine
+from tvm.support import describe
+
+
+def parse_args():
+    args = argparse.ArgumentParser()
+
+    args.add_argument(
+        "--mode",
+        type=str,
+        choices=["tune", "eval", "all"],
+        default="all",
+        help="""
+        The running mode of this script. Available values are:
+        - tune: Only tune the model and create the tuning database.
+        - eval: Only benchmark model using pre-existing tuning database.
+        - all: Run both tuning and benchmark
+        """,
+    )
+    args.add_argument(
+        "--batch-size",
+        type=int,
+        default=None,
+        help="The batch size of model input. Use TorchBench's default value if not specified.",
+    )
+    args.add_argument(
+        "--cosine-similarity",
+        action="store_true",
+        help="""
+        Whether to use cosine similarity to determine whether the output is the same as 
+        expected. By default torch.allclose is used.
+        """,
+    )

Review Comment:
   I think the `--cosine-similarity` idea is from torchdynamo. I am okay with this argument as some models on CUDA will fail `allclose` anyway and we need to specify `--cosine` for those models.



##########
python/tvm/meta_schedule/testing/torchbench/run.py:
##########
@@ -0,0 +1,511 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This script is for benchmarking TVM performance on models from TorchBench.
+It uses the TorchDynamo as the frontend to ingest models into TVM, and it also
+leverages the benchmark util from TorchDynamo.
+
+TorchDynamo (https://github.com/pytorch/torchdynamo) and TorchBench
+(https://github.com/pytorch/benchmark) need to be in the parent directory of TVM.
+We need a local clone of these repos because torchbench and the benchmark runner
+in TorchDynamo isn't designed to be used as a Python package.
+
+To setup the environment, run the following commands in the parent directory of TVM and with 
+the appropriate Python environment:
+```bash
+# torchdynamo requires nightly pytorch. If it fails to find the specified version, try
+# installing the latest nightly pytorch.
+pip3 install --pre \
+    --extra-index-url https://download.pytorch.org/whl/nightly/cu116 \
+    torch==1.13.0.dev20220926 \
+    torchvision==0.14.0.dev20220926 \
+    torchtext==0.14.0.dev20220926
+
+git clone https://github.com/pytorch/torchdynamo
+pushd torchdynamo
+git checkout c537639f9712621dc04ca09908796dbbe86c354b
+pip install -e .
+popd
+
+sudo apt install git-lfs  # git lfs is used for TorchBench
+git clone https://github.com/pytorch/benchmark
+pushd benchmark
+python install.py --continue_on_fail  # fambench_xlmr might fail to install
+popd
+```
+
+To run a benchmark, the script can be run under 'tune' mode by
+```bash
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+    --mode tune \
+    --model resnet50 \
+    --target "nvidia/geforce-rtx-3070" \
+    --work-dir ../workdir \
+    --num-trials 20000 \
+    --rpc-host <rpc tracker host for tuning> \
+    --rpc-port <rpc tracker port for tuning> \
+    --rpc-key <rpc key> \
+```
+
+All available target tags (like nvidia/geforce-rtx-3070) can be found at
+https://github.com/apache/tvm/blob/main/src/target/tag.cc
+
+Then the script can be run under 'eval' mode to actual benchmark the performance,
+using the tuning database under the work directory. This can be executed on a different
+machine than the one executes tuning.
+```bash
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+    --mode eval \
+    --model resnet50 \
+    --target "nvidia/geforce-rtx-3070" \
+    --work-dir ../workdir \
+    --num-trials 0 
+```
+
+Alternatively, both tuning and evaluation can be done in a single run on the same machine,
+by
+```bash
+python python/tvm/meta_schedule/testing/torchbench/run.py \
+    --model resnet50 \

Review Comment:
   maybe better to specify the `--mode all` here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org