You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2021/12/01 18:53:45 UTC

[tvm] branch main updated: Fix cuDNN call for NHWC layout (#9600)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 70de68a  Fix cuDNN call for NHWC layout (#9600)
70de68a is described below

commit 70de68a38f6738255cc71fd9a38964b23f2c5f55
Author: masahi <ma...@gmail.com>
AuthorDate: Thu Dec 2 03:53:20 2021 +0900

    Fix cuDNN call for NHWC layout (#9600)
---
 python/tvm/relay/op/strategy/cuda.py      | 35 ++++++++++-------
 src/runtime/contrib/cudnn/conv_forward.cc | 64 +++++++++++++++++++++++--------
 tests/python/contrib/test_cudnn.py        |  4 +-
 3 files changed, 70 insertions(+), 33 deletions(-)

diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index 80f1fe1..7bc04b4 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -180,8 +180,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
                 wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
                 name="conv2d_hwcn.cuda",
             )
-        elif layout == "NHWC":
-            assert kernel_layout == "HWIO"
+        elif layout == "NHWC" and kernel_layout == "HWIO":
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.gpu.conv2d_nhwc),
                 wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc),
@@ -304,19 +303,27 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
                 wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
                 name="conv2d_NCHWc_int8.cuda",
             )
-        else:
+        elif target.kind.name == "cuda" and "cudnn" not in target.libs:
+            # No TVM native kernel applicable
             raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
-        # add cudnn implementation
-        if target.kind.name == "cuda" and "cudnn" in target.libs:
-            if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and padding[1] == padding[3]:
-                strategy.add_implementation(
-                    wrap_compute_conv2d(
-                        topi.cuda.conv2d_cudnn, need_data_layout=True, has_groups=True
-                    ),
-                    wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn),
-                    name="conv2d_cudnn.cuda",
-                    plevel=25,
-                )
+
+        if (
+            target.kind.name == "cuda"
+            and "cudnn" in target.libs
+            and layout in ["NCHW", "NHWC"]
+            and padding[0] == padding[2]
+            and padding[1] == padding[3]
+        ):
+            # add cudnn implementation
+            if layout == "NHWC":
+                assert kernel_layout == "OHWI"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.cuda.conv2d_cudnn, need_data_layout=True, has_groups=True),
+                wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn),
+                name="conv2d_cudnn.cuda",
+                plevel=25,
+            )
+
     elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
         if layout == "NCHW":
             assert kernel_layout == "OIHW"
diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc
index 2d7f826..9770498 100644
--- a/src/runtime/contrib/cudnn/conv_forward.cc
+++ b/src/runtime/contrib/cudnn/conv_forward.cc
@@ -157,7 +157,7 @@ void OutputShape(int format, int dims, int groups, const int pad[], const int st
                                              entry_ptr->conv_entry.data_type));
 
   if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) {
-    ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only defined for 4d tensors";
+    ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only supported for 4d tensors";
 
     // Set Input
     CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
@@ -206,23 +206,53 @@ void FindAlgo(int format, int dims, int groups, const int pad[], const int strid
 
   // conv desc
   CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
-  CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
-                                             dilation, CUDNN_CROSS_CORRELATION,
-                                             entry_ptr->conv_entry.data_type));
 
-  std::vector<int> tensor_stride(full_dims);
-  // input desc
-  GetCudnnStride(full_dims, x_dim, tensor_stride.data());
-  CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
-                                        x_dim, tensor_stride.data()));
-  // filter desc
-  CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
-                                        entry_ptr->conv_entry.tensor_format, full_dims, w_dim));
-
-  // output desc
-  GetCudnnStride(full_dims, y_dim, tensor_stride.data());
-  CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims,
-                                        y_dim, tensor_stride.data()));
+  if (format == 1) {
+    ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only supported for 4d tensors";
+    int ni = 0;
+    int ci = 3;
+    int hi = 1;
+    int wi = 2;
+
+    // Set Input
+    CUDNN_CALL(cudnnSetTensor4dDescriptor(
+        entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type,
+        static_cast<int>(x_dim[ni]), static_cast<int>(x_dim[ci]), static_cast<int>(x_dim[hi]),
+        static_cast<int>(x_dim[wi])));
+
+    CUDNN_CALL(cudnnSetFilter4dDescriptor(
+        entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format,
+        static_cast<int>(w_dim[ni]), static_cast<int>(w_dim[ci]), static_cast<int>(w_dim[hi]),
+        static_cast<int>(w_dim[wi])));
+    // Set Output
+    CUDNN_CALL(cudnnSetTensor4dDescriptor(
+        entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type,
+        static_cast<int>(y_dim[ni]), static_cast<int>(y_dim[ci]), static_cast<int>(y_dim[hi]),
+        static_cast<int>(y_dim[wi])));
+
+    CUDNN_CALL(cudnnSetConvolution2dDescriptor(
+        entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0],
+        dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type));
+  } else {
+    CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
+                                               dilation, CUDNN_CROSS_CORRELATION,
+                                               entry_ptr->conv_entry.data_type));
+
+    std::vector<int> tensor_stride(full_dims);
+    // input desc
+    GetCudnnStride(full_dims, x_dim, tensor_stride.data());
+    CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
+                                          x_dim, tensor_stride.data()));
+    // filter desc
+    CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
+                                          entry_ptr->conv_entry.tensor_format, full_dims, w_dim));
+
+    // output desc
+    GetCudnnStride(full_dims, y_dim, tensor_stride.data());
+    CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims,
+                                          y_dim, tensor_stride.data()));
+  }
+
   if (cudnnGetVersion() > 7000) {
     CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH))
   }
diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py
index 839c6f5..7f504fc 100644
--- a/tests/python/contrib/test_cudnn.py
+++ b/tests/python/contrib/test_cudnn.py
@@ -102,8 +102,8 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
 def test_conv2d():
     verify_conv2d("float32", "float32", tensor_format=0)
     verify_conv2d("float16", "float32", tensor_format=1)
-    # This test is flaky, disable for now
-    # verify_conv2d("float16", "float16", tensor_format=0)
+    verify_conv2d("float16", "float16", tensor_format=0)
+    verify_conv2d("float16", "float16", tensor_format=1)
     verify_conv2d("int8", "int32", tensor_format=1)
 
     verify_conv2d("float32", "float32", tensor_format=0, groups=2)