You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/12/30 08:50:05 UTC

[GitHub] [tvm] echuraev commented on a diff in pull request #13675: [BENCHMARK][ADRENO] Adreno Benchmarks with texture

echuraev commented on code in PR #13675:
URL: https://github.com/apache/tvm/pull/13675#discussion_r1059292569


##########
tests/scripts/ci.py:
##########
@@ -727,6 +727,13 @@ def add_subparser(
                     "./tests/scripts/task_python_adreno.sh " + os.environ.get("ANDROID_SERIAL", ""),
                 ],
             ),
+            "benchmarks": (

Review Comment:
   Are there any plans to run these benchmarks in the CI? @driazati probably you know, do we have any opportunities to run different performance tests e.g. once per week to check that no performance regressions were introduced to TVM mainline.



##########
apps/benchmark/adreno/bench.py:
##########
@@ -0,0 +1,61 @@
+#!/usr/bin/env bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -euxo pipefail
+
+echo "Bench called"
+
+source tests/scripts/setup-pytest-env.sh
+export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python
+export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"
+
+export TVM_TRACKER_HOST=127.0.0.1
+export TVM_TRACKER_PORT=$(((RANDOM % 100) + 9100))
+export RPC_DEVICE_KEY="android"
+export TVM_NDK_CC="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang"
+
+env PYTHONPATH=python python3 -m tvm.exec.rpc_tracker --host "${TVM_TRACKER_HOST}" --port "${TVM_TRACKER_PORT}" &
+TRACKER_PID=$!
+sleep 5   # Wait for tracker to bind
+
+export ANDROID_SERIAL=$2
+
+adb shell "mkdir -p /data/local/tmp/tvm_ci"
+adb push build-adreno-target/tvm_rpc /data/local/tmp/tvm_ci/tvm_rpc_ci
+adb push build-adreno-target/libtvm_runtime.so /data/local/tmp/tvm_ci
+
+adb reverse tcp:${TVM_TRACKER_PORT} tcp:${TVM_TRACKER_PORT}
+adb forward tcp:5000 tcp:5000
+adb forward tcp:5001 tcp:5001
+adb forward tcp:5002 tcp:5002
+env adb shell "cd /data/local/tmp/tvm_ci; killall -9 tvm_rpc_ci; sleep 2; LD_LIBRARY_PATH=/data/local/tmp/tvm_ci/ ./tvm_rpc_ci server --host=0.0.0.0 --port=5000 --port-end=5010 --tracker=127.0.0.1:${TVM_TRACKER_PORT} --key=${RPC_DEVICE_KEY}" &
+DEVICE_PID=$!
+sleep 5 # Wait for the device connections
+trap "{ kill ${TRACKER_PID}; kill ${DEVICE_PID}; }" 0
+
+# cleanup pycache
+find . -type f -path "*.pyc" | xargs rm -f
+# Test TVM
+make cython3
+
+if [ "texture" == $1 ] ; then

Review Comment:
   Are there any reasons to run this script without `texture` parameter? If no, probably we can remove this argument.



##########
apps/benchmark/adreno/adreno_gpu_bench_texture.py:
##########
@@ -0,0 +1,277 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Benchmark script for various models on Adreno GPU.
+"""
+import argparse
+
+import numpy as np
+
+import os
+import sys
+import tvm
+from tvm import te
+from tvm.relay import testing
+from tvm.contrib.utils import tempdir
+import tvm.contrib.graph_executor as runtime
+from tvm import relay
+from tvm import autotvm
+from tvm.contrib import utils, ndk
+
+
+def get_network(name, batch_size, dtype="float32"):
+    """Get the symbol definition and random weight of a network
+
+    Parameters
+    ----------
+    name: str
+        The name of the network, can be 'resnet-18', 'resnet-50', 'vgg-16', 'inception_v3', 'mobilenet', ...
+    batch_size: int
+        batch size
+    dtype: str
+        Data type
+
+    Returns
+    -------
+    net: tvm.IRModule
+        The relay function of network definition
+    params: dict
+        The random parameters for benchmark
+    input_shape: tuple
+        The shape of input tensor
+    output_shape: tuple
+        The shape of output tensor
+    """
+    input_shape = (batch_size, 3, 224, 224)
+    output_shape = (batch_size, 1000)
+
+    if name == "mobilenet":
+        net, params = testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
+    elif name == "inception_v3":
+        input_shape = (batch_size, 3, 299, 299)
+        net, params = testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
+    elif "resnet" in name:
+        n_layer = int(name.split("-")[1])
+        net, params = testing.resnet.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
+        )
+    elif "vgg" in name:
+        n_layer = int(name.split("-")[1])
+        net, params = testing.vgg.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
+        )
+    elif "densenet" in name:
+        n_layer = int(name.split("-")[1])
+        net, params = testing.densenet.get_workload(
+            densenet_size=n_layer, batch_size=batch_size, dtype=dtype
+        )
+    elif "squeezenet" in name:
+        version = name.split("_v")[1]
+        net, params = testing.squeezenet.get_workload(
+            batch_size=batch_size, version=version, dtype=dtype
+        )
+    elif name == "mxnet":
+        # an example for mxnet model
+        from mxnet.gluon.model_zoo.vision import get_model
+
+        block = get_model("resnet18_v1", pretrained=True)
+        net, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype)
+        net = net["main"]
+        net = relay.Function(
+            net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
+        )
+        net = tvm.IRModule.from_expr(net)
+    else:
+        raise ValueError("Unsupported network: " + name)
+
+    return net, params, input_shape, output_shape
+
+
+def print_progress(msg):
+    """print progress message
+
+    Parameters
+    ----------
+    msg: str
+        The message to print
+    """
+    sys.stdout.write(msg + "\r")
+    sys.stdout.flush()
+
+
+def tune_tasks(

Review Comment:
   Just as an idea. If we want to use such benchmarks in the CI we can use statistic from tophub to measure performance w/o tuning. If someone changes schedules for Adreno then he/she will upload updated statistic to the tophub. In this way, we can avoid performance regressions.



##########
apps/benchmark/adreno/adreno_gpu_bench_texture.py:
##########
@@ -0,0 +1,277 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Benchmark script for various models on Adreno GPU.
+"""
+import argparse
+
+import numpy as np
+
+import os
+import sys
+import tvm
+from tvm import te
+from tvm.relay import testing
+from tvm.contrib.utils import tempdir
+import tvm.contrib.graph_executor as runtime
+from tvm import relay
+from tvm import autotvm
+from tvm.contrib import utils, ndk
+
+
+def get_network(name, batch_size, dtype="float32"):
+    """Get the symbol definition and random weight of a network
+
+    Parameters
+    ----------
+    name: str
+        The name of the network, can be 'resnet-18', 'resnet-50', 'vgg-16', 'inception_v3', 'mobilenet', ...
+    batch_size: int
+        batch size
+    dtype: str
+        Data type
+
+    Returns
+    -------
+    net: tvm.IRModule
+        The relay function of network definition
+    params: dict
+        The random parameters for benchmark
+    input_shape: tuple
+        The shape of input tensor
+    output_shape: tuple
+        The shape of output tensor
+    """
+    input_shape = (batch_size, 3, 224, 224)
+    output_shape = (batch_size, 1000)
+
+    if name == "mobilenet":
+        net, params = testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
+    elif name == "inception_v3":
+        input_shape = (batch_size, 3, 299, 299)
+        net, params = testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
+    elif "resnet" in name:
+        n_layer = int(name.split("-")[1])
+        net, params = testing.resnet.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
+        )
+    elif "vgg" in name:
+        n_layer = int(name.split("-")[1])
+        net, params = testing.vgg.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
+        )
+    elif "densenet" in name:
+        n_layer = int(name.split("-")[1])
+        net, params = testing.densenet.get_workload(
+            densenet_size=n_layer, batch_size=batch_size, dtype=dtype
+        )
+    elif "squeezenet" in name:
+        version = name.split("_v")[1]
+        net, params = testing.squeezenet.get_workload(
+            batch_size=batch_size, version=version, dtype=dtype
+        )
+    elif name == "mxnet":
+        # an example for mxnet model
+        from mxnet.gluon.model_zoo.vision import get_model
+
+        block = get_model("resnet18_v1", pretrained=True)
+        net, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype)
+        net = net["main"]
+        net = relay.Function(
+            net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
+        )
+        net = tvm.IRModule.from_expr(net)
+    else:
+        raise ValueError("Unsupported network: " + name)
+
+    return net, params, input_shape, output_shape
+
+
+def print_progress(msg):
+    """print progress message
+
+    Parameters
+    ----------
+    msg: str
+        The message to print
+    """
+    sys.stdout.write(msg + "\r")
+    sys.stdout.flush()
+
+
+def tune_tasks(
+    tasks,
+    measure_option,
+    n_trial=1024,
+    early_stopping=None,
+    log_filename="tuning.log",
+):
+    from tvm.autotvm.tuner import XGBTuner
+
+    tmp_log_file = log_filename + ".tmp"
+
+    for i, tsk in enumerate(reversed(tasks)):
+        print("Task: ", tsk)
+        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
+        tuner_obj = XGBTuner(tsk, loss_type="rank")
+
+        tsk_trial = min(n_trial, len(tsk.config_space))
+        tuner_obj.tune(
+            n_trial=tsk_trial,
+            early_stopping=early_stopping,
+            measure_option=measure_option,
+            callbacks=[
+                autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
+                autotvm.callback.log_to_file(tmp_log_file),
+            ],
+        )
+
+        autotvm.record.pick_best(tmp_log_file, log_filename)
+
+
+def evaluate_network(network, target, target_host, dtype, repeat):
+    print_progress(network)
+    net, params, input_shape, output_shape = get_network(network, batch_size=1, dtype=dtype)
+
+    # Auto Tuning
+    tune_log = "adreno-" + network + "-" + dtype + ".log"
+    tuning_options = {
+        "log_filename": tune_log,
+        "early_stopping": None,
+        "measure_option": autotvm.measure_option(
+            builder=autotvm.LocalBuilder(build_func=ndk.create_shared, timeout=15),
+            runner=autotvm.RPCRunner(
+                args.rpc_key,
+                host=args.host,
+                port=args.port,
+                number=3,
+                timeout=600,
+            ),
+        ),
+    }
+    if args.tune:
+        tasks = autotvm.task.extract_from_program(
+            net, target=target, target_host=target_host, params=params
+        )
+        tune_tasks(tasks, **tuning_options)
+
+    print_progress("%-20s building..." % network)
+
+    # Build the tuning log
+    if os.path.exists(tune_log):
+        with autotvm.apply_history_best(tune_log):
+            with tvm.transform.PassContext(opt_level=3):
+                lib = relay.build(
+                    net, target=tvm.target.Target(target, host=target_host), params=params
+                )
+    else:
+        print("WARNING: Benchmark running with out tuning cache file - ", tune_log)

Review Comment:
   ```suggestion
           print("WARNING: Benchmark running without tuning cache file - ", tune_log)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org