You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/18 17:38:27 UTC

[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6512: [ANSOR] Auto-scheduler tutorial for GPU and necessary refactor/fix

comaniac commented on a change in pull request #6512:
URL: https://github.com/apache/incubator-tvm/pull/6512#discussion_r491089637



##########
File path: tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
##########
@@ -0,0 +1,190 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+.. _auto-scheduler-conv-gpu:
+
+Auto-scheduling a convolution layer for GPU
+=============================================

Review comment:
       ```suggestion
   ===========================================
   ```

##########
File path: src/auto_scheduler/search_policy/sketch_policy.cc
##########
@@ -390,135 +385,102 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul
   Array<State>* pnow = &states_buf1;
   Array<State>* pnext = &states_buf2;
 
-  // The set of explored states to avoid redundancy.
-  std::unordered_set<std::string> explored_set;
-
-  // The heap to maintain the so far best states.
+  // A heap to keep the best states during evolution
   using StateHeapItem = std::pair<State, float>;
   auto cmp = [](const StateHeapItem& left, const StateHeapItem& right) {
     return left.second > right.second;
   };
-  using StateHeap = std::priority_queue<StateHeapItem, std::vector<StateHeapItem>, decltype(cmp)>;
-  StateHeap heap(cmp);
-  auto update_heap = [&heap, &explored_set](const Array<State>& states,
-                                            const std::vector<float>& scores, const int out_size) {
-    float max_score = 0.0;
-    for (size_t i = 0; i < states.size(); ++i) {
-      const State& state = states[i];
+  std::vector<StateHeapItem> heap;
+  std::unordered_set<std::string> in_heap(measured_states_set_);
+  heap.reserve(out_size);
+
+  // auxiliary global variables
+  std::vector<float> pop_scores;
+  std::vector<double> pop_selection_probs;
+  float max_score = 0.0;
+  pop_scores.reserve(population);
+  pop_selection_probs.reserve(population);
+  std::uniform_real_distribution<> dis(0.0, 1.0);
+
+  // mutation rules
+  int mutation_success_ct, mutation_fail_ct;
+  mutation_success_ct = mutation_fail_ct = 0;
+  std::vector<float> rule_weights;
+  std::vector<double> rule_selection_probs;
+  for (const auto& rule : mutation_rules) {
+    rule_weights.push_back(rule->weight);
+  }
+  ComputePrefixSumProb(rule_weights, &rule_selection_probs);
+
+  // Genetic Algorithm
+  for (int k = 0; k < num_iters + 1; ++k) {
+    // Maintain the heap
+    *pnow = search_task->compute_dag.InferBound(*pnow);
+    PruneInvalidState(search_task, pnow);

Review comment:
       I moved this part to the end of each GA iteration in order to deal with the case that all generated states are pruned by this call. What happen if all states are invalid and `pnow` becomes empty? Should we simply error out and stop the tuning, as it is likely problematic?

##########
File path: tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
##########
@@ -0,0 +1,190 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+.. _auto-scheduler-conv-gpu:
+
+Auto-scheduling a convolution layer for GPU
+=============================================
+**Author**: `Lianmin Zheng <https://github.com/merrymercy>`_, \
+            `Chengfan Jia <https://github.com/jcf94/>`_
+
+
+Different from the existing :ref:`autotvm <tutorials-autotvm-sec>` which relies on 
+manual templates to define the search space, the auto-scheduler does not require any templates.
+The auto-scheduler is template-free, so users only need to write the computation declaration without
+any schedule commands or templates.
+The auto-scheduler can automatically generate a large
+search space and find a good schedule in the space.
+
+We use a convolution layer as an example in this tutorial.
+"""
+
+import numpy as np
+import tvm
+from tvm import te, testing, auto_scheduler, topi
+from tvm.topi.testing import conv2d_nchw_python
+
+######################################################################
+# Define the computation
+# ^^^^^^^^^^^^^^^^^^^^^^
+# To begin with, let us define the computation of a convolution layer.
+# The function should return the list of input/output tensors.
+# From these tensors, the auto-scheduler can get the whole computational graph.
+
+
+@auto_scheduler.register_workload
+def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding):
+    data = te.placeholder((N, CI, H, W), name="data")
+    kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
+    bias = te.placeholder((1, CO, 1, 1), name="bias")
+    conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32")
+    out = topi.nn.relu(conv + bias)
+    return [data, kernel, bias, out]
+
+
+######################################################################
+# Create the search task
+# ^^^^^^^^^^^^^^^^^^^^^^
+# We then create a search task for the last convolution layer in the resnet.
+
+target = tvm.target.Target("cuda")
+
+# the last layer in resnet
+N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
+task = auto_scheduler.create_task(conv2d_layer, (N, H, W, CO, CI, KH, KW, strides, padding), target)
+
+# Inspect the computational graph
+print(task.compute_dag)
+
+######################################################################
+# Next, we set parameters for the auto-scheduler. These parameters
+# mainly specify how we do the measurement during the search and auto-tuning.
+#
+# * `measure_ctx` launches a different process for measurement. This
+#   provides an isolation. It can protect the master process from GPU crashes
+#   happended during measurement and avoid other runtime conflicts.
+# * `min_repeat_ms` defines the minimum duration of one "repeat" in every measurement.
+#   This can warmup the GPU, which is necessary to get accurate measurement results.

Review comment:
       It would be better to mention the recommended value (i.e., >300) on GPU.

##########
File path: tutorials/auto_scheduler/tune_matmul_x86.py
##########
@@ -161,13 +178,16 @@ def resume_search(task, log_file):
 # .. note::
 #   We cannot run the line above because of the conflict between
 #   python's multiprocessing and tvm's thread pool.
-#   After running a tvm generated binary (L112), the python's multiprocessing
-#   library will hang forever.
-#   You have to make sure that you don't run any tvm generated binaries before
-#   calling ansor's search. To run the L156 above, you should comment out L112-114.
+#   After running a tvm generated binary the python's multiprocessing library
+#   will hang forever. You have to make sure that you don't run any tvm
+#   generated binaries before calling auot-scheduler's search.
+#   To run the function above, you should comment out all code in
+#   "Check correctness and evaluate performance" section.
 #
 #   You should be careful about this problem in your applications.
 #   There are other workarounds for this problem.
 #   For example, you can start a new thread/process (with the builtin python library
 #   threading or multiprocessing) and run the tvm binaries in the new thread/process.
 #   This provides an isolation and avoids the conflict in the main thread/process.
+#   You can also use :any:`auto_scheduler.measure.LocalRPCMeasureContext` for auto-scheduler,
+#   as shown in the GPU tutorial (:ref:`auto-scheduler-conv-gpu`).

Review comment:
       Intuitively, if there's no obvious performance impact, we should use RPC runner on both CPU and GPU, so it'd better to mention why we didn't use it in this tutorial.

##########
File path: tutorials/auto_scheduler/tune_matmul_x86.py
##########
@@ -59,6 +59,9 @@ def matmul_add(N, L, M, dtype):
 # Create the search task
 # ^^^^^^^^^^^^^^^^^^^^^^
 # We then create a search task with N=L=M=128 and dtype="float32"
+# If your machine supports avx instructions, you can
+# - replace "llvm" below with "llvm -mcpu=core-avx2" to enable AVX2
+# - replace "llvm" belwo with "llvm -mcpu=skylake-avx512" to enable AVX-512

Review comment:
       ```suggestion
   # - replace "llvm" below with "llvm -mcpu=skylake-avx512" to enable AVX-512
   ```

##########
File path: tutorials/auto_scheduler/tune_matmul_x86.py
##########
@@ -93,25 +96,38 @@ def matmul_add(N, L, M, dtype):
 ######################################################################
 # We can lower the schedule to see the IR after auto-scheduling.
 # The auto-scheduler correctly performs optimizations including multi-level tiling,
-# parallelization, vectorization, unrolling and fusion.
+# parallelization, vectorization, unrolling and operator fusion.
 
 print(tvm.lower(sch, args, simple_mode=True))
 
 ######################################################################
-# Check correctness
-# ^^^^^^^^^^^^^^^^^
-# We build the binary and check its correctness
+# Check correctness and evaluate performance
+# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+# We build the binary and check its correctness and performance.
 
 func = tvm.build(sch, args)
 a_np = np.random.uniform(size=(128, 128)).astype(np.float32)
 b_np = np.random.uniform(size=(128, 128)).astype(np.float32)
 c_np = np.random.uniform(size=(128, 128)).astype(np.float32)
-d_np = a_np.dot(b_np) + c_np
-
-d_tvm = tvm.nd.empty(d_np.shape)
-func(tvm.nd.array(a_np), tvm.nd.array(b_np), tvm.nd.array(c_np), d_tvm)
-
-tvm.testing.assert_allclose(d_np, d_tvm.asnumpy(), rtol=1e-3)
+out_np = a_np.dot(b_np) + c_np
+
+ctx = tvm.cpu()
+a_tvm = tvm.nd.array(a_np, ctx=ctx)
+b_tvm = tvm.nd.array(b_np, ctx=ctx)
+c_tvm = tvm.nd.array(c_np, ctx=ctx)
+out_tvm = tvm.nd.empty(out_np.shape, ctx=ctx)
+func(a_tvm, b_tvm, c_tvm, out_tvm)
+
+# Check results
+tvm.testing.assert_allclose(out_np, out_tvm.asnumpy(), rtol=1e-3)
+
+# Evaluate execution time.
+evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500)
+print(
+    "Execution time of this operator: %.3f ms"
+    % (evaluator(a_tvm, b_tvm, c_tvm, out_tvm).mean * 1000)

Review comment:
       Should we use median instead of mean per offline discussion (same as in the GPU tutorial)?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org