You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/03 16:32:34 UTC

[GitHub] [tvm] tkonolige commented on a change in pull request #7562: Add segment sum Op to relay and 7 corresponding TF Ops , fix scatter_add dynamic bug

tkonolige commented on a change in pull request #7562:
URL: https://github.com/apache/tvm/pull/7562#discussion_r586565340



##########
File path: python/tvm/relay/op/transform.py
##########
@@ -1450,6 +1450,75 @@ def sparse_reshape(sparse_indices, prev_shape, new_shape):
     return TupleWrapper(_make.sparse_reshape(sparse_indices, prev_shape, new_shape), 2)
 
 
+def segment_sum(data, segment_ids, num_segments=None):
+    """
+    Computes the sum along segment_ids along axis 0. If multiple segment_ids reference the same
+    location their contributions add up.
+    result[index] = Σi... data[i...] where index = segment_ids[i]

Review comment:
       This expression seems wrong in the multidimensional case. Shouldn't it be `result[index, j, k, ...] = \sum =_i data[i, j, k,..]`

##########
File path: tests/python/frontend/tensorflow/test_forward.py
##########
@@ -2080,6 +2080,140 @@ def test_forward_sparse_reshape(
     _test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, use_dyn)
 
 
+#######################################################################
+# Sparse SegmentSum
+# ------------
+
+
+def _test_sparse_segment_sum(data_np, indices_np, segment_ids_np, num_segments, use_dyn=False):
+    with tf.Graph().as_default():
+        if use_dyn:
+            data = tf.placeholder(
+                shape=[None for _ in data_np.shape], dtype=data_np.dtype, name="data"
+            )
+            indices = tf.placeholder(shape=[None], dtype=indices_np.dtype, name="indices")
+            segment_ids = tf.placeholder(
+                shape=(None), dtype=segment_ids_np.dtype, name="segment_ids"
+            )
+        else:
+            data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name="data")
+            indices = tf.placeholder(shape=indices_np.shape, dtype=indices_np.dtype, name="indices")
+            segment_ids = tf.placeholder(
+                shape=segment_ids_np.shape, dtype=segment_ids_np.dtype, name="segment_ids"
+            )
+
+        _ = tf.sparse.segment_sum(
+            data, indices, segment_ids, num_segments=num_segments, name="sparse_segment_sum"
+        )
+        compare_tf_with_tvm(
+            [data_np, indices_np, segment_ids_np],
+            [data.name, indices.name, segment_ids.name],
+            ["sparse_segment_sum:0"],
+            mode="vm",
+        )
+
+
+@pytest.mark.parametrize(
+    "data_np, indices_np, segment_ids_np, num_segments",

Review comment:
       Could you switch one of these tests to int64?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org