You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/04/26 22:16:44 UTC

[GitHub] [tvm] mbrookhart opened a new pull request #7927: [TOPI][RELAY][ONNX] Scatter ND

mbrookhart opened a new pull request #7927:
URL: https://github.com/apache/tvm/pull/7927


   This PR refactors the relay scatter_nd op in two ways:
   1) scatter_nd now takes in a data tensor, and updates are scattered based on indices. The previous API assumed zero-initialization for the data.
   2) scatter_nd now has a "mode" argument to determine how to update the input data, either "add" or "update". This is to support the use cases as a gradient of gather_nd and to support the ONNX and pytorch APIs. TF also supports "max", "min", and "sub", which are not yet supported here.
   
   This also adds ScatterND to the ONNX importer.
   
   cc @tkonolige @altanh @jwfromm @masahi 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jcf94 merged pull request #7927: [TOPI][RELAY][ONNX] Scatter ND

Posted by GitBox <gi...@apache.org>.
jcf94 merged pull request #7927:
URL: https://github.com/apache/tvm/pull/7927


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tkonolige commented on a change in pull request #7927: [TOPI][RELAY][ONNX] Scatter ND

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #7927:
URL: https://github.com/apache/tvm/pull/7927#discussion_r620690363



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -723,7 +723,7 @@ def update_func(dst_ptr, dst_index, update):
     return out
 
 
-def scatter_nd(data, indices, shape):
+def scatter_nd(data, indices, updates, mode):
     """Scatter elements from a n-dimension array.
 
     Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape

Review comment:
       Can you update this?

##########
File path: python/tvm/topi/x86/scatter.py
##########
@@ -46,62 +46,69 @@ def scatter_nd(data, indices, shape):
     indices : tvm.te.Tensor
         The indices of the values to extract.
 
-    shape : Sequence[int]
-        The output shape. This must be specified because it cannot be inferred.
+    updates : tvm.te.Tensor
+        The updates to apply at the Indices
+
+    mode : string
+        The update mode for the algorith, either "update" or "add"

Review comment:
       ```suggestion
           The update mode for the algorithm, either "update" or "add"
   ```

##########
File path: python/tvm/topi/scatter.py
##########
@@ -248,29 +248,31 @@ def scatter_nd(data, indices, shape):
     indices : tvm.te.Tensor
         The indices of the values to extract.
 
-    shape : Sequence[int]
-        The output shape. This must be specified because it cannot be inferred.
+    updates : tvm.te.Tensor
+        The updates to apply at the Indices
+
+    mode : string
+        The update mode for the algorith, either "update" or "add"

Review comment:
       What do the different modes do?

##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -789,38 +790,41 @@ def gen_ir(data_ptr, indices_ptr, out_ptr):
         bx = te.thread_axis("blockIdx.x")
         tx = te.thread_axis("threadIdx.x")
         max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
-        tdim = min(max_threads, fused_data_dimension)
+        tdim = min(max_threads, fused_updates_dimension)
         ib.scope_attr(tx, "thread_extent", tdim)
-        bdim = ceil_div(fused_data_dimension, tdim)
+        bdim = ceil_div(fused_updates_dimension, tdim)
         ib.scope_attr(bx, "thread_extent", bdim)
 
-        # zero data
-        # TODO(tkonolige): could we use topi.full to zero it instead?
         with ib.for_range(0, ceil_div(fused_shape, bdim)) as i:
-            index = i * fused_data_dimension + bx * tdim + tx
+            index = i * fused_updates_dimension + bx * tdim + tx
             with ib.if_scope(index < fused_shape):
-                out[index] = tvm.tir.Cast(data_ptr.dtype, 0)
+                out[index] = data[index]
 
         with ib.for_range(0, fused_indices_dimension) as i:
             j = bx * tdim + tx
-            with ib.if_scope(j < fused_data_dimension):
-                offset = fused_data_dimension
+            with ib.if_scope(j < fused_updates_dimension):
+                offset = fused_updates_dimension
                 index = j  # This is x_M, .. x_{N-1} part of the index into out.
                 # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
                 # of the index into out.
                 for l in reversed(range(indices_ptr.shape[0].value)):
                     # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
                     index += offset * indices[i + l * fused_indices_dimension]
-                    offset *= shape[l]
-                out[index] += data[i * fused_data_dimension + j]
+                    offset *= data_ptr.shape[l]
+                if mode == "update":
+                    out[index] = updates[i * fused_updates_dimension + j]
+                elif mode == "add":
+                    out[index] += updates[i * fused_updates_dimension + j]
+                else:
+                    raise NotImplementedError("scatter_nd mode not supported:", mode)

Review comment:
       Please add the supported modes to the error message.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jcf94 commented on pull request #7927: [TOPI][RELAY][ONNX] Scatter ND

Posted by GitBox <gi...@apache.org>.
jcf94 commented on pull request #7927:
URL: https://github.com/apache/tvm/pull/7927#issuecomment-828149228


   Thanks! @mbrookhart @tkonolige @jwfromm @altanh 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org