You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cj...@apache.org on 2017/11/05 15:36:46 UTC

[incubator-mxnet] 02/04: dtype default to source_array.dtype for sparse ndarrays (#8403)

This is an automated email from the ASF dual-hosted git repository.

cjolivier01 pushed a commit to branch v0.12.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit beaf8ec06c319b9872c74342b67230d0e0c2e33a
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Tue Oct 24 21:09:05 2017 -0700

    dtype default to source_array.dtype for sparse ndarrays (#8403)
    
    * derive default dtype/ctx from input for sparse ndarrays
    
    * add gpu tests
    
    * fix lint. add doc
    
    * remove default_ctx code
    
    * bug fix when passing dtype to array()
    
    * update doc
    
    * remove extra line
    
    * also check ctx
---
 python/mxnet/ndarray/sparse.py               | 96 +++++++++++++++++++---------
 tests/python/gpu/test_operator_gpu.py        |  2 +
 tests/python/unittest/test_sparse_ndarray.py | 95 ++++++++++++++++++++++-----
 3 files changed, 148 insertions(+), 45 deletions(-)

diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py
index a1a3ba8..a31ec01 100644
--- a/python/mxnet/ndarray/sparse.py
+++ b/python/mxnet/ndarray/sparse.py
@@ -727,6 +727,18 @@ def _prepare_src_array(source_array, dtype):
             raise TypeError('values must be array like object')
     return source_array
 
+def _prepare_default_dtype(src_array, dtype):
+    """Prepare the value of dtype if `dtype` is None. If `src_array` is an NDArray, numpy.ndarray
+    or scipy.sparse.csr.csr_matrix, return src_array.dtype. float32 is returned otherwise."""
+    if dtype is None:
+        if isinstance(src_array, (NDArray, np.ndarray)):
+            dtype = src_array.dtype
+        elif spsp and isinstance(src_array, spsp.csr.csr_matrix):
+            dtype = src_array.dtype
+        else:
+            dtype = mx_real_t
+    return dtype
+
 def _check_shape(s1, s2):
     """check s1 == s2 if both are not None"""
     if s1 and s2 and s1 != s2:
@@ -749,12 +761,11 @@ def csr_matrix(arg1, shape=None, ctx=None, dtype=None):
 
     - csr_matrix(S)
         to construct a CSRNDArray with a sparse 2D array ``S``
-            -  **S** (*CSRNDArray or scipy.sparse.csr_matrix*) - A sparse matrix.
+            -  **S** (*CSRNDArray or scipy.sparse.csr.csr_matrix*) - A sparse matrix.
             - **ctx** (*Context, optional*) - Device context \
             (default is the current default context).
             - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \
-            The default dtype is ``D.dtype`` if ``D`` is an NDArray or numpy.ndarray, \
-            float32 otherwise.
+            The default dtype is ``S.dtype``.
 
     - csr_matrix((M, N))
         to construct an empty CSRNDArray with shape ``(M, N)``
@@ -784,19 +795,20 @@ def csr_matrix(arg1, shape=None, ctx=None, dtype=None):
             - **ctx** (*Context, optional*) - Device context \
             (default is the current default context).
             - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \
-            The default dtype is float32.
+            The default dtype is ``data.dtype`` if ``data`` is an NDArray or numpy.ndarray, \
+            float32 otherwise.
 
     Parameters
     ----------
-    arg1: tuple of int, tuple of array_like, array_like, CSRNDArray or scipy.sparse.csr_matrix
+    arg1: NDArray, CSRNDArray, numpy.ndarray, scipy.sparse.csr.csr_matrix, tuple of int or tuple \
+    of array_like
         The argument to help instantiate the csr matrix. See above for further details.
-    shape : tuple of int
+    shape : tuple of int, optional
         The shape of the csr matrix.
     ctx: Context, optional
         Device context (default is the current default context).
     dtype: str or numpy.dtype, optional
-        The data type of the output array. The default dtype is ``values.dtype``
-        if `values` is an `NDArray`, `float32` otherwise.
+        The data type of the output array.
 
     Returns
     -------
@@ -839,7 +851,14 @@ def csr_matrix(arg1, shape=None, ctx=None, dtype=None):
             raise ValueError("Unexpected input type: RowSparseNDArray")
         else:
             # construct a csr matrix from a dense one
-            dns = _array(arg1, ctx=ctx, dtype=dtype)
+            # prepare default ctx and dtype since mx.nd.array doesn't use default values
+            # based on source_array
+            dtype = _prepare_default_dtype(arg1, dtype)
+            # create dns array with provided dtype. ctx is not passed since copy across
+            # ctx requires dtype to be the same
+            dns = _array(arg1, dtype=dtype)
+            if ctx is not None and dns.context != ctx:
+                dns = dns.as_in_context(ctx)
             _check_shape(dns.shape, shape)
             return dns.tostype('csr')
 
@@ -848,10 +867,9 @@ def _csr_matrix_from_definition(data, indices, indptr, shape=None, ctx=None,
     """Create a `CSRNDArray` based on data, indices and indptr"""
     storage_type = 'csr'
     # context
-    if ctx is None:
-        ctx = Context.default_ctx
+    ctx = Context.default_ctx if ctx is None else ctx
     # types
-    dtype = mx_real_t if dtype is None else dtype
+    dtype = _prepare_default_dtype(data, dtype)
     indptr_type = _STORAGE_AUX_TYPES[storage_type][0] if indptr_type is None else indptr_type
     indices_type = _STORAGE_AUX_TYPES[storage_type][1] if indices_type is None else indices_type
     # prepare src array and types
@@ -906,8 +924,7 @@ def row_sparse_array(arg1, shape=None, ctx=None, dtype=None):
             - **ctx** (*Context, optional*) - Device context \
             (default is the current default context).
             - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \
-            The default dtype is ``D.dtype`` if ``D`` is an NDArray or numpy.ndarray, \
-            float32 otherwise.
+            The default dtype is ``S.dtype``.
 
     - row_sparse_array((D0, D1 .. Dn))
         to construct an empty RowSparseNDArray with shape ``(D0, D1, ... Dn)``
@@ -931,20 +948,21 @@ def row_sparse_array(arg1, shape=None, ctx=None, dtype=None):
             stores the row index for each row slice with non-zero elements.
             - **shape** (*tuple of int, optional*) - The shape of the array. The default \
             shape is inferred from the indices and indptr arrays.
+            - **ctx** (*Context, optional*) - Device context \
+            (default is the current default context).
             - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \
             The default dtype is float32.
 
     Parameters
     ----------
-    arg1: tuple of int, tuple of array_like, array_like or RowSparseNDArray
+    arg1: NDArray, numpy.ndarray, RowSparseNDArray, tuple of int or tuple of array_like
         The argument to help instantiate the row sparse ndarray. See above for further details.
-    shape : tuple of int
+    shape : tuple of int, optional
         The shape of the row sparse ndarray.
     ctx : Context, optional
         Device context (default is the current default context).
     dtype : str or numpy.dtype, optional
-        The data type of the output array. The default dtype is ``data.dtype``
-        if `data` is an `NDArray`, `float32` otherwise.
+        The data type of the output array.
 
     Returns
     -------
@@ -995,7 +1013,14 @@ def row_sparse_array(arg1, shape=None, ctx=None, dtype=None):
             raise ValueError("Unexpected input type: CSRNDArray")
         else:
             # construct a csr matrix from a dense one
-            dns = _array(arg1, ctx=ctx, dtype=dtype)
+            # prepare default dtype since mx.nd.array doesn't use default values
+            # based on source_array
+            dtype = _prepare_default_dtype(arg1, dtype)
+            # create dns array with provided dtype. ctx is not passed since copy across
+            # ctx requires dtype to be the same
+            dns = _array(arg1, dtype=dtype)
+            if ctx is not None and dns.context != ctx:
+                dns = dns.as_in_context(ctx)
             _check_shape(dns.shape, shape)
             return dns.tostype('row_sparse')
 
@@ -1004,10 +1029,9 @@ def _row_sparse_ndarray_from_definition(data, indices, shape=None, ctx=None,
     """Create a `RowSparseNDArray` based on data and indices"""
     storage_type = 'row_sparse'
     # context
-    if ctx is None:
-        ctx = Context.default_ctx
+    ctx = Context.default_ctx if ctx is None else ctx
     # types
-    dtype = mx_real_t if dtype is None else dtype
+    dtype = _prepare_default_dtype(data, dtype)
     indices_type = _STORAGE_AUX_TYPES[storage_type][0] if indices_type is None else indices_type
     # prepare src array and types
     data = _prepare_src_array(data, dtype)
@@ -1022,7 +1046,9 @@ def _row_sparse_ndarray_from_definition(data, indices, shape=None, ctx=None,
         indices = _array(indices, ctx, indices_type)
     if shape is None:
         num_indices = indices.shape[0]
-        dim0 = 0 if num_indices == 0 else indices[num_indices - 1].asscalar() + 1
+        if num_indices == 0:
+            raise ValueError('invalid shape')
+        dim0 = indices[num_indices - 1].asscalar() + 1
         shape = (dim0, ) + data.shape[1:]
     # verify shapes
     if data.ndim != len(shape) or indices.ndim != 1 or np.prod(shape[1:]) == 0:
@@ -1127,10 +1153,12 @@ def array(source_array, ctx=None, dtype=None):
     source_array : RowSparseNDArray, CSRNDArray or scipy.sparse.csr.csr_matrix
         The source sparse array
     ctx : Context, optional
-        Device context (default is the current default context).
+        The default context is ``source_array.context`` if ``source_array`` is an NDArray. \
+        The current default context otherwise.
     dtype : str or numpy.dtype, optional
         The data type of the output array. The default dtype is ``source_array.dtype``
-        if `source_array` is an `NDArray`, `float32` otherwise.
+        if `source_array` is an `NDArray`, `numpy.ndarray` or `scipy.sparse.csr.csr_matrix`, \
+        `float32` otherwise.
 
     Returns
     -------
@@ -1148,19 +1176,29 @@ def array(source_array, ctx=None, dtype=None):
     >>> mx.nd.sparse.array(mx.nd.sparse.zeros('row_sparse', (3, 2)))
     <RowSparseNDArray 3x2 @cpu(0)>
     """
+    ctx = Context.default_ctx if ctx is None else ctx
     if isinstance(source_array, NDArray):
         assert(source_array.stype != 'default'), \
                "Please use `tostype` to create RowSparseNDArray or CSRNDArray from an NDArray"
-        dtype = source_array.dtype if dtype is None else dtype
-        arr = empty(source_array.stype, source_array.shape, ctx=ctx, dtype=dtype)
-        arr[:] = source_array
+        # prepare dtype and ctx based on source_array, if not provided
+        dtype = _prepare_default_dtype(source_array, dtype)
+        # if both dtype and ctx are different from source_array, we cannot copy directly
+        if source_array.dtype != dtype and source_array.context != ctx:
+            arr = empty(source_array.stype, source_array.shape, dtype=dtype)
+            arr[:] = source_array
+            arr = arr.as_in_context(ctx)
+        else:
+            arr = empty(source_array.stype, source_array.shape, dtype=dtype, ctx=ctx)
+            arr[:] = source_array
         return arr
     elif spsp and isinstance(source_array, spsp.csr.csr_matrix):
         # TODO(haibin) implement `_sync_copy_from` with scipy csr object to reduce a copy
         # preprocess scipy csr to canonical form
         csr = source_array.sorted_indices()
         csr.sum_duplicates()
-        return csr_matrix((csr.data, csr.indices, csr.indptr), shape=csr.shape, dtype=dtype)
+        dtype = _prepare_default_dtype(source_array, dtype)
+        return csr_matrix((csr.data, csr.indices, csr.indptr), shape=csr.shape, \
+                          dtype=dtype, ctx=ctx)
     elif isinstance(source_array, (np.ndarray, np.generic)):
         raise ValueError("Please use mx.nd.array to create an NDArray with source_array of type ",
                          type(source_array))
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index b1f43f3..4126a8e 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -35,6 +35,8 @@ from test_loss import *
 #from test_rnn import *
 from test_gluon_rnn import *
 from test_sparse_ndarray import test_create_csr, test_create_row_sparse, test_sparse_nd_slice
+from test_sparse_ndarray import test_create_sparse_nd_empty, test_create_sparse_nd_from_sparse
+from test_sparse_ndarray import test_create_sparse_nd_from_dense, test_create_sparse_nd_infer_shape
 from test_sparse_operator import *
 from test_ndarray import *
 
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 8da24ef..0bbb5ff 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -19,6 +19,7 @@ import pickle as pkl
 
 from mxnet.ndarray import NDArray
 from mxnet.test_utils import *
+from mxnet.base import mx_real_t
 from numpy.testing import assert_allclose
 import numpy.random as rnd
 
@@ -463,9 +464,10 @@ def test_sparse_nd_unsupported():
             pass
 
 def test_create_csr():
-    def check_create_csr_from_nd(shape, density):
+    def check_create_csr_from_nd(shape, density, dtype):
         matrix = rand_ndarray(shape, 'csr', density)
-        data = matrix.data
+        # create data array with provided dtype and ctx
+        data = mx.nd.array(matrix.data.asnumpy(), dtype=dtype)
         indptr = matrix.indptr
         indices = matrix.indices
         csr_created = mx.nd.sparse.csr_matrix((data, indices, indptr), shape=shape)
@@ -473,6 +475,10 @@ def test_create_csr():
         assert same(csr_created.data.asnumpy(), data.asnumpy())
         assert same(csr_created.indptr.asnumpy(), indptr.asnumpy())
         assert same(csr_created.indices.asnumpy(), indices.asnumpy())
+        # verify csr matrix dtype and ctx is consistent from the ones provided
+        assert csr_created.dtype == dtype
+        assert csr_created.data.dtype == dtype
+        assert csr_created.context == Context.default_ctx
         csr_copy = mx.nd.array(csr_created)
         assert(same(csr_copy.asnumpy(), csr_created.asnumpy()))
 
@@ -481,6 +487,7 @@ def test_create_csr():
             assert_almost_equal(nd.data.asnumpy(), sp.data)
             assert_almost_equal(nd.indptr.asnumpy(), sp.indptr)
             assert_almost_equal(nd.indices.asnumpy(), sp.indices)
+
         try:
             import scipy.sparse as spsp
             # random canonical csr
@@ -500,12 +507,13 @@ def test_create_csr():
         except ImportError:
             print("Could not import scipy.sparse. Skipping unit tests for scipy csr creation")
 
-    dim0 = 50
-    dim1 = 50
+    dim0 = 20
+    dim1 = 20
     densities = [0, 0.5]
+    dtype = np.float64
     for density in densities:
         shape = rand_shape_2d(dim0, dim1)
-        check_create_csr_from_nd(shape, density)
+        check_create_csr_from_nd(shape, density, dtype)
         check_create_csr_from_scipy(shape, density, mx.nd.sparse.array)
         check_create_csr_from_scipy(shape, density, mx.nd.array)
 
@@ -569,16 +577,57 @@ def test_create_sparse_nd_infer_shape():
         check_create_rsp_infer_shape(shape_3d, density, dtype)
 
 def test_create_sparse_nd_from_dense():
-    def check_create_from_dns(shape, f, dense_arr, dtype):
-        arr = f(dense_arr, dtype=dtype)
+    def check_create_from_dns(shape, f, dense_arr, dtype, default_dtype, ctx):
+        arr = f(dense_arr, dtype=dtype, ctx=ctx)
         assert(same(arr.asnumpy(), np.ones(shape)))
         assert(arr.dtype == dtype)
+        assert(arr.context == ctx)
+        # verify the default dtype inferred from dense arr
+        arr2 = f(dense_arr)
+        assert(arr2.dtype == default_dtype)
+        assert(arr2.context == Context.default_ctx)
     shape = rand_shape_2d()
     dtype = np.int32
-    dense_arrs = [mx.nd.ones(shape), np.ones(shape), np.ones(shape).tolist()]
+    src_dtype = np.float64
+    ctx = mx.cpu(1)
+    dense_arrs = [mx.nd.ones(shape, dtype=src_dtype), np.ones(shape, dtype=src_dtype), \
+                  np.ones(shape, dtype=src_dtype).tolist()]
     for f in [mx.nd.sparse.csr_matrix, mx.nd.sparse.row_sparse_array]:
         for dense_arr in dense_arrs:
-            check_create_from_dns(shape, f, dense_arr, dtype)
+            default_dtype = dense_arr.dtype if isinstance(dense_arr, (NDArray, np.ndarray)) \
+                            else np.float32
+            check_create_from_dns(shape, f, dense_arr, dtype, default_dtype, ctx)
+
+def test_create_sparse_nd_from_sparse():
+    def check_create_from_sp(shape, f, sp_arr, dtype, src_dtype, ctx):
+        arr = f(sp_arr, dtype=dtype, ctx=ctx)
+        assert(same(arr.asnumpy(), np.ones(shape)))
+        assert(arr.dtype == dtype)
+        assert(arr.context == ctx)
+        # verify the default dtype inferred from dense arr
+        arr2 = f(sp_arr)
+        assert(arr2.dtype == src_dtype)
+        assert(arr2.context == Context.default_ctx)
+
+    shape = rand_shape_2d()
+    src_dtype = np.float64
+    dtype = np.int32
+    ctx = mx.cpu(1)
+    ones = mx.nd.ones(shape, dtype=src_dtype)
+    csr_arrs = [ones.tostype('csr')]
+    rsp_arrs = [ones.tostype('row_sparse')]
+    try:
+        import scipy.sparse as spsp
+        csr_sp = spsp.csr_matrix(np.ones(shape, dtype=src_dtype))
+        csr_arrs.append(csr_sp)
+    except ImportError:
+        print("Could not import scipy.sparse. Skipping unit tests for scipy csr creation")
+    f_csr = mx.nd.sparse.csr_matrix
+    f_rsp = mx.nd.sparse.row_sparse_array
+    for sp_arr in csr_arrs:
+        check_create_from_sp(shape, f_csr, sp_arr, dtype, src_dtype, ctx)
+    for sp_arr in rsp_arrs:
+        check_create_from_sp(shape, f_rsp, sp_arr, dtype, src_dtype, ctx)
 
 def test_create_sparse_nd_empty():
     def check_empty(shape, stype):
@@ -586,24 +635,38 @@ def test_create_sparse_nd_empty():
         assert(arr.stype == stype)
         assert same(arr.asnumpy(), np.zeros(shape))
 
-    def check_csr_empty(shape):
-        arr = mx.nd.sparse.csr_matrix(shape)
+    def check_csr_empty(shape, dtype, ctx):
+        arr = mx.nd.sparse.csr_matrix(shape, dtype=dtype, ctx=ctx)
         assert(arr.stype == 'csr')
+        assert(arr.dtype == dtype)
+        assert(arr.context == ctx)
         assert same(arr.asnumpy(), np.zeros(shape))
+        # check the default value for dtype and ctx
+        arr = mx.nd.sparse.csr_matrix(shape)
+        assert(arr.dtype == np.float32)
+        assert(arr.context == Context.default_ctx)
 
-    def check_rsp_empty(shape):
-        arr = mx.nd.sparse.row_sparse_array(shape)
+    def check_rsp_empty(shape, dtype, ctx):
+        arr = mx.nd.sparse.row_sparse_array(shape, dtype=dtype, ctx=ctx)
         assert(arr.stype == 'row_sparse')
+        assert(arr.dtype == dtype)
+        assert(arr.context == ctx)
         assert same(arr.asnumpy(), np.zeros(shape))
+        # check the default value for dtype and ctx
+        arr = mx.nd.sparse.row_sparse_array(shape)
+        assert(arr.dtype == np.float32)
+        assert(arr.context == Context.default_ctx)
 
     stypes = ['csr', 'row_sparse']
     shape = rand_shape_2d()
     shape_3d = rand_shape_3d()
+    dtype = np.int32
+    ctx = mx.cpu(1)
     for stype in stypes:
         check_empty(shape, stype)
-    check_csr_empty(shape)
-    check_rsp_empty(shape)
-    check_rsp_empty(shape_3d)
+    check_csr_empty(shape, dtype, ctx)
+    check_rsp_empty(shape, dtype, ctx)
+    check_rsp_empty(shape_3d, dtype, ctx)
 
 def test_synthetic_dataset_generator():
     def test_powerlaw_generator(csr_arr, final_row=1):

-- 
To stop receiving notification emails like this one, please contact
"commits@mxnet.apache.org" <co...@mxnet.apache.org>.