You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2021/12/01 18:53:45 UTC
[tvm] branch main updated: Fix cuDNN call for NHWC layout (#9600)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 70de68a Fix cuDNN call for NHWC layout (#9600)
70de68a is described below
commit 70de68a38f6738255cc71fd9a38964b23f2c5f55
Author: masahi <ma...@gmail.com>
AuthorDate: Thu Dec 2 03:53:20 2021 +0900
Fix cuDNN call for NHWC layout (#9600)
---
python/tvm/relay/op/strategy/cuda.py | 35 ++++++++++-------
src/runtime/contrib/cudnn/conv_forward.cc | 64 +++++++++++++++++++++++--------
tests/python/contrib/test_cudnn.py | 4 +-
3 files changed, 70 insertions(+), 33 deletions(-)
diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index 80f1fe1..7bc04b4 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -180,8 +180,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
name="conv2d_hwcn.cuda",
)
- elif layout == "NHWC":
- assert kernel_layout == "HWIO"
+ elif layout == "NHWC" and kernel_layout == "HWIO":
strategy.add_implementation(
wrap_compute_conv2d(topi.gpu.conv2d_nhwc),
wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc),
@@ -304,19 +303,27 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.cuda",
)
- else:
+ elif target.kind.name == "cuda" and "cudnn" not in target.libs:
+ # No TVM native kernel applicable
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
- # add cudnn implementation
- if target.kind.name == "cuda" and "cudnn" in target.libs:
- if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and padding[1] == padding[3]:
- strategy.add_implementation(
- wrap_compute_conv2d(
- topi.cuda.conv2d_cudnn, need_data_layout=True, has_groups=True
- ),
- wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn),
- name="conv2d_cudnn.cuda",
- plevel=25,
- )
+
+ if (
+ target.kind.name == "cuda"
+ and "cudnn" in target.libs
+ and layout in ["NCHW", "NHWC"]
+ and padding[0] == padding[2]
+ and padding[1] == padding[3]
+ ):
+ # add cudnn implementation
+ if layout == "NHWC":
+ assert kernel_layout == "OHWI"
+ strategy.add_implementation(
+ wrap_compute_conv2d(topi.cuda.conv2d_cudnn, need_data_layout=True, has_groups=True),
+ wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn),
+ name="conv2d_cudnn.cuda",
+ plevel=25,
+ )
+
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc
index 2d7f826..9770498 100644
--- a/src/runtime/contrib/cudnn/conv_forward.cc
+++ b/src/runtime/contrib/cudnn/conv_forward.cc
@@ -157,7 +157,7 @@ void OutputShape(int format, int dims, int groups, const int pad[], const int st
entry_ptr->conv_entry.data_type));
if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) {
- ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only defined for 4d tensors";
+ ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only supported for 4d tensors";
// Set Input
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
@@ -206,23 +206,53 @@ void FindAlgo(int format, int dims, int groups, const int pad[], const int strid
// conv desc
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
- CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
- dilation, CUDNN_CROSS_CORRELATION,
- entry_ptr->conv_entry.data_type));
- std::vector<int> tensor_stride(full_dims);
- // input desc
- GetCudnnStride(full_dims, x_dim, tensor_stride.data());
- CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
- x_dim, tensor_stride.data()));
- // filter desc
- CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
- entry_ptr->conv_entry.tensor_format, full_dims, w_dim));
-
- // output desc
- GetCudnnStride(full_dims, y_dim, tensor_stride.data());
- CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims,
- y_dim, tensor_stride.data()));
+ if (format == 1) {
+ ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only supported for 4d tensors";
+ int ni = 0;
+ int ci = 3;
+ int hi = 1;
+ int wi = 2;
+
+ // Set Input
+ CUDNN_CALL(cudnnSetTensor4dDescriptor(
+ entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type,
+ static_cast<int>(x_dim[ni]), static_cast<int>(x_dim[ci]), static_cast<int>(x_dim[hi]),
+ static_cast<int>(x_dim[wi])));
+
+ CUDNN_CALL(cudnnSetFilter4dDescriptor(
+ entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format,
+ static_cast<int>(w_dim[ni]), static_cast<int>(w_dim[ci]), static_cast<int>(w_dim[hi]),
+ static_cast<int>(w_dim[wi])));
+ // Set Output
+ CUDNN_CALL(cudnnSetTensor4dDescriptor(
+ entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type,
+ static_cast<int>(y_dim[ni]), static_cast<int>(y_dim[ci]), static_cast<int>(y_dim[hi]),
+ static_cast<int>(y_dim[wi])));
+
+ CUDNN_CALL(cudnnSetConvolution2dDescriptor(
+ entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0],
+ dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type));
+ } else {
+ CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
+ dilation, CUDNN_CROSS_CORRELATION,
+ entry_ptr->conv_entry.data_type));
+
+ std::vector<int> tensor_stride(full_dims);
+ // input desc
+ GetCudnnStride(full_dims, x_dim, tensor_stride.data());
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
+ x_dim, tensor_stride.data()));
+ // filter desc
+ CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
+ entry_ptr->conv_entry.tensor_format, full_dims, w_dim));
+
+ // output desc
+ GetCudnnStride(full_dims, y_dim, tensor_stride.data());
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims,
+ y_dim, tensor_stride.data()));
+ }
+
if (cudnnGetVersion() > 7000) {
CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH))
}
diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py
index 839c6f5..7f504fc 100644
--- a/tests/python/contrib/test_cudnn.py
+++ b/tests/python/contrib/test_cudnn.py
@@ -102,8 +102,8 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
def test_conv2d():
verify_conv2d("float32", "float32", tensor_format=0)
verify_conv2d("float16", "float32", tensor_format=1)
- # This test is flaky, disable for now
- # verify_conv2d("float16", "float16", tensor_format=0)
+ verify_conv2d("float16", "float16", tensor_format=0)
+ verify_conv2d("float16", "float16", tensor_format=1)
verify_conv2d("int8", "int32", tensor_format=1)
verify_conv2d("float32", "float32", tensor_format=0, groups=2)