You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/03/11 21:02:16 UTC

[tvm] branch main updated: [TOPI] Sparse Add Op added (#7435)

This is an automated email from the ASF dual-hosted git repository.

mbaret pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 017ff94  [TOPI] Sparse Add Op added (#7435)
017ff94 is described below

commit 017ff94d15df85ea8476f8ad3ce234470072ae84
Author: ANSHUMAN TRIPATHY <an...@huawei.com>
AuthorDate: Fri Mar 12 02:31:59 2021 +0530

    [TOPI] Sparse Add Op added (#7435)
    
    * [TOPI] Sparse Add Op added
    
    * lint resolved
    
    * TF frontend support added
    
    * Test case added
    
    * [1] Review comment handled
    
    * [2] Review comment handled
    
    * [3] Review comment handled
    
    * [4] Review comment handled
    
    * [5] Review comment handled
---
 python/tvm/relay/frontend/tensorflow.py          | 35 ++++++++++++
 python/tvm/relay/op/nn/_nn.py                    |  5 ++
 python/tvm/relay/op/nn/nn.py                     | 47 ++++++++++++++++
 python/tvm/relay/op/strategy/generic.py          | 23 ++++++++
 python/tvm/topi/nn/sparse.py                     | 69 ++++++++++++++++++++++++
 src/relay/op/nn/sparse.cc                        | 41 ++++++++++++++
 tests/python/frontend/tensorflow/test_forward.py | 48 +++++++++++++++++
 tests/python/topi/python/test_topi_sparse.py     | 28 ++++++++++
 8 files changed, 296 insertions(+)

diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index c79c495..f56d187 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1286,6 +1286,40 @@ def _sparse_segment_mean_with_num_segments():
     return _impl
 
 
+def _sparse_tensor_dense_add():
+    # Sparse utility from scipy
+    from scipy.sparse import csr_matrix
+
+    def _impl(inputs, attr, params, mod):
+        assert (
+            len(inputs) == 4
+        ), "There should be 4 input tensors [sparse_indices, sparse_values, sparse_shape, dense]."
+
+        indices_tensor = _infer_value(inputs[0], params, mod).asnumpy()
+        values_tensor = _infer_value(inputs[1], params, mod).asnumpy()
+        dense_shape_tensor = _infer_value(inputs[2], params, mod).asnumpy()
+
+        data = inputs[3]
+
+        rows = [x[0] for x in indices_tensor]
+        cols = [x[1] for x in indices_tensor]
+
+        # Create scipy sparse Tensor(CSR)
+        weight_sp = csr_matrix(
+            (values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist())
+        )
+
+        weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype)
+        weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype)
+        weight_indices = _expr.const(weight_sp.indices, weight_sp.indices.dtype)
+
+        ret = _op.nn.sparse_add(data, [weight_data, weight_indices, weight_indptrs])
+
+        return ret
+
+    return _impl
+
+
 def _identity():
     def _impl(inputs, attr, params, mod):
         return inputs[0]
@@ -2787,6 +2821,7 @@ _convert_map = {
     "SparseSegmentSqrtNWithNumSegments": _sparse_segment_sum_sqrtn_with_num_segments(),
     "SparseSegmentMean": _sparse_segment_mean(),
     "SparseSegmentMeanWithNumSegments": _sparse_segment_mean_with_num_segments(),
+    "SparseTensorDenseAdd": _sparse_tensor_dense_add(),
     "Split": _split(False),
     "SplitV": _split(True),
     "Sqrt": AttrCvt("sqrt"),
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 6ae86c0..af64873 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -142,6 +142,11 @@ def alter_op_layout_sparse_dense(attrs, inputs, tinfos, out_type):
     return topi.nn.sparse_dense_alter_layout(attrs, inputs, tinfos, out_type)
 
 
+# sparse_add
+reg.register_strategy("nn.sparse_add", strategy.sparse_add_strategy)
+reg.register_pattern("nn.sparse_add", reg.OpPattern.OPAQUE)
+
+
 @reg.register_compute("nn.internal.sparse_dense_padded")
 def compute_sparse_dense_padded(attrs, inputs, out_type):
     """Compute definition of sparse_dense_padded"""
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index 5135ac7..a1147fe 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -2148,6 +2148,53 @@ def sparse_transpose(x):
     return expr.TupleWrapper(_make.sparse_transpose(x[0], x[1], x[2]), 3)
 
 
+# pylint: disable=no-else-return,inconsistent-return-statements
+def sparse_add(dense_mat, sparse_mat):
+    r"""
+    Computes the matrix addition of `dense_mat` and `sparse_mat`, where `dense_mat` is
+    a dense matrix and `sparse_mat` is a sparse (CSR) namedtuple with
+    fields `data`, `indices`, and `indptr`.
+
+    .. math::
+
+        \mbox{sparse_add}(dense_mat, sparse_mat)[m, n] = \mbox{add}(\mbox{as_dense}(S), (D))[m, n]
+
+    where `as_dense` returns dense equivalent of the given S(sparse matrix)
+    while performing addition with given D(dense matrix).
+
+    Parameters
+    ----------
+    dense_mat : tvm.relay.Expr
+        The input dense matrix for the matrix addition
+
+    sparse_mat : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
+        The input sparse matrix(CSR) for the matrix addition.
+
+    Returns
+    -------
+    result: tvm.relay.Expr
+        The computed result.
+
+    Examples
+    -------
+    .. code-block:: python
+        dense_data = [[ 3.,   4.,   4. ]
+                      [ 4.,  2.,  5. ]]
+        sparse_data = [4., 8.]
+        sparse_indices =[0, 2]
+        sparse_indptr =[0, 1, 2]
+
+        output = relay.sparse_add(dense_data, sparse_data, sparse_indices, sparse_indptr)
+
+        output = [[ 7.,   4.,   4. ]
+                  [ 4.,  2.,  13. ]]
+    """
+    if hasattr(sparse_mat, "indices"):
+        return _make.sparse_add(dense_mat, sparse_mat.data, sparse_mat.indices, sparse_mat.indptr)
+    else:
+        return _make.sparse_add(dense_mat, sparse_mat[0], sparse_mat[1], sparse_mat[2])
+
+
 def contrib_conv2d_winograd_without_weight_transform(
     data,
     weight,
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index be86ea9..04f2564 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -799,6 +799,29 @@ def sparse_dense_padded_strategy(attrs, inputs, out_type, target):
     raise NotImplementedError("sparse_dense_padded is only implemented for cuda")
 
 
+# sparse_add
+def wrap_compute_sparse_add(topi_compute):
+    """wrap sparse add topi compute"""
+
+    def _compute_sparse_add(attrs, inputs, out_type):
+        return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3])]
+
+    return _compute_sparse_add
+
+
+@override_native_generic_func("sparse_add_strategy")
+def sparse_add_strategy(attrs, inputs, out_type, target):
+    """sparse add generic strategy"""
+    logger.warning("sparse add is not optimized for this platform.")
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_sparse_add(topi.nn.sparse_add),
+        wrap_topi_schedule(topi.generic.schedule_extern),
+        name="sparse_add.generic",
+    )
+    return strategy
+
+
 # sparse_transpose
 @generic_func
 def schedule_sparse_transpose(attrs, outs, target):
diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py
index 1bf18df..7561106 100644
--- a/python/tvm/topi/nn/sparse.py
+++ b/python/tvm/topi/nn/sparse.py
@@ -468,3 +468,72 @@ def try_get_sparse_input(args):
     sparse_input_map[sparse_indptr] = sparse_prefix + "W_indptr"
 
     return sparse_input_map
+
+
+def sparse_add(dense_data, sparse_data, sparse_indices, sparse_indptr):
+    """
+    Computes sparse-dense addition
+
+    Parameters
+    ----------
+    dense_data : tvm.te.Tensor
+        2-D with shape [M, N]
+
+    sparse_data : tvm.te.Tensor
+        1-D with shape [nnz] (CSR)
+
+    sparse_indices : tvm.te.Tensor
+        1-D with shape [nnz] (CSR)
+
+    sparse_indptr : tvm.te.Tensor
+        1-D with shape [M + 1] (CSR)
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        2-D with shape [M, N]
+    """
+    # TODO(ANSHUMAN87): support BSR format too
+    assert len(sparse_data.shape) == 1, "only CSR format is supported"
+    return _sparse_add_csr(dense_data, sparse_data, sparse_indices, sparse_indptr)
+
+
+def _sparse_add_csr(dense_data_inp, sparse_data_inp, sparse_indices_inp, sparse_indptr_inp):
+    oshape = get_const_tuple(dense_data_inp.shape)
+
+    def _csr_add_ir(dense_data, sparse_data, sparse_indices, sparse_indptr, out_data):
+        irb = tvm.tir.ir_builder.create()
+        dense_data_ptr = irb.buffer_ptr(dense_data)
+        sparse_data_ptr = irb.buffer_ptr(sparse_data)
+        sparse_indices_ptr = irb.buffer_ptr(sparse_indices)
+        sparse_indptr_ptr = irb.buffer_ptr(sparse_indptr)
+
+        out_data_ptr = irb.buffer_ptr(out_data)
+
+        with irb.for_range(0, oshape[0], kind="vectorize", name="row") as row:
+            with irb.for_range(0, oshape[1], kind="parallel", name="col") as col:
+                out_data_ptr[row, col] = dense_data_ptr[row, col]
+
+        with irb.for_range(0, oshape[0], kind="parallel", name="row") as row:
+            offset = sparse_indptr_ptr[row]
+            diff = sparse_indptr_ptr[row + 1] - sparse_indptr_ptr[row]
+            with irb.for_range(0, diff, kind="serial", name="idx") as idx:
+                real_idx = offset + idx
+                col = sparse_indices_ptr[real_idx]
+                out_data_ptr[row, col] = sparse_data_ptr[real_idx] + out_data_ptr[row, col]
+
+        return irb.get()
+
+    return te.extern(
+        shape=oshape,
+        inputs=[dense_data_inp, sparse_data_inp, sparse_indices_inp, sparse_indptr_inp],
+        fcompute=lambda ins, outs: _csr_add_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
+        tag="sparse_add_csr",
+        dtype=[
+            dense_data_inp.dtype,
+            sparse_data_inp.dtype,
+            sparse_indices_inp.dtype,
+            sparse_indptr_inp.dtype,
+        ],
+        name="sparse_add_csr_output",
+    )
diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc
index 6322cff..b1a16f1 100644
--- a/src/relay/op/nn/sparse.cc
+++ b/src/relay/op/nn/sparse.cc
@@ -196,5 +196,46 @@ RELAY_REGISTER_OP("nn.sparse_transpose")
     .set_support_level(1)
     .add_type_rel("SparseTranspose", SparseTransposeRel);
 
+// relay.nn.sparse_add
+bool SparseAddRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                  const TypeReporter& reporter) {
+  ICHECK_EQ(types.size(), 5) << "expecting 4 inputs and 1 output.";
+  const auto* dense_data = types[0].as<TensorTypeNode>();
+  const auto* sparse_data = types[1].as<TensorTypeNode>();
+  ICHECK(reporter->Assert(sparse_data->dtype == dense_data->dtype))
+      << "sparse tensor and dense tensor datatype should match.";
+  ICHECK(reporter->Assert(sparse_data->shape.size() == 1)) << "sparse data tensor should be 1D.";
+  const auto* sparse_indices = types[2].as<TensorTypeNode>();
+  ICHECK(reporter->Assert(sparse_indices->shape.size() == 1))
+      << "sparse indices tensor should be 1D.";
+
+  reporter->Assign(types[4], TensorType(dense_data->shape, dense_data->dtype));
+  return true;
+}
+
+Expr MakeSparseAdd(Expr dense_data, Expr sparse_data, Expr sparse_indices, Expr sparse_indptr) {
+  static const Op& op = Op::Get("nn.sparse_add");
+  return Call(op, {dense_data, sparse_data, sparse_indices, sparse_indptr}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_add").set_body_typed(MakeSparseAdd);
+
+RELAY_REGISTER_OP("nn.sparse_add")
+    .describe(R"code(Add a dense matrix X with sparse matrix Y.
+
+- **dense**: `(M, N)`
+- **sparse**: `(M, N)`
+
+- **out**: `(M, N)`.
+
+)code" TVM_ADD_FILELINE)
+    .set_num_inputs(4)
+    .add_argument("dense_data", "2D Tensor", "Dense data matrix.")
+    .add_argument("sparse_data", "1D Tensor", "Sparse data vector.")
+    .add_argument("sparse_indices", "1D Tensor", "Sparse indices vector.")
+    .add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer vector.")
+    .set_support_level(1)
+    .add_type_rel("SparseAdd", SparseAddRel);
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 81aeb5e..fa27dee 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -2353,6 +2353,54 @@ def test_forward_sparse_to_dense_v2():
 
 
 #######################################################################
+# tensorflow.sparse.add
+# ----------------------------------
+
+
+def _test_sparse_add(indices, values, A_shape, B_shape, dtype, flip=False):
+    """ One iteration of tf.sparse.add """
+
+    # TODO(ANSHUMAN87): support cuda
+    # TODO(ANSHUMAN87): support both sparse input case
+
+    with tf.Graph().as_default():
+        A_sp = tf.sparse.SparseTensor(
+            indices=indices, values=np.array(values).astype(dtype), dense_shape=A_shape
+        )
+        B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")
+
+        # TODO(ANSHUMAN87): support user input threashold values
+        if flip:
+            result = tf.sparse.add(B, A_sp, threshold=0)
+        else:
+            result = tf.sparse.add(A_sp, B, threshold=0)
+
+        B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
+
+        compare_tf_with_tvm([B_np], [B.name], result.name, no_gpu=True)
+
+
+def test_sparse_add():
+    """ sparse.add op test"""
+    ###################################################################
+    #
+    # In order to create a SparseTensor, it requires 3 input as below:
+    #    SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
+    #
+    # Above Sparse can be represented in Dense as below :
+    #    [[1, 0, 0, 0]
+    #     [0, 0, 2, 0]
+    #     [0, 0, 0, 0]]
+    #
+    # ------------------------------------------------------------------
+    for dtype_inp in ["float32", "float64", "int32"]:
+        _test_sparse_add([[0, 0], [1, 2]], [4.0, 8.0], [3, 4], [3, 4], dtype_inp)
+        _test_sparse_add([[0, 0], [1, 2]], [4.0, 8.0], [3, 4], [3, 4], dtype_inp, True)
+        _test_sparse_add([[0, 0], [1, 3], [4, 3]], [3.0, 6.0, 9.0], [5, 5], [5, 5], dtype_inp)
+        _test_sparse_add([[0, 0], [1, 3], [4, 3]], [3.0, 6.0, 9.0], [5, 5], [5, 5], dtype_inp, True)
+
+
+#######################################################################
 # StridedSlice
 # ------------
 
diff --git a/tests/python/topi/python/test_topi_sparse.py b/tests/python/topi/python/test_topi_sparse.py
index d5bd7aa..d84bd15 100644
--- a/tests/python/topi/python/test_topi_sparse.py
+++ b/tests/python/topi/python/test_topi_sparse.py
@@ -526,6 +526,33 @@ def test_sparse_dense_padded_alter_op():
             x = relay.build(tvm.IRModule.from_expr(f), target=tvm.target.Target("cuda"))
 
 
+def test_sparse_add_csr():
+    for indices_dtype in ["int32", "int64"]:
+        for data_dtype in ["float32", "float64"]:
+            M, K, density = 3, 49, 0.2
+            X_np = np.random.randn(M, K).astype(data_dtype)
+            Y_sp_np = sp.random(M, K, density=density, format="csr", dtype=data_dtype)
+            Y_np = Y_sp_np.todense()
+            Z_np = X_np + Y_np
+
+            Y_data = te.placeholder(shape=Y_sp_np.data.shape, dtype=data_dtype)
+            Y_indices = te.placeholder(shape=Y_sp_np.indices.shape, dtype=indices_dtype)
+            Y_indptr = te.placeholder(shape=Y_sp_np.indptr.shape, dtype=indices_dtype)
+            X = te.placeholder(shape=X_np.shape, dtype=data_dtype)
+            Z = topi.nn.sparse_add(X, Y_data, Y_indices, Y_indptr)
+            s = te.create_schedule(Z.op)
+            func = tvm.build(s, [X, Y_data, Y_indices, Y_indptr, Z])
+            Z_tvm = tvm.nd.array(np.zeros(Z_np.shape, dtype=Z_np.dtype))
+            func(
+                tvm.nd.array(X_np.astype(data_dtype)),
+                tvm.nd.array(Y_sp_np.data.astype(data_dtype)),
+                tvm.nd.array(Y_sp_np.indices.astype(indices_dtype)),
+                tvm.nd.array(Y_sp_np.indptr.astype(indices_dtype)),
+                Z_tvm,
+            )
+            tvm.testing.assert_allclose(Z_tvm.asnumpy(), Z_np, atol=1e-4, rtol=1e-4)
+
+
 if __name__ == "__main__":
     test_csrmv()
     test_csrmm()
@@ -537,3 +564,4 @@ if __name__ == "__main__":
     test_sparse_dense_padded_alter_op()
     test_sparse_dense_csr_reverse()
     test_sparse_dense_bsr_reverse()
+    test_sparse_add_csr()