You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/18 18:18:49 UTC

[GitHub] [tvm] comaniac commented on a change in pull request #7132: [Auto Scheduler] Mali Support

comaniac commented on a change in pull request #7132:
URL: https://github.com/apache/tvm/pull/7132#discussion_r545995260



##########
File path: python/tvm/relay/op/strategy/mali.py
##########
@@ -18,8 +18,10 @@
 # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
 import re
 from tvm import topi
+from tvm.auto_scheduler import is_auto_scheduler_enabled
 from .generic import *
 from .. import op as _op
+from .cuda import naive_schedule

Review comment:
       This function has been moved to `.generic`.

##########
File path: python/tvm/relay/op/strategy/mali.py
##########
@@ -105,6 +146,16 @@ def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_ty
             wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_winograd),
             name="conv2d_nchw_winograd.mali",
         )
+    elif layout == "NHWC":
+        if is_auto_scheduler_enabled():
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform),
+                naive_schedule,  # this implementation should never be picked by autotvm
+                name="conv2d_nhwc_winograd_without_weight_transform",
+                plevel=15,
+            )
+        else:
+            logger.error("AutoTVM doesn't support NHWC winograd on Mali currently")

Review comment:
       Make it consistent as other cases.
   ```suggestion
           if not is_auto_scheduler_enabled():
               logger.error("Winograd conv2d NHWC is not enabled for mali without auto_scheduler")
           strategy.add_implementation(
               wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform),
               naive_schedule,  # this implementation should never be picked by autotvm
               name="conv2d_nhwc_winograd_without_weight_transform",
               plevel=15,
           )
   ```

##########
File path: src/auto_scheduler/search_task.cc
##########
@@ -90,6 +90,22 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
     int max_vthread_extent = warp_size / 4;
     return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
                           max_threads_per_block, max_vthread_extent, warp_size);
+  } else if (target->kind->device_type == kDLOpenCL) {
+    if (target->GetAttr<String>("device", "") == "mali") {
+      // We can not use device api to get attr like CUDA
+      // because like Mali target is normally on the remote machine

Review comment:
       nit: maybe we can tell users how to get the Mali hardware parameters manually in the tutorial.
   
   ```suggestion
         // We cannot use device API to get hardware attributes like CUDA,
         // because Mali target is normally on the remote machine.
   ```

##########
File path: tutorials/auto_scheduler/tune_network_mali.py
##########
@@ -0,0 +1,322 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Auto-scheduling a Neural Network for mali GPU
+=============================================
+**Author**: `Zhao Wu <https://github.com/FrozenGene>`_
+
+Auto-tuning for specific devices and workloads is critical for getting the
+best performance. This is a tutorial on how to tune a whole neural
+network for mali GPU with the auto-scheduler.
+
+To auto-tune a neural network, we partition the network into small subgraphs and 
+tune them independently. Each subgraph is treated as one search task.
+A task scheduler slices the time and dynamically allocates time resources to
+these tasks. The task scheduler predicts the impact of each task on the end-to-end
+execution time and prioritizes the one that can reduce the execution time the most.
+
+For each subgraph, we use the compute declaration in :code:`tvm/python/topi` to
+get the computational DAG in the tensor expression form.
+We then use the auto-scheduler to construct a search space of this DAG and search
+for good schedules (low-level optimizations).
+
+Different from the template-based :ref:`autotvm <tutorials-autotvm-sec>` which relies on
+manual templates to define the search space, the auto-scheduler does not require any
+schedule templates. In other words, the auto-scheduler only uses the compute declarations
+in :code:`tvm/python/topi` and does not use existing schedule templates.
+
+Note that this tutorial will not run on Windows or recent versions of macOS. To
+get it to run, you will need to wrap the body of this tutorial in a :code:`if
+__name__ == "__main__":` block.
+"""
+
+import numpy as np
+
+import tvm
+from tvm import relay, auto_scheduler
+import tvm.relay.testing
+from tvm.contrib import graph_runtime
+import os
+
+#################################################################
+# Define a Network
+# ----------------
+# First, we need to define the network with relay frontend API.
+# We can load some pre-defined network from :code:`tvm.relay.testing`.
+# We can also load models from MXNet, ONNX, PyTorch, and TensorFlow
+# (see :ref:`front end tutorials<tutorial-frontend>`).
+#
+# For convolutional neural networks, although auto-scheduler can work correctly
+# with any layout, we found the best performance is typically achieved with NHWC layout.
+# We also implemented more optimizations for NHWC layout with the auto-scheduler.
+# So it is recommended to convert your models to NHWC layout to use the auto-scheduler.
+# You can use :ref:`ConvertLayout <convert-layout-usage>` pass to do the layout conversion in TVM.
+
+
+def get_network(name, batch_size, layout="NHWC", dtype="float32"):
+    """Get the symbol definition and random weight of a network"""
+
+    # auto-scheduler prefers NHWC layout
+    if layout == "NHWC":
+        image_shape = (224, 224, 3)
+    elif layout == "NCHW":
+        image_shape = (3, 224, 224)
+    else:
+        raise ValueError("Invalid layout: " + layout)
+
+    input_shape = (batch_size,) + image_shape
+    output_shape = (batch_size, 1000)
+
+    if name.startswith("resnet-"):
+        n_layer = int(name.split("-")[1])
+        mod, params = relay.testing.resnet.get_workload(
+            num_layers=n_layer,
+            batch_size=batch_size,
+            layout=layout,
+            dtype=dtype,
+            image_shape=image_shape,
+        )
+    elif name.startswith("resnet3d-"):
+        n_layer = int(name.split("-")[1])
+        mod, params = relay.testing.resnet.get_workload(
+            num_layers=n_layer,
+            batch_size=batch_size,
+            layout=layout,
+            dtype=dtype,
+            image_shape=image_shape,
+        )
+    elif name == "mobilenet":
+        mod, params = relay.testing.mobilenet.get_workload(
+            batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape
+        )
+    elif name == "squeezenet_v1.1":
+        assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout"
+        mod, params = relay.testing.squeezenet.get_workload(
+            version="1.1",
+            batch_size=batch_size,
+            dtype=dtype,
+            image_shape=image_shape,
+        )
+    elif name == "inception_v3":
+        input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else (batch_size, 299, 299, 3)
+        mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
+    elif name == "mxnet":
+        # an example for mxnet model
+        from mxnet.gluon.model_zoo.vision import get_model
+
+        assert layout == "NCHW"
+
+        block = get_model("resnet50_v1", pretrained=True)
+        mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype)
+        net = mod["main"]
+        net = relay.Function(
+            net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
+        )
+        mod = tvm.IRModule.from_expr(net)
+
+    return mod, params, input_shape, output_shape
+
+
+# Define the neural network and compilation target.
+network = "mobilenet"
+batch_size = 1
+layout = "NHWC"
+# replace this with the device key in your tracker

Review comment:
       ```suggestion
   # Replace this with the device key in your tracker
   ```

##########
File path: python/tvm/relay/op/strategy/mali.py
##########
@@ -79,6 +111,15 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target):
                 wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nchw),
                 name="depthwise_conv2d_nchw.mali",
             )
+        elif layout == "NHWC":
+            assert kernel_layout == "HWOI"
+            if not is_auto_scheduler_enabled():
+                logger.error("depthwise_conv2d NHWC layout is not enabled for mali with autotvm.")

Review comment:
       ditto.

##########
File path: python/tvm/relay/op/strategy/mali.py
##########
@@ -69,6 +71,36 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target):
                 raise RuntimeError(
                     "Unsupported weight layout {} for conv2d NCHW".format(kernel_layout)
                 )
+        elif layout == "NHWC":
+            assert kernel_layout == "HWIO"
+            if not is_auto_scheduler_enabled():
+                logger.error("conv2d NHWC layout is not enabled for mali with autotvm.")
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True),
+                naive_schedule,
+                name="conv2d_nhwc.mali",
+            )
+            is_winograd_applicable = False
+            if len(kernel.shape) == 4:
+                kernel_h, kernel_w, _, _ = get_const_tuple(kernel.shape)
+                is_winograd_applicable = (
+                    "float" in data.dtype
+                    and "float" in kernel.dtype
+                    and kernel_h == 3
+                    and kernel_w == 3
+                    and stride_h == 1
+                    and stride_w == 1
+                    and dilation_h == 1
+                    and dilation_w == 1
+                )

Review comment:
       nit: this part might be refactored along with the Winograd judgement of NCHW layout.

##########
File path: python/tvm/relay/op/strategy/mali.py
##########
@@ -69,6 +71,36 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target):
                 raise RuntimeError(
                     "Unsupported weight layout {} for conv2d NCHW".format(kernel_layout)
                 )
+        elif layout == "NHWC":
+            assert kernel_layout == "HWIO"
+            if not is_auto_scheduler_enabled():
+                logger.error("conv2d NHWC layout is not enabled for mali with autotvm.")

Review comment:
       The message might be a bit confusing for new users, because if they will see this message when building the model even without AutoTVM tuning. In addition, since logging error won't terminate the execution, it will continue and error out when creating the schedule using `naive_schedule`. It seems to me that we should either change this level to warning to let users know the failure reason, or simply raise an exception here.
   
   ```suggestion
                   logger.error("conv2d NHWC layout is not enabled for mali without auto_scheduler.")
   ```

##########
File path: tutorials/auto_scheduler/tune_network_mali.py
##########
@@ -0,0 +1,322 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Auto-scheduling a Neural Network for mali GPU
+=============================================
+**Author**: `Zhao Wu <https://github.com/FrozenGene>`_
+
+Auto-tuning for specific devices and workloads is critical for getting the
+best performance. This is a tutorial on how to tune a whole neural
+network for mali GPU with the auto-scheduler.
+
+To auto-tune a neural network, we partition the network into small subgraphs and 
+tune them independently. Each subgraph is treated as one search task.
+A task scheduler slices the time and dynamically allocates time resources to
+these tasks. The task scheduler predicts the impact of each task on the end-to-end
+execution time and prioritizes the one that can reduce the execution time the most.
+
+For each subgraph, we use the compute declaration in :code:`tvm/python/topi` to
+get the computational DAG in the tensor expression form.
+We then use the auto-scheduler to construct a search space of this DAG and search
+for good schedules (low-level optimizations).
+
+Different from the template-based :ref:`autotvm <tutorials-autotvm-sec>` which relies on
+manual templates to define the search space, the auto-scheduler does not require any
+schedule templates. In other words, the auto-scheduler only uses the compute declarations
+in :code:`tvm/python/topi` and does not use existing schedule templates.
+
+Note that this tutorial will not run on Windows or recent versions of macOS. To
+get it to run, you will need to wrap the body of this tutorial in a :code:`if
+__name__ == "__main__":` block.
+"""
+
+import numpy as np
+
+import tvm
+from tvm import relay, auto_scheduler
+import tvm.relay.testing
+from tvm.contrib import graph_runtime
+import os
+
+#################################################################
+# Define a Network
+# ----------------
+# First, we need to define the network with relay frontend API.
+# We can load some pre-defined network from :code:`tvm.relay.testing`.
+# We can also load models from MXNet, ONNX, PyTorch, and TensorFlow
+# (see :ref:`front end tutorials<tutorial-frontend>`).
+#
+# For convolutional neural networks, although auto-scheduler can work correctly
+# with any layout, we found the best performance is typically achieved with NHWC layout.
+# We also implemented more optimizations for NHWC layout with the auto-scheduler.
+# So it is recommended to convert your models to NHWC layout to use the auto-scheduler.
+# You can use :ref:`ConvertLayout <convert-layout-usage>` pass to do the layout conversion in TVM.
+
+
+def get_network(name, batch_size, layout="NHWC", dtype="float32"):
+    """Get the symbol definition and random weight of a network"""
+
+    # auto-scheduler prefers NHWC layout
+    if layout == "NHWC":
+        image_shape = (224, 224, 3)
+    elif layout == "NCHW":
+        image_shape = (3, 224, 224)
+    else:
+        raise ValueError("Invalid layout: " + layout)
+
+    input_shape = (batch_size,) + image_shape
+    output_shape = (batch_size, 1000)
+
+    if name.startswith("resnet-"):
+        n_layer = int(name.split("-")[1])
+        mod, params = relay.testing.resnet.get_workload(
+            num_layers=n_layer,
+            batch_size=batch_size,
+            layout=layout,
+            dtype=dtype,
+            image_shape=image_shape,
+        )
+    elif name.startswith("resnet3d-"):
+        n_layer = int(name.split("-")[1])
+        mod, params = relay.testing.resnet.get_workload(
+            num_layers=n_layer,
+            batch_size=batch_size,
+            layout=layout,
+            dtype=dtype,
+            image_shape=image_shape,
+        )
+    elif name == "mobilenet":
+        mod, params = relay.testing.mobilenet.get_workload(
+            batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape
+        )
+    elif name == "squeezenet_v1.1":
+        assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout"
+        mod, params = relay.testing.squeezenet.get_workload(
+            version="1.1",
+            batch_size=batch_size,
+            dtype=dtype,
+            image_shape=image_shape,
+        )
+    elif name == "inception_v3":
+        input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else (batch_size, 299, 299, 3)
+        mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
+    elif name == "mxnet":
+        # an example for mxnet model
+        from mxnet.gluon.model_zoo.vision import get_model
+
+        assert layout == "NCHW"
+
+        block = get_model("resnet50_v1", pretrained=True)
+        mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype)
+        net = mod["main"]
+        net = relay.Function(
+            net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
+        )
+        mod = tvm.IRModule.from_expr(net)
+
+    return mod, params, input_shape, output_shape
+
+
+# Define the neural network and compilation target.
+network = "mobilenet"
+batch_size = 1
+layout = "NHWC"
+# replace this with the device key in your tracker
+device_key = "rk3399"
+# Set this to True if you use ndk tools for cross compiling
+use_ndk = True
+# Path to cross compiler
+os.environ["TVM_NDK_CC"] = "/usr/bin/aarch64-linux-gnu-g++"
+target_host = tvm.target.Target("llvm -mtriple=aarch64-linux-gnu")
+target = tvm.target.Target("opencl -device=mali")
+dtype = "float32"
+log_file = "%s-%s-B%d-%s.json" % (network, layout, batch_size, target.kind.name)
+
+#################################################################
+# Extract Search Tasks
+# --------------------
+# Next, we extract the search tasks and their weights from a network.
+# The weight of a task is the number of appearances of the task's subgraph
+# in the whole network.
+# By using the weight, we can approximate the end-to-end latency of the network
+# as :code:`sum(latency[t] * weight[t])`, where :code:`latency[t]` is the
+# latency of a task and :code:`weight[t]` is the weight of the task.
+# The task scheduler will just optimize this objective.
+
+# Extract tasks from the network
+print("Extract tasks...")
+mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype)
+tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target, target_host)
+
+for idx, task in enumerate(tasks):
+    print("========== Task %d  (workload key: %s) ==========" % (idx, task.workload_key))
+    print(task.compute_dag)
+
+#################################################################
+# Tuning and Evaluate
+# -------------------
+# Now, we set some options for tuning, launch the search tasks and evaluate the end-to-end performance
+#
+# * :code:`num_measure_trials` is the number of measurement trials we can use during the tuning.
+#   You can set it to a small number (e.g., 200) for a fast demonstrative run.
+#   In practice, we recommend setting it around :code:`800 * len(tasks)`,
+#   which is typically enough for the search to converge.
+#   For example, there are 29 tasks in resnet-50, so we can set it as 20000.
+#   You can adjust this parameter according to your time budget.
+# * In addition, we use :code:`RecordToFile` to dump measurement records into a log file,
+#   The measurement records can be used to query the history best, resume the search,
+#   and do more analyses later.
+# * see :any:`auto_scheduler.TuningOptions`,
+#   :any:`auto_scheduler.LocalRunner` for more parameters.
+#
+
+
+def tune_and_evaluate():
+    print("Begin tuning...")
+    tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
+    tune_option = auto_scheduler.TuningOptions(

Review comment:
       As mentioned in the other comment, maybe we can have a note or a small section about how to identify the right hardware parameters and how to pass them to auto_scheduler.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org