You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/22 22:28:13 UTC

[GitHub] [incubator-tvm] mbrookhart opened a new pull request #6533: Scatter on Cuda

mbrookhart opened a new pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533


   I was unable to get scatter working with te schedules on cuda due to the two loops that update values in place, so I resorted to using ir_builder directly.
   
   Attempts to better parallelize the algorithm have created some strange behavior. If I can get a correct and faster implementation to work, I'll submit another PR.
   
   Thanks to @tkonolige for very helpful discussions. @notoraptor, this may be useful for your scatter_add work.
   
   cc @zhiics 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r501815195



##########
File path: tests/python/relay/test_op_level3.py
##########
@@ -903,8 +904,8 @@ def verify_scatter(dshape, ishape, axis=0):
         indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64")
 
         ref_res = ref_scatter(data_np, indices_np, updates_np, axis)
-        # TODO(mbrookhart): expand testing when adding more backend schedules
-        for target, ctx in [("llvm", tvm.cpu())]:
+

Review comment:
       @zhiics, added dynamic-shaped tests to test_op_level3.py. It was pretty painless, everything worked :+1: 

##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,444 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    """Generate scatter ir for 1d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        out_ptr[bx] = data_ptr[bx]
+
+    indices_ptr = ib.buffer_ptr(indices)
+    updates_ptr = ib.buffer_ptr(updates)
+
+    ni = indices.shape[0]
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", 1)

Review comment:
       If I don't define at least one block in the scope, the generated cuda code fails to compile.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r501884877



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,444 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    """Generate scatter ir for 1d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        out_ptr[bx] = data_ptr[bx]
+
+    indices_ptr = ib.buffer_ptr(indices)
+    updates_ptr = ib.buffer_ptr(updates)
+
+    ni = indices.shape[0]
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", 1)
+        with ib.for_range(0, ni, name="i") as i:
+            index = indices_ptr[i]
+            with ib.if_scope(index < 0):
+                out_ptr[index + n] = updates_ptr[i]
+            with ib.else_scope():
+                out_ptr[index] = updates_ptr[i]
+
+    return ib.get()
+
+
+def gen_ir_2d(data, indices, updates, axis, out):
+    """Generate scatter ir for 2d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    warp_size = tvm.target.Target.current(False).thread_warp_size
+
+    n = data.shape[0]
+    c = data.shape[1]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        tx = te.thread_axis("threadIdx.x")
+        ib.scope_attr(tx, "thread_extent", warp_size)
+        i = bx
+        with ib.for_range(0, ceil_div(c, warp_size), name="j") as j_:

Review comment:
       BTW, I am not very sure, but can we just use the number of max threads and let the number of bx be  `(n * c) // max_threads + 1` and then use a thread id to iterate from 0 to (n * c)? we could then just have something like
   
   ```python
   tid = bx * max_threads + tx
   with ib.if_scope(tid < (n * c)):
       out_ptr[tid] = data_ptr[tid]
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r499027042



##########
File path: tests/python/relay/test_op_level3.py
##########
@@ -903,8 +904,8 @@ def verify_scatter(dshape, ishape, axis=0):
         indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64")
 
         ref_res = ref_scatter(data_np, indices_np, updates_np, axis)
-        # TODO(mbrookhart): expand testing when adding more backend schedules
-        for target, ctx in [("llvm", tvm.cpu())]:
+

Review comment:
       maybe we can add a test to dyn op_level3 as well since the shape func is elemwise?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi edited a comment on pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#issuecomment-717603822


   @zhiics Can we merge this? I want to send a PR to enable `scatter_add` on GPU, building off of this PR


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics edited a comment on pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics edited a comment on pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#issuecomment-697017469


   cc @Laurawly @vinx13 @icemelon9 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r501815195



##########
File path: tests/python/relay/test_op_level3.py
##########
@@ -903,8 +904,8 @@ def verify_scatter(dshape, ishape, axis=0):
         indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64")
 
         ref_res = ref_scatter(data_np, indices_np, updates_np, axis)
-        # TODO(mbrookhart): expand testing when adding more backend schedules
-        for target, ctx in [("llvm", tvm.cpu())]:
+

Review comment:
       @zhiics, added dynamic-shaped tests to test_op_level3.py. It was pretty painless, everything worked :+1: 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics commented on pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#issuecomment-717604926


   yeah, let's merge this. before that, we probably want to rebase again just in case bad things would happen 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r499881267



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,444 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    """Generate scatter ir for 1d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        out_ptr[bx] = data_ptr[bx]
+
+    indices_ptr = ib.buffer_ptr(indices)
+    updates_ptr = ib.buffer_ptr(updates)
+
+    ni = indices.shape[0]
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", 1)

Review comment:
       If I don't define at least one block in the scope, the generate cuda code fails to compile.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#issuecomment-717603822


   @zhiics Can we merge this?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r503468549



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,444 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    """Generate scatter ir for 1d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        out_ptr[bx] = data_ptr[bx]
+
+    indices_ptr = ib.buffer_ptr(indices)
+    updates_ptr = ib.buffer_ptr(updates)
+
+    ni = indices.shape[0]
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", 1)
+        with ib.for_range(0, ni, name="i") as i:
+            index = indices_ptr[i]
+            with ib.if_scope(index < 0):
+                out_ptr[index + n] = updates_ptr[i]
+            with ib.else_scope():
+                out_ptr[index] = updates_ptr[i]
+
+    return ib.get()
+
+
+def gen_ir_2d(data, indices, updates, axis, out):
+    """Generate scatter ir for 2d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    warp_size = tvm.target.Target.current(False).thread_warp_size
+
+    n = data.shape[0]
+    c = data.shape[1]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        tx = te.thread_axis("threadIdx.x")
+        ib.scope_attr(tx, "thread_extent", warp_size)
+        i = bx

Review comment:
       i is not used?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r501341759



##########
File path: tests/python/relay/test_op_level3.py
##########
@@ -903,8 +904,8 @@ def verify_scatter(dshape, ishape, axis=0):
         indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64")
 
         ref_res = ref_scatter(data_np, indices_np, updates_np, axis)
-        # TODO(mbrookhart): expand testing when adding more backend schedules
-        for target, ctx in [("llvm", tvm.cpu())]:
+

Review comment:
       yeah, I think test_any is already pretty large. But we can probably put it there for now. We may need to split the tests later as well.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r493072350



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    bx = te.thread_axis("blockIdx.x")

Review comment:
       so we only use a single thread?

##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    bx = te.thread_axis("blockIdx.x")

Review comment:
       Great! Thanks.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r499885677



##########
File path: tests/python/relay/test_op_level3.py
##########
@@ -903,8 +904,8 @@ def verify_scatter(dshape, ishape, axis=0):
         indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64")
 
         ref_res = ref_scatter(data_np, indices_np, updates_np, axis)
-        # TODO(mbrookhart): expand testing when adding more backend schedules
-        for target, ctx in [("llvm", tvm.cpu())]:
+

Review comment:
       Hmm, that's a good idea. I'm not sure the dyn namespace is the right place though? Maybe just add a second test with dynamic shapes here? or in test_any?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r493074030



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    bx = te.thread_axis("blockIdx.x")

Review comment:
       I've spent most of today trying to get it to use multiple threads, but I'm running into an odd situation where the output gets padded with zeros if the output shape is larger than the input shape on any axis other than the updating axis. @tkonolige and I will keep looking at it to see if we can get a faster version to pass the tests.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r493072350



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    bx = te.thread_axis("blockIdx.x")

Review comment:
       so we only use a single thread?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#issuecomment-717611462


   I rebased, I'll keep an eye on it this evening to make sure it still passes CI. Thanks!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r502485425



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,444 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    """Generate scatter ir for 1d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        out_ptr[bx] = data_ptr[bx]
+
+    indices_ptr = ib.buffer_ptr(indices)
+    updates_ptr = ib.buffer_ptr(updates)
+
+    ni = indices.shape[0]
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", 1)
+        with ib.for_range(0, ni, name="i") as i:
+            index = indices_ptr[i]
+            with ib.if_scope(index < 0):
+                out_ptr[index + n] = updates_ptr[i]
+            with ib.else_scope():
+                out_ptr[index] = updates_ptr[i]
+
+    return ib.get()
+
+
+def gen_ir_2d(data, indices, updates, axis, out):
+    """Generate scatter ir for 2d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    warp_size = tvm.target.Target.current(False).thread_warp_size
+
+    n = data.shape[0]
+    c = data.shape[1]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        tx = te.thread_axis("threadIdx.x")
+        ib.scope_attr(tx, "thread_extent", warp_size)
+        i = bx
+        with ib.for_range(0, ceil_div(c, warp_size), name="j") as j_:

Review comment:
       Hmm, I am not a cuda expert, I used this pattern on @tkonolige's recommendation, perhaps he has a thought?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics commented on pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#issuecomment-697017469


   cc @Laurawly as well


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics merged pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics merged pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r494554915



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    bx = te.thread_axis("blockIdx.x")

Review comment:
       Thanks to help from @tkonolige and @tqchen, we figured out that we were accidentally launching one cuda kernel instead of 2. Fixed it and added threading to the PR.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] Laurawly commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
Laurawly commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r499001184



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,444 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    """Generate scatter ir for 1d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        out_ptr[bx] = data_ptr[bx]
+
+    indices_ptr = ib.buffer_ptr(indices)
+    updates_ptr = ib.buffer_ptr(updates)
+
+    ni = indices.shape[0]
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", 1)

Review comment:
       if bx is not used in this context, we can remove the above two lines.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r493074030



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    bx = te.thread_axis("blockIdx.x")

Review comment:
       I've spent most of today trying to get it to use multiple threads, but I'm running into an odd situation where the output gets padded with zeros if the output shape is larger than the indices shape on any axis other than the updating axis. @tkonolige and I will keep looking at it to see if we can get a faster version to pass the tests.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r501884877



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,444 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    """Generate scatter ir for 1d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        out_ptr[bx] = data_ptr[bx]
+
+    indices_ptr = ib.buffer_ptr(indices)
+    updates_ptr = ib.buffer_ptr(updates)
+
+    ni = indices.shape[0]
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", 1)
+        with ib.for_range(0, ni, name="i") as i:
+            index = indices_ptr[i]
+            with ib.if_scope(index < 0):
+                out_ptr[index + n] = updates_ptr[i]
+            with ib.else_scope():
+                out_ptr[index] = updates_ptr[i]
+
+    return ib.get()
+
+
+def gen_ir_2d(data, indices, updates, axis, out):
+    """Generate scatter ir for 2d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    warp_size = tvm.target.Target.current(False).thread_warp_size
+
+    n = data.shape[0]
+    c = data.shape[1]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        tx = te.thread_axis("threadIdx.x")
+        ib.scope_attr(tx, "thread_extent", warp_size)
+        i = bx
+        with ib.for_range(0, ceil_div(c, warp_size), name="j") as j_:

Review comment:
       BTW, I am not very sure, but can we just use the number of max threads and let the number of bx be  `(n * c) // max_threads + 1` and then use a thread id to iterate from 0 to (n * c)? we could then just have something like
   
   ```python
   tid = bx * max_threads + tx
   with ib.if_scope(tid < (n * c)):
       out_ptr[tid] = data_ptr[tid]
   ```

##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,444 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    """Generate scatter ir for 1d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        out_ptr[bx] = data_ptr[bx]
+
+    indices_ptr = ib.buffer_ptr(indices)
+    updates_ptr = ib.buffer_ptr(updates)
+
+    ni = indices.shape[0]
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", 1)
+        with ib.for_range(0, ni, name="i") as i:
+            index = indices_ptr[i]
+            with ib.if_scope(index < 0):
+                out_ptr[index + n] = updates_ptr[i]
+            with ib.else_scope():
+                out_ptr[index] = updates_ptr[i]
+
+    return ib.get()
+
+
+def gen_ir_2d(data, indices, updates, axis, out):
+    """Generate scatter ir for 2d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    warp_size = tvm.target.Target.current(False).thread_warp_size
+
+    n = data.shape[0]
+    c = data.shape[1]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        tx = te.thread_axis("threadIdx.x")
+        ib.scope_attr(tx, "thread_extent", warp_size)
+        i = bx
+        with ib.for_range(0, ceil_div(c, warp_size), name="j") as j_:

Review comment:
       BTW, I am not very sure, but can we just use the number of max threads and let the number of bx be  `(n * c) // max_threads + 1` and then use a thread id to iterate from 0 to (n * c)? we could then just have something like
   
   ```python
   tid = bx * max_threads + tx
   with ib.if_scope(tid < (n * c)):
       out_ptr[tid] = data_ptr[tid]
   ```
   
   cc @Laurawly may have better idea




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics edited a comment on pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics edited a comment on pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#issuecomment-697017469


   cc @Laurawly @vinx13 @icemelon9 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r493076021



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    bx = te.thread_axis("blockIdx.x")

Review comment:
       Great! Thanks.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics commented on pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#issuecomment-717690788


   Thanks @mbrookhart @Laurawly @tkonolige @masahi 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r501884877



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,444 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    """Generate scatter ir for 1d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        out_ptr[bx] = data_ptr[bx]
+
+    indices_ptr = ib.buffer_ptr(indices)
+    updates_ptr = ib.buffer_ptr(updates)
+
+    ni = indices.shape[0]
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", 1)
+        with ib.for_range(0, ni, name="i") as i:
+            index = indices_ptr[i]
+            with ib.if_scope(index < 0):
+                out_ptr[index + n] = updates_ptr[i]
+            with ib.else_scope():
+                out_ptr[index] = updates_ptr[i]
+
+    return ib.get()
+
+
+def gen_ir_2d(data, indices, updates, axis, out):
+    """Generate scatter ir for 2d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    warp_size = tvm.target.Target.current(False).thread_warp_size
+
+    n = data.shape[0]
+    c = data.shape[1]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        tx = te.thread_axis("threadIdx.x")
+        ib.scope_attr(tx, "thread_extent", warp_size)
+        i = bx
+        with ib.for_range(0, ceil_div(c, warp_size), name="j") as j_:

Review comment:
       BTW, I am not very sure, but can we just use the number of max threads and let the number of bx be  `(n * c) // max_threads + 1` and then use a thread id to iterate from 0 to (n * c)? we could then just have something like
   
   ```python
   tid = bx * max_threads + tx
   with ib.if_scope(tid < (n * c)):
       out_ptr[tid] = data_ptr[tid]
   ```
   
   cc @Laurawly may have better idea




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
zhiics commented on pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#issuecomment-697017469


   cc @Laurawly as well


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r493074030



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    bx = te.thread_axis("blockIdx.x")

Review comment:
       I've spent most of today trying to get it to use multiple threads, but I'm running into an odd situation where the output gets padded with zeros if the output shape is larger than the input shape on any axis other than the updating axis. @tkonolige and I will keep looking at it to see if we can get a faster version to pass the tests.

##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    bx = te.thread_axis("blockIdx.x")

Review comment:
       I've spent most of today trying to get it to use multiple threads, but I'm running into an odd situation where the output gets padded with zeros if the output shape is larger than the indices shape on any axis other than the updating axis. @tkonolige and I will keep looking at it to see if we can get a faster version to pass the tests.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r494554915



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    bx = te.thread_axis("blockIdx.x")

Review comment:
       Thanks to help from @tkonolige and @tqchen, we figured out that we were accidentally launching one cuda kernel instead of 2. Fixed it and added threading to the PR.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r508055980



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,444 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    """Generate scatter ir for 1d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        out_ptr[bx] = data_ptr[bx]
+
+    indices_ptr = ib.buffer_ptr(indices)
+    updates_ptr = ib.buffer_ptr(updates)
+
+    ni = indices.shape[0]
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", 1)
+        with ib.for_range(0, ni, name="i") as i:
+            index = indices_ptr[i]
+            with ib.if_scope(index < 0):
+                out_ptr[index + n] = updates_ptr[i]
+            with ib.else_scope():
+                out_ptr[index] = updates_ptr[i]
+
+    return ib.get()
+
+
+def gen_ir_2d(data, indices, updates, axis, out):
+    """Generate scatter ir for 2d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    warp_size = tvm.target.Target.current(False).thread_warp_size
+
+    n = data.shape[0]
+    c = data.shape[1]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        tx = te.thread_axis("threadIdx.x")
+        ib.scope_attr(tx, "thread_extent", warp_size)
+        i = bx

Review comment:
       Removed! Thanks for the catch.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tkonolige commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r502532481



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,444 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    """Generate scatter ir for 1d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        out_ptr[bx] = data_ptr[bx]
+
+    indices_ptr = ib.buffer_ptr(indices)
+    updates_ptr = ib.buffer_ptr(updates)
+
+    ni = indices.shape[0]
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", 1)
+        with ib.for_range(0, ni, name="i") as i:
+            index = indices_ptr[i]
+            with ib.if_scope(index < 0):
+                out_ptr[index + n] = updates_ptr[i]
+            with ib.else_scope():
+                out_ptr[index] = updates_ptr[i]
+
+    return ib.get()
+
+
+def gen_ir_2d(data, indices, updates, axis, out):
+    """Generate scatter ir for 2d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    warp_size = tvm.target.Target.current(False).thread_warp_size
+
+    n = data.shape[0]
+    c = data.shape[1]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        tx = te.thread_axis("threadIdx.x")
+        ib.scope_attr(tx, "thread_extent", warp_size)
+        i = bx
+        with ib.for_range(0, ceil_div(c, warp_size), name="j") as j_:

Review comment:
       Yeah, that would work too. I'm not sure what would be faster. My guess is they are equivalent unless n is really small and c is large (in which case the current code would be worse).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6533: Scatter on Cuda

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6533:
URL: https://github.com/apache/incubator-tvm/pull/6533#discussion_r499881267



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -0,0 +1,444 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Scatter operator """
+import tvm
+from tvm import te
+
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+
+def gen_ir_1d(data, indices, updates, axis, out):
+    """Generate scatter ir for 1d inputs
+
+    Parameters
+    ----------
+    data : tir.Tensor
+        The input data to the operator.
+
+    indices : tir.Tensor
+        The index locations to update.
+
+    updates : tir.Tensor
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    out : tir.Tensor
+        The output tensor.
+
+    Returns
+    -------
+    ret : tir
+        The computational ir.
+    """
+    assert axis == 0
+    n = data.shape[0]
+
+    ib = tvm.tir.ir_builder.create()
+
+    out_ptr = ib.buffer_ptr(out)
+    data_ptr = ib.buffer_ptr(data)
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", n)
+        out_ptr[bx] = data_ptr[bx]
+
+    indices_ptr = ib.buffer_ptr(indices)
+    updates_ptr = ib.buffer_ptr(updates)
+
+    ni = indices.shape[0]
+
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", 1)

Review comment:
       If I don't define at least one block in the scope, the generated cuda code fails to compile.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org