You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/04/26 22:27:44 UTC

[GitHub] [tvm] tkonolige commented on a change in pull request #7927: [TOPI][RELAY][ONNX] Scatter ND

tkonolige commented on a change in pull request #7927:
URL: https://github.com/apache/tvm/pull/7927#discussion_r620690363



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -723,7 +723,7 @@ def update_func(dst_ptr, dst_index, update):
     return out
 
 
-def scatter_nd(data, indices, shape):
+def scatter_nd(data, indices, updates, mode):
     """Scatter elements from a n-dimension array.
 
     Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape

Review comment:
       Can you update this?

##########
File path: python/tvm/topi/x86/scatter.py
##########
@@ -46,62 +46,69 @@ def scatter_nd(data, indices, shape):
     indices : tvm.te.Tensor
         The indices of the values to extract.
 
-    shape : Sequence[int]
-        The output shape. This must be specified because it cannot be inferred.
+    updates : tvm.te.Tensor
+        The updates to apply at the Indices
+
+    mode : string
+        The update mode for the algorith, either "update" or "add"

Review comment:
       ```suggestion
           The update mode for the algorithm, either "update" or "add"
   ```

##########
File path: python/tvm/topi/scatter.py
##########
@@ -248,29 +248,31 @@ def scatter_nd(data, indices, shape):
     indices : tvm.te.Tensor
         The indices of the values to extract.
 
-    shape : Sequence[int]
-        The output shape. This must be specified because it cannot be inferred.
+    updates : tvm.te.Tensor
+        The updates to apply at the Indices
+
+    mode : string
+        The update mode for the algorith, either "update" or "add"

Review comment:
       What do the different modes do?

##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -789,38 +790,41 @@ def gen_ir(data_ptr, indices_ptr, out_ptr):
         bx = te.thread_axis("blockIdx.x")
         tx = te.thread_axis("threadIdx.x")
         max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
-        tdim = min(max_threads, fused_data_dimension)
+        tdim = min(max_threads, fused_updates_dimension)
         ib.scope_attr(tx, "thread_extent", tdim)
-        bdim = ceil_div(fused_data_dimension, tdim)
+        bdim = ceil_div(fused_updates_dimension, tdim)
         ib.scope_attr(bx, "thread_extent", bdim)
 
-        # zero data
-        # TODO(tkonolige): could we use topi.full to zero it instead?
         with ib.for_range(0, ceil_div(fused_shape, bdim)) as i:
-            index = i * fused_data_dimension + bx * tdim + tx
+            index = i * fused_updates_dimension + bx * tdim + tx
             with ib.if_scope(index < fused_shape):
-                out[index] = tvm.tir.Cast(data_ptr.dtype, 0)
+                out[index] = data[index]
 
         with ib.for_range(0, fused_indices_dimension) as i:
             j = bx * tdim + tx
-            with ib.if_scope(j < fused_data_dimension):
-                offset = fused_data_dimension
+            with ib.if_scope(j < fused_updates_dimension):
+                offset = fused_updates_dimension
                 index = j  # This is x_M, .. x_{N-1} part of the index into out.
                 # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
                 # of the index into out.
                 for l in reversed(range(indices_ptr.shape[0].value)):
                     # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
                     index += offset * indices[i + l * fused_indices_dimension]
-                    offset *= shape[l]
-                out[index] += data[i * fused_data_dimension + j]
+                    offset *= data_ptr.shape[l]
+                if mode == "update":
+                    out[index] = updates[i * fused_updates_dimension + j]
+                elif mode == "add":
+                    out[index] += updates[i * fused_updates_dimension + j]
+                else:
+                    raise NotImplementedError("scatter_nd mode not supported:", mode)

Review comment:
       Please add the supported modes to the error message.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org