You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/03/11 21:02:16 UTC
[tvm] branch main updated: [TOPI] Sparse Add Op added (#7435)
This is an automated email from the ASF dual-hosted git repository.
mbaret pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 017ff94 [TOPI] Sparse Add Op added (#7435)
017ff94 is described below
commit 017ff94d15df85ea8476f8ad3ce234470072ae84
Author: ANSHUMAN TRIPATHY <an...@huawei.com>
AuthorDate: Fri Mar 12 02:31:59 2021 +0530
[TOPI] Sparse Add Op added (#7435)
* [TOPI] Sparse Add Op added
* lint resolved
* TF frontend support added
* Test case added
* [1] Review comment handled
* [2] Review comment handled
* [3] Review comment handled
* [4] Review comment handled
* [5] Review comment handled
---
python/tvm/relay/frontend/tensorflow.py | 35 ++++++++++++
python/tvm/relay/op/nn/_nn.py | 5 ++
python/tvm/relay/op/nn/nn.py | 47 ++++++++++++++++
python/tvm/relay/op/strategy/generic.py | 23 ++++++++
python/tvm/topi/nn/sparse.py | 69 ++++++++++++++++++++++++
src/relay/op/nn/sparse.cc | 41 ++++++++++++++
tests/python/frontend/tensorflow/test_forward.py | 48 +++++++++++++++++
tests/python/topi/python/test_topi_sparse.py | 28 ++++++++++
8 files changed, 296 insertions(+)
diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index c79c495..f56d187 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1286,6 +1286,40 @@ def _sparse_segment_mean_with_num_segments():
return _impl
+def _sparse_tensor_dense_add():
+ # Sparse utility from scipy
+ from scipy.sparse import csr_matrix
+
+ def _impl(inputs, attr, params, mod):
+ assert (
+ len(inputs) == 4
+ ), "There should be 4 input tensors [sparse_indices, sparse_values, sparse_shape, dense]."
+
+ indices_tensor = _infer_value(inputs[0], params, mod).asnumpy()
+ values_tensor = _infer_value(inputs[1], params, mod).asnumpy()
+ dense_shape_tensor = _infer_value(inputs[2], params, mod).asnumpy()
+
+ data = inputs[3]
+
+ rows = [x[0] for x in indices_tensor]
+ cols = [x[1] for x in indices_tensor]
+
+ # Create scipy sparse Tensor(CSR)
+ weight_sp = csr_matrix(
+ (values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist())
+ )
+
+ weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype)
+ weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype)
+ weight_indices = _expr.const(weight_sp.indices, weight_sp.indices.dtype)
+
+ ret = _op.nn.sparse_add(data, [weight_data, weight_indices, weight_indptrs])
+
+ return ret
+
+ return _impl
+
+
def _identity():
def _impl(inputs, attr, params, mod):
return inputs[0]
@@ -2787,6 +2821,7 @@ _convert_map = {
"SparseSegmentSqrtNWithNumSegments": _sparse_segment_sum_sqrtn_with_num_segments(),
"SparseSegmentMean": _sparse_segment_mean(),
"SparseSegmentMeanWithNumSegments": _sparse_segment_mean_with_num_segments(),
+ "SparseTensorDenseAdd": _sparse_tensor_dense_add(),
"Split": _split(False),
"SplitV": _split(True),
"Sqrt": AttrCvt("sqrt"),
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 6ae86c0..af64873 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -142,6 +142,11 @@ def alter_op_layout_sparse_dense(attrs, inputs, tinfos, out_type):
return topi.nn.sparse_dense_alter_layout(attrs, inputs, tinfos, out_type)
+# sparse_add
+reg.register_strategy("nn.sparse_add", strategy.sparse_add_strategy)
+reg.register_pattern("nn.sparse_add", reg.OpPattern.OPAQUE)
+
+
@reg.register_compute("nn.internal.sparse_dense_padded")
def compute_sparse_dense_padded(attrs, inputs, out_type):
"""Compute definition of sparse_dense_padded"""
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index 5135ac7..a1147fe 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -2148,6 +2148,53 @@ def sparse_transpose(x):
return expr.TupleWrapper(_make.sparse_transpose(x[0], x[1], x[2]), 3)
+# pylint: disable=no-else-return,inconsistent-return-statements
+def sparse_add(dense_mat, sparse_mat):
+ r"""
+ Computes the matrix addition of `dense_mat` and `sparse_mat`, where `dense_mat` is
+ a dense matrix and `sparse_mat` is a sparse (CSR) namedtuple with
+ fields `data`, `indices`, and `indptr`.
+
+ .. math::
+
+ \mbox{sparse_add}(dense_mat, sparse_mat)[m, n] = \mbox{add}(\mbox{as_dense}(S), (D))[m, n]
+
+ where `as_dense` returns dense equivalent of the given S(sparse matrix)
+ while performing addition with given D(dense matrix).
+
+ Parameters
+ ----------
+ dense_mat : tvm.relay.Expr
+ The input dense matrix for the matrix addition
+
+ sparse_mat : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
+ The input sparse matrix(CSR) for the matrix addition.
+
+ Returns
+ -------
+ result: tvm.relay.Expr
+ The computed result.
+
+ Examples
+ -------
+ .. code-block:: python
+ dense_data = [[ 3., 4., 4. ]
+ [ 4., 2., 5. ]]
+ sparse_data = [4., 8.]
+ sparse_indices =[0, 2]
+ sparse_indptr =[0, 1, 2]
+
+ output = relay.sparse_add(dense_data, sparse_data, sparse_indices, sparse_indptr)
+
+ output = [[ 7., 4., 4. ]
+ [ 4., 2., 13. ]]
+ """
+ if hasattr(sparse_mat, "indices"):
+ return _make.sparse_add(dense_mat, sparse_mat.data, sparse_mat.indices, sparse_mat.indptr)
+ else:
+ return _make.sparse_add(dense_mat, sparse_mat[0], sparse_mat[1], sparse_mat[2])
+
+
def contrib_conv2d_winograd_without_weight_transform(
data,
weight,
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index be86ea9..04f2564 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -799,6 +799,29 @@ def sparse_dense_padded_strategy(attrs, inputs, out_type, target):
raise NotImplementedError("sparse_dense_padded is only implemented for cuda")
+# sparse_add
+def wrap_compute_sparse_add(topi_compute):
+ """wrap sparse add topi compute"""
+
+ def _compute_sparse_add(attrs, inputs, out_type):
+ return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3])]
+
+ return _compute_sparse_add
+
+
+@override_native_generic_func("sparse_add_strategy")
+def sparse_add_strategy(attrs, inputs, out_type, target):
+ """sparse add generic strategy"""
+ logger.warning("sparse add is not optimized for this platform.")
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_compute_sparse_add(topi.nn.sparse_add),
+ wrap_topi_schedule(topi.generic.schedule_extern),
+ name="sparse_add.generic",
+ )
+ return strategy
+
+
# sparse_transpose
@generic_func
def schedule_sparse_transpose(attrs, outs, target):
diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py
index 1bf18df..7561106 100644
--- a/python/tvm/topi/nn/sparse.py
+++ b/python/tvm/topi/nn/sparse.py
@@ -468,3 +468,72 @@ def try_get_sparse_input(args):
sparse_input_map[sparse_indptr] = sparse_prefix + "W_indptr"
return sparse_input_map
+
+
+def sparse_add(dense_data, sparse_data, sparse_indices, sparse_indptr):
+ """
+ Computes sparse-dense addition
+
+ Parameters
+ ----------
+ dense_data : tvm.te.Tensor
+ 2-D with shape [M, N]
+
+ sparse_data : tvm.te.Tensor
+ 1-D with shape [nnz] (CSR)
+
+ sparse_indices : tvm.te.Tensor
+ 1-D with shape [nnz] (CSR)
+
+ sparse_indptr : tvm.te.Tensor
+ 1-D with shape [M + 1] (CSR)
+
+ Returns
+ -------
+ output : tvm.te.Tensor
+ 2-D with shape [M, N]
+ """
+ # TODO(ANSHUMAN87): support BSR format too
+ assert len(sparse_data.shape) == 1, "only CSR format is supported"
+ return _sparse_add_csr(dense_data, sparse_data, sparse_indices, sparse_indptr)
+
+
+def _sparse_add_csr(dense_data_inp, sparse_data_inp, sparse_indices_inp, sparse_indptr_inp):
+ oshape = get_const_tuple(dense_data_inp.shape)
+
+ def _csr_add_ir(dense_data, sparse_data, sparse_indices, sparse_indptr, out_data):
+ irb = tvm.tir.ir_builder.create()
+ dense_data_ptr = irb.buffer_ptr(dense_data)
+ sparse_data_ptr = irb.buffer_ptr(sparse_data)
+ sparse_indices_ptr = irb.buffer_ptr(sparse_indices)
+ sparse_indptr_ptr = irb.buffer_ptr(sparse_indptr)
+
+ out_data_ptr = irb.buffer_ptr(out_data)
+
+ with irb.for_range(0, oshape[0], kind="vectorize", name="row") as row:
+ with irb.for_range(0, oshape[1], kind="parallel", name="col") as col:
+ out_data_ptr[row, col] = dense_data_ptr[row, col]
+
+ with irb.for_range(0, oshape[0], kind="parallel", name="row") as row:
+ offset = sparse_indptr_ptr[row]
+ diff = sparse_indptr_ptr[row + 1] - sparse_indptr_ptr[row]
+ with irb.for_range(0, diff, kind="serial", name="idx") as idx:
+ real_idx = offset + idx
+ col = sparse_indices_ptr[real_idx]
+ out_data_ptr[row, col] = sparse_data_ptr[real_idx] + out_data_ptr[row, col]
+
+ return irb.get()
+
+ return te.extern(
+ shape=oshape,
+ inputs=[dense_data_inp, sparse_data_inp, sparse_indices_inp, sparse_indptr_inp],
+ fcompute=lambda ins, outs: _csr_add_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
+ tag="sparse_add_csr",
+ dtype=[
+ dense_data_inp.dtype,
+ sparse_data_inp.dtype,
+ sparse_indices_inp.dtype,
+ sparse_indptr_inp.dtype,
+ ],
+ name="sparse_add_csr_output",
+ )
diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc
index 6322cff..b1a16f1 100644
--- a/src/relay/op/nn/sparse.cc
+++ b/src/relay/op/nn/sparse.cc
@@ -196,5 +196,46 @@ RELAY_REGISTER_OP("nn.sparse_transpose")
.set_support_level(1)
.add_type_rel("SparseTranspose", SparseTransposeRel);
+// relay.nn.sparse_add
+bool SparseAddRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ ICHECK_EQ(types.size(), 5) << "expecting 4 inputs and 1 output.";
+ const auto* dense_data = types[0].as<TensorTypeNode>();
+ const auto* sparse_data = types[1].as<TensorTypeNode>();
+ ICHECK(reporter->Assert(sparse_data->dtype == dense_data->dtype))
+ << "sparse tensor and dense tensor datatype should match.";
+ ICHECK(reporter->Assert(sparse_data->shape.size() == 1)) << "sparse data tensor should be 1D.";
+ const auto* sparse_indices = types[2].as<TensorTypeNode>();
+ ICHECK(reporter->Assert(sparse_indices->shape.size() == 1))
+ << "sparse indices tensor should be 1D.";
+
+ reporter->Assign(types[4], TensorType(dense_data->shape, dense_data->dtype));
+ return true;
+}
+
+Expr MakeSparseAdd(Expr dense_data, Expr sparse_data, Expr sparse_indices, Expr sparse_indptr) {
+ static const Op& op = Op::Get("nn.sparse_add");
+ return Call(op, {dense_data, sparse_data, sparse_indices, sparse_indptr}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_add").set_body_typed(MakeSparseAdd);
+
+RELAY_REGISTER_OP("nn.sparse_add")
+ .describe(R"code(Add a dense matrix X with sparse matrix Y.
+
+- **dense**: `(M, N)`
+- **sparse**: `(M, N)`
+
+- **out**: `(M, N)`.
+
+)code" TVM_ADD_FILELINE)
+ .set_num_inputs(4)
+ .add_argument("dense_data", "2D Tensor", "Dense data matrix.")
+ .add_argument("sparse_data", "1D Tensor", "Sparse data vector.")
+ .add_argument("sparse_indices", "1D Tensor", "Sparse indices vector.")
+ .add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer vector.")
+ .set_support_level(1)
+ .add_type_rel("SparseAdd", SparseAddRel);
+
} // namespace relay
} // namespace tvm
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 81aeb5e..fa27dee 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -2353,6 +2353,54 @@ def test_forward_sparse_to_dense_v2():
#######################################################################
+# tensorflow.sparse.add
+# ----------------------------------
+
+
+def _test_sparse_add(indices, values, A_shape, B_shape, dtype, flip=False):
+ """ One iteration of tf.sparse.add """
+
+ # TODO(ANSHUMAN87): support cuda
+ # TODO(ANSHUMAN87): support both sparse input case
+
+ with tf.Graph().as_default():
+ A_sp = tf.sparse.SparseTensor(
+ indices=indices, values=np.array(values).astype(dtype), dense_shape=A_shape
+ )
+ B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")
+
+ # TODO(ANSHUMAN87): support user input threashold values
+ if flip:
+ result = tf.sparse.add(B, A_sp, threshold=0)
+ else:
+ result = tf.sparse.add(A_sp, B, threshold=0)
+
+ B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
+
+ compare_tf_with_tvm([B_np], [B.name], result.name, no_gpu=True)
+
+
+def test_sparse_add():
+ """ sparse.add op test"""
+ ###################################################################
+ #
+ # In order to create a SparseTensor, it requires 3 input as below:
+ # SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
+ #
+ # Above Sparse can be represented in Dense as below :
+ # [[1, 0, 0, 0]
+ # [0, 0, 2, 0]
+ # [0, 0, 0, 0]]
+ #
+ # ------------------------------------------------------------------
+ for dtype_inp in ["float32", "float64", "int32"]:
+ _test_sparse_add([[0, 0], [1, 2]], [4.0, 8.0], [3, 4], [3, 4], dtype_inp)
+ _test_sparse_add([[0, 0], [1, 2]], [4.0, 8.0], [3, 4], [3, 4], dtype_inp, True)
+ _test_sparse_add([[0, 0], [1, 3], [4, 3]], [3.0, 6.0, 9.0], [5, 5], [5, 5], dtype_inp)
+ _test_sparse_add([[0, 0], [1, 3], [4, 3]], [3.0, 6.0, 9.0], [5, 5], [5, 5], dtype_inp, True)
+
+
+#######################################################################
# StridedSlice
# ------------
diff --git a/tests/python/topi/python/test_topi_sparse.py b/tests/python/topi/python/test_topi_sparse.py
index d5bd7aa..d84bd15 100644
--- a/tests/python/topi/python/test_topi_sparse.py
+++ b/tests/python/topi/python/test_topi_sparse.py
@@ -526,6 +526,33 @@ def test_sparse_dense_padded_alter_op():
x = relay.build(tvm.IRModule.from_expr(f), target=tvm.target.Target("cuda"))
+def test_sparse_add_csr():
+ for indices_dtype in ["int32", "int64"]:
+ for data_dtype in ["float32", "float64"]:
+ M, K, density = 3, 49, 0.2
+ X_np = np.random.randn(M, K).astype(data_dtype)
+ Y_sp_np = sp.random(M, K, density=density, format="csr", dtype=data_dtype)
+ Y_np = Y_sp_np.todense()
+ Z_np = X_np + Y_np
+
+ Y_data = te.placeholder(shape=Y_sp_np.data.shape, dtype=data_dtype)
+ Y_indices = te.placeholder(shape=Y_sp_np.indices.shape, dtype=indices_dtype)
+ Y_indptr = te.placeholder(shape=Y_sp_np.indptr.shape, dtype=indices_dtype)
+ X = te.placeholder(shape=X_np.shape, dtype=data_dtype)
+ Z = topi.nn.sparse_add(X, Y_data, Y_indices, Y_indptr)
+ s = te.create_schedule(Z.op)
+ func = tvm.build(s, [X, Y_data, Y_indices, Y_indptr, Z])
+ Z_tvm = tvm.nd.array(np.zeros(Z_np.shape, dtype=Z_np.dtype))
+ func(
+ tvm.nd.array(X_np.astype(data_dtype)),
+ tvm.nd.array(Y_sp_np.data.astype(data_dtype)),
+ tvm.nd.array(Y_sp_np.indices.astype(indices_dtype)),
+ tvm.nd.array(Y_sp_np.indptr.astype(indices_dtype)),
+ Z_tvm,
+ )
+ tvm.testing.assert_allclose(Z_tvm.asnumpy(), Z_np, atol=1e-4, rtol=1e-4)
+
+
if __name__ == "__main__":
test_csrmv()
test_csrmm()
@@ -537,3 +564,4 @@ if __name__ == "__main__":
test_sparse_dense_padded_alter_op()
test_sparse_dense_csr_reverse()
test_sparse_dense_bsr_reverse()
+ test_sparse_add_csr()