You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2020/06/10 17:07:50 UTC

[incubator-tvm] branch master updated: [topi] block sparse dense on cuda (#5746)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new ed58309  [topi] block sparse dense on cuda (#5746)
ed58309 is described below

commit ed583092dbeb4f1b0458ad015f607f0746d61e80
Author: Zijing Gu <ji...@live.com>
AuthorDate: Wed Jun 10 13:07:36 2020 -0400

    [topi] block sparse dense on cuda (#5746)
---
 topi/python/topi/cuda/__init__.py     |  1 +
 topi/python/topi/cuda/sparse.py       | 94 +++++++++++++++++++++++++++++++++++
 topi/python/topi/nn/sparse.py         |  2 +-
 topi/tests/python/test_topi_sparse.py | 70 ++++++++++++++++++--------
 4 files changed, 146 insertions(+), 21 deletions(-)

diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py
index ba5c54b..78e3680 100644
--- a/topi/python/topi/cuda/__init__.py
+++ b/topi/python/topi/cuda/__init__.py
@@ -50,3 +50,4 @@ from .conv2d_nhwc_tensorcore import *
 from .conv3d_ndhwc_tensorcore import *
 from .dense_tensorcore import *
 from .correlation import *
+from .sparse import *
diff --git a/topi/python/topi/cuda/sparse.py b/topi/python/topi/cuda/sparse.py
new file mode 100644
index 0000000..037eea4
--- /dev/null
+++ b/topi/python/topi/cuda/sparse.py
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Sparse operators"""
+from tvm import te
+from tvm import autotvm
+from tvm.autotvm.task.space import SplitEntity
+from ..util import traverse_inline
+from .. import nn
+
+
+@autotvm.register_topi_compute("sparse_dense.cuda")
+def sparse_dense(cfg, data, weight_data, weight_indices, weight_indptr):
+    """
+    Computes sparse-dense matrix multiplication of `data` and
+    `(weight_data, weight_indices, weight_indptr).T`
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        2-D with shape [M, K], float32
+
+    weight_data : tvm.te.Tensor
+        1-D with shape [nnz] (CSR) or
+        3-D with shape [num_blocks, bs_r, bs_c] (BSR)
+
+    weight_indices : tvm.te.Tensor
+        1-D with shape [nnz] (CSR) or
+        1-D with shape [num_blocks] (BSR)
+
+    weight_indptr : tvm.te.Tensor
+        1-D with shape [N + 1] (CSR) or
+        1-D with shape [(N + 1) // bs_r] (BSR)
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        2-D with shape [M, N]
+    """
+    # pylint:disable=unused-argument
+    return nn.sparse_dense(data, weight_data, weight_indices, weight_indptr)
+
+
+@autotvm.register_topi_schedule("sparse_dense.cuda")
+def schedule_sparse_dense(cfg, outs):
+    """Create schedule for sparse dense"""
+    # pylint:disable=invalid-name
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == "sparse_dense_bsrmm":
+            y_bsrmm = op.input_tensors[0]
+            assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block"
+            out = s.outputs[0].output(0)
+            (_, c) = s[y_bsrmm].op.reduce_axis
+
+            (m_o, n_o) = s[out].op.axis
+            s[out].bind(m_o, te.thread_axis("blockIdx.x"))
+            s[out].bind(n_o, te.thread_axis("blockIdx.y"))
+            s[y_bsrmm].compute_at(s[out], n_o)
+
+            thread_x = te.thread_axis("threadIdx.x")
+
+            cfg.define_split("tile_c", c, num_outputs=2)
+            if cfg.is_fallback:
+                cfg["tile_c"] = SplitEntity([-1, 8])
+            _, ci = cfg['tile_c'].apply(s, y_bsrmm, c)
+
+            y_bsrmm_factored = s.rfactor(y_bsrmm, ci)
+            tx = s[y_bsrmm].op.reduce_axis[0]
+            s[y_bsrmm].bind(tx, thread_x)
+            s[y_bsrmm_factored].compute_at(s[y_bsrmm], tx)
+            s[y_bsrmm].set_store_predicate(thread_x.var.equal(0))
+            s[out].set_store_predicate(thread_x.var.equal(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
diff --git a/topi/python/topi/nn/sparse.py b/topi/python/topi/nn/sparse.py
index b37bac2..b24121b 100644
--- a/topi/python/topi/nn/sparse.py
+++ b/topi/python/topi/nn/sparse.py
@@ -30,7 +30,7 @@ def sparse_dense(data, weight_data, weight_indices, weight_indptr):
 
     Parameters
     ----------
-    x : tvm.te.Tensor
+    data : tvm.te.Tensor
         2-D with shape [M, K], float32
 
     weight_data : tvm.te.Tensor
diff --git a/topi/tests/python/test_topi_sparse.py b/topi/tests/python/test_topi_sparse.py
index fc2d26b..3290fc0 100644
--- a/topi/tests/python/test_topi_sparse.py
+++ b/topi/tests/python/test_topi_sparse.py
@@ -26,6 +26,12 @@ from collections import namedtuple
 import time
 import scipy.sparse as sp
 
+_sparse_dense_implement = {
+    "generic": (topi.nn.sparse_dense, topi.generic.schedule_sparse_dense),
+    "cuda": (topi.cuda.sparse_dense, topi.cuda.schedule_sparse_dense),
+    "x86": (topi.nn.sparse_dense, topi.x86.schedule_sparse_dense)
+}
+
 def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True):
     nr, nc, n = te.var("nr"), te.var("nc"), te.var("n")
     dtype = 'float32'
@@ -293,16 +299,28 @@ def test_sparse_dense_bsr():
     W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype))
     W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype))
     X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype))
-    Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
-    s = te.create_schedule(Y.op)
-    func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
-    Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
-    func(tvm.nd.array(X_np),
-         tvm.nd.array(W_sp_np.data),
-         tvm.nd.array(W_sp_np.indices),
-         tvm.nd.array(W_sp_np.indptr),
-         Y_tvm)
-    tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        fcompute, fschedule = topi.testing.dispatch(device, _sparse_dense_implement)
+        with tvm.target.create(device):
+            Y = fcompute(X, W_data, W_indices, W_indptr)
+            s = fschedule([Y])
+            func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
+            Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx)
+            func(tvm.nd.array(X_np, ctx=ctx),
+                 tvm.nd.array(W_sp_np.data, ctx=ctx),
+                 tvm.nd.array(W_sp_np.indices, ctx=ctx),
+                 tvm.nd.array(W_sp_np.indptr, ctx=ctx),
+                 Y_tvm)
+            tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
+
+    for device in ['llvm', 'cuda']:
+        check_device(device)
 
 def test_sparse_dense_bsr_randomized():
     for _ in range(20):
@@ -322,16 +340,28 @@ def test_sparse_dense_bsr_randomized():
         W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype))
         W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype))
         X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype))
-        Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
-        s = te.create_schedule(Y.op)
-        func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
-        Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
-        func(tvm.nd.array(X_np),
-             tvm.nd.array(W_sp_np.data),
-             tvm.nd.array(W_sp_np.indices),
-             tvm.nd.array(W_sp_np.indptr),
-             Y_tvm)
-        tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5)
+
+        def check_device(device):
+            ctx = tvm.context(device, 0)
+            if not ctx.exist:
+                print("Skip because %s is not enabled" % device)
+                return
+            print("Running on target: %s" % device)
+            fcompute, fschedule = topi.testing.dispatch(device, _sparse_dense_implement)
+            with tvm.target.create(device):
+                Y = fcompute(X, W_data, W_indices, W_indptr)
+                s = fschedule([Y])
+                func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
+                Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx)
+                func(tvm.nd.array(X_np, ctx=ctx),
+                     tvm.nd.array(W_sp_np.data, ctx=ctx),
+                     tvm.nd.array(W_sp_np.indices, ctx=ctx),
+                     tvm.nd.array(W_sp_np.indptr, ctx=ctx),
+                     Y_tvm)
+                tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5)
+
+        for device in ['llvm', 'cuda']:
+            check_device(device)
 
 
 def test_sparse_dense():