You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/16 02:11:08 UTC

[GitHub] [tvm] kevinthesun opened a new pull request #7117: [TOPI] Fix GPU Dynamic Op Schedule

kevinthesun opened a new pull request #7117:
URL: https://github.com/apache/tvm/pull/7117


   This PR limits the resources used by dynamic shape gpu kernels to avoid runtime errors. It also skips ```CallPacked``` in vm if kernel has only one output and this output is empty, like (1, 0, 6).
   
   After this PR, TF and PT object detection models should be runnable on Nvidia GPU.
   
   @zhiics @Laurawly @mbrookhart 
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on a change in pull request #7117: [TOPI] Fix GPU Dynamic Op Schedule

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #7117:
URL: https://github.com/apache/tvm/pull/7117#discussion_r544451189



##########
File path: python/tvm/topi/cuda/nms.py
##########
@@ -95,23 +94,23 @@ def rearrange_indices_out_ir(data, output, valid_box_count):
     with ib.new_scope():
         i = te.thread_axis("blockIdx.x")
         ib.scope_attr(i, "thread_extent", batch_size)
-        valid_idx = ib.allocate("int32", (1), name="valid_idx", scope="local")
-        valid_idx[0] = 0
+        valid_idx = ib.allocate("int32", (batch_size,), name="valid_idx", scope="local")

Review comment:
       We can't allocate something with dynamic shapes, this is why this test is failing. I'm not sure I understand why this change is needed. Since we're threading over batch size, a thread-local variable of size 1 will effectively be a scratch pad of size batch_size distributed over the threads.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] kevinthesun commented on a change in pull request #7117: [TOPI] Fix GPU Dynamic Op Schedule

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on a change in pull request #7117:
URL: https://github.com/apache/tvm/pull/7117#discussion_r544558368



##########
File path: python/tvm/topi/cuda/nms.py
##########
@@ -754,7 +782,22 @@ def non_max_suppression(
     )
     score_axis = score_index
     score_shape = (batch_size, num_anchors)
-    score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
+    data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)

Review comment:
       When the nms workload is large like in RCNN models, general cuda injective schedule can still cause runtime error even with the improvement of this PR. It's common that any dynamic injective op can have runtime issue with current uniform cuda injective schedule.
   
   This problem is not directly related to nms, but cuda injective schedule. Later we might need to revisit this part for gpu dynamic ops and have a better and more general solution(together with more tests).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on a change in pull request #7117: [TOPI] Fix GPU Dynamic Op Schedule

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #7117:
URL: https://github.com/apache/tvm/pull/7117#discussion_r544452199



##########
File path: tests/python/relay/test_any.py
##########
@@ -1430,6 +1439,21 @@ def test_non_max_suppression():
         disable_targets=["nvptx"],
     )
 
+    np_data = np.zeros((1, 0, 6)).astype("float32")
+    np_valid_count = np.array([0]).astype("int32")
+    np_indices = np.zeros((1, 0)).astype("int32")
+    np_max_output_size = -1
+    np_indices_result = np.zeros((1, 0))
+    np_valid_box_count = np.array([[0]]).astype("int32")
+
+    check_result(
+        [np_data, np_valid_count, np_indices, np_max_output_size],
+        mod,
+        [np_indices_result, np_valid_box_count],
+        only_vm=False,
+        disable_targets=["nvptx"],

Review comment:
       This tests the empty output VM change :+1: 
   Why disable nvptx?

##########
File path: tests/python/relay/test_any.py
##########
@@ -199,6 +199,15 @@ def test_any_concat():
     ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0)
     check_result([x_np, y_np], mod, ref)
 
+    num_inputs = 25
+    x = [relay.var("x", shape=(relay.Any(),), dtype="float32") for _ in range(num_inputs)]
+    z = relay.op.concatenate(x, axis=0)

Review comment:
       this tests the injective schedule :+1: 

##########
File path: python/tvm/topi/cuda/nms.py
##########
@@ -754,7 +782,22 @@ def non_max_suppression(
     )
     score_axis = score_index
     score_shape = (batch_size, num_anchors)
-    score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
+    data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)

Review comment:
       This looks fine, but I'm a little surprised it's necessary. Do you have a test case that breaks the current code, or is this mostly for performance?

##########
File path: python/tvm/topi/cuda/conv2d_transpose_nchw.py
##########
@@ -194,6 +197,8 @@ def _callback(op):
 
             if cfg.is_fallback:
                 N, F, Y, X = get_const_tuple(conv.shape)
+                if not isinstance(N, int):
+                    N = 1

Review comment:
       Can we add a test that hits this change?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] kevinthesun commented on a change in pull request #7117: [TOPI] Fix GPU Dynamic Op Schedule

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on a change in pull request #7117:
URL: https://github.com/apache/tvm/pull/7117#discussion_r544551395



##########
File path: tests/python/relay/test_any.py
##########
@@ -1430,6 +1439,21 @@ def test_non_max_suppression():
         disable_targets=["nvptx"],
     )
 
+    np_data = np.zeros((1, 0, 6)).astype("float32")
+    np_valid_count = np.array([0]).astype("int32")
+    np_indices = np.zeros((1, 0)).astype("int32")
+    np_max_output_size = -1
+    np_indices_result = np.zeros((1, 0))
+    np_valid_box_count = np.array([[0]]).astype("int32")
+
+    check_result(
+        [np_data, np_valid_count, np_indices, np_max_output_size],
+        mod,
+        [np_indices_result, np_valid_box_count],
+        only_vm=False,
+        disable_targets=["nvptx"],

Review comment:
       There is issue causing segfault from dynamic nms for nvptx, and generally we need thrust for any dynamic shape sorting. For now nvptx is not ready for these operations.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] kevinthesun commented on pull request #7117: [TOPI] Fix GPU Dynamic Op Schedule

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on pull request #7117:
URL: https://github.com/apache/tvm/pull/7117#issuecomment-747725450


   Thanks @mbrookhart 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] kevinthesun merged pull request #7117: [TOPI] Fix GPU Dynamic Op Schedule

Posted by GitBox <gi...@apache.org>.
kevinthesun merged pull request #7117:
URL: https://github.com/apache/tvm/pull/7117


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on a change in pull request #7117: [TOPI] Fix GPU Dynamic Op Schedule

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #7117:
URL: https://github.com/apache/tvm/pull/7117#discussion_r544553996



##########
File path: tests/python/relay/test_any.py
##########
@@ -1430,6 +1439,21 @@ def test_non_max_suppression():
         disable_targets=["nvptx"],
     )
 
+    np_data = np.zeros((1, 0, 6)).astype("float32")
+    np_valid_count = np.array([0]).astype("int32")
+    np_indices = np.zeros((1, 0)).astype("int32")
+    np_max_output_size = -1
+    np_indices_result = np.zeros((1, 0))
+    np_valid_box_count = np.array([[0]]).astype("int32")
+
+    check_result(
+        [np_data, np_valid_count, np_indices, np_max_output_size],
+        mod,
+        [np_indices_result, np_valid_box_count],
+        only_vm=False,
+        disable_targets=["nvptx"],

Review comment:
       Makes sense. I'm trying to fix the default sort kernel in #7099, if you want to take a look




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] kevinthesun commented on a change in pull request #7117: [TOPI] Fix GPU Dynamic Op Schedule

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on a change in pull request #7117:
URL: https://github.com/apache/tvm/pull/7117#discussion_r544558622



##########
File path: python/tvm/topi/cuda/conv2d_transpose_nchw.py
##########
@@ -194,6 +197,8 @@ def _callback(op):
 
             if cfg.is_fallback:
                 N, F, Y, X = get_const_tuple(conv.shape)
+                if not isinstance(N, int):
+                    N = 1

Review comment:
       Yeah we do have a test for this. Now I enabled all targets.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org