You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/12/06 01:59:28 UTC

[tvm] branch main updated: [ROCm][Auto scheduler] Support Auto scheduler and NHWC convolution on ROCm (#7038)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ec739c  [ROCm][Auto scheduler] Support Auto scheduler and NHWC convolution on ROCm (#7038)
3ec739c is described below

commit 3ec739c650b2b49c15528f301242e98f00f0493e
Author: masahi <ma...@gmail.com>
AuthorDate: Sun Dec 6 10:59:09 2020 +0900

    [ROCm][Auto scheduler] Support Auto scheduler and NHWC convolution on ROCm (#7038)
    
    * add nhwc + winograd support to rocm strategy
    
    * support rocm hw parameters in search task
    
    * run analysis pass for rocm too
    
    * run black
    
    * pylint fix
    
    * use IsGPUTask function
    
    Co-authored-by: Masahiro Masuda <ma...@gmail.com>
---
 python/tvm/relay/op/strategy/rocm.py | 52 +++++++++++++++++++++++++++++++-----
 src/auto_scheduler/feature.cc        |  3 ++-
 src/auto_scheduler/search_task.cc    | 13 +++++----
 3 files changed, 55 insertions(+), 13 deletions(-)

diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py
index f52bbc3..c52da54 100644
--- a/python/tvm/relay/op/strategy/rocm.py
+++ b/python/tvm/relay/op/strategy/rocm.py
@@ -17,8 +17,10 @@
 """Definition of ROCm operator strategy."""
 # pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import
 from tvm import topi
+from tvm.auto_scheduler import is_auto_scheduler_enabled
 from .generic import *
 from .. import op as _op
+from .cuda import judge_winograd, naive_schedule
 
 
 @schedule_lrn.register("rocm")
@@ -67,6 +69,49 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
                     name="conv2d_nchw_winograd.cuda",
                     plevel=5,
                 )
+        elif layout == "NHWC":
+            assert kernel_layout == "HWIO"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
+                wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
+                name="conv2d_nhwc.cuda",
+            )
+            N, H, W, _ = get_const_tuple(data.shape)
+            KH, KW, CI, CO = get_const_tuple(kernel.shape)
+
+            (_, judge_winograd_autotvm, judge_winograd_auto_scheduler,) = judge_winograd(
+                N,
+                H,
+                W,
+                KH,
+                KW,
+                CI,
+                CO,
+                padding,
+                stride_h,
+                stride_w,
+                dilation_h,
+                dilation_w,
+                data.dtype,
+                kernel.dtype,
+                pre_flag=False,
+            )
+
+            if judge_winograd_autotvm:
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct),
+                    wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_winograd_direct),
+                    name="conv2d_nhwc_winograd_direct.cuda",
+                    plevel=5,
+                )
+
+            if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
+                    naive_schedule,  # this implementation should never be picked by autotvm
+                    name="conv2d_nhwc.winograd",
+                    plevel=15,
+                )
         elif layout == "HWCN":
             assert kernel_layout == "HWIO"
             strategy.add_implementation(
@@ -74,13 +119,6 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
                 wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
                 name="conv2d_hwcn.cuda",
             )
-        # TODO(@alexgl-github): Re-enable this after fix the conv2d_nhwc for cuda
-        # elif layout == "NHWC":
-        #     assert kernel_layout == "HWIO"
-        #     strategy.add_implementation(
-        #         wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
-        #         wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
-        #         name="conv2d_nhwc.cuda")
         elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
             assert kernel_layout == "OIHW4o4i"
             strategy.add_implementation(
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index 0a3d705..53287a0 100755
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -41,6 +41,7 @@
 #include <unordered_map>
 #include <vector>
 
+#include "search_policy/utils.h"
 #include "utils.h"
 
 namespace tvm {
@@ -1296,7 +1297,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i
     }
     auto mod = IRModule(Map<GlobalVar, BaseFunc>({{global_var, f}}));
 
-    if (task->target->kind->device_type == kDLGPU) {
+    if (IsGPUTask(task)) {
       auto pass_list = Array<tvm::transform::Pass>();
       // Phase 0
       pass_list.push_back(tir::transform::InjectPrefetch());
diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc
index 4c8cc6d..5a34755 100755
--- a/src/auto_scheduler/search_task.cc
+++ b/src/auto_scheduler/search_task.cc
@@ -22,6 +22,7 @@
  * \brief Meta information and hardware parameters for a search task.
  */
 
+#include <dlpack/dlpack.h>
 #include <tvm/auto_scheduler/search_task.h>
 #include <tvm/runtime/device_api.h>
 #include <tvm/runtime/registry.h>
@@ -52,11 +53,13 @@ HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_l
 
 HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target,
                                                             const Target& target_host) {
-  if (target->kind->device_type == kDLCPU) {
+  const auto device_type = target->kind->device_type;
+  if (device_type == kDLCPU) {
     return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 0, 0, 0, 0, 0);
-  } else if (target->kind->device_type == kDLGPU) {
-    auto ctx = TVMContext{kDLGPU, 0};
-    auto func = tvm::runtime::Registry::Get("device_api.gpu");
+  } else if (device_type == kDLGPU || device_type == kDLROCM) {
+    auto ctx = TVMContext{static_cast<DLDeviceType>(device_type), 0};
+    auto device_name = device_type == kDLGPU ? "device_api.gpu" : "device_api.rocm";
+    auto func = tvm::runtime::Registry::Get(device_name);
     ICHECK(func != nullptr) << "Cannot find GPU device_api in registry";
     auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());
 
@@ -77,7 +80,7 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
     int max_vthread_extent = warp_size / 4;
     return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
                           max_threads_per_block, max_vthread_extent, warp_size);
-  } else if (target->kind->device_type == kDLMetal) {
+  } else if (device_type == kDLMetal) {
     // Reference: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
     // This setting looks working for Metal GPUs later than A10
     int max_shared_memory_per_block = 32 * 1024;