You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/04/10 20:17:48 UTC

git commit: SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining

Repository: spark
Updated Branches:
  refs/heads/master 79820fe82 -> 3bd312940


SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining

Author: Sandeep <sa...@techaddict.me>

Closes #356 from techaddict/1428 and squashes the following commits:

3bdf5f6 [Sandeep] SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3bd31294
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3bd31294
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3bd31294

Branch: refs/heads/master
Commit: 3bd312940e2f5250edaf3e88d6c23de25bb1d0a9
Parents: 79820fe
Author: Sandeep <sa...@techaddict.me>
Authored: Thu Apr 10 11:17:41 2014 -0700
Committer: Matei Zaharia <ma...@databricks.com>
Committed: Thu Apr 10 11:17:41 2014 -0700

----------------------------------------------------------------------
 python/pyspark/mllib/_common.py | 18 ++++++++++++++----
 1 file changed, 14 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3bd31294/python/pyspark/mllib/_common.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index 20a0e30..7ef251d 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -15,8 +15,9 @@
 # limitations under the License.
 #
 
-from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
+from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape, complex, issubdtype
 from pyspark import SparkContext, RDD
+import numpy as np
 
 from pyspark.serializers import Serializer
 import struct
@@ -47,13 +48,22 @@ def _deserialize_byte_array(shape, ba, offset):
     return ar.copy()
 
 def _serialize_double_vector(v):
-    """Serialize a double vector into a mutually understood format."""
+    """Serialize a double vector into a mutually understood format.
+
+    >>> x = array([1,2,3])
+    >>> y = _deserialize_double_vector(_serialize_double_vector(x))
+    >>> array_equal(y, array([1.0, 2.0, 3.0]))
+    True
+    """
     if type(v) != ndarray:
         raise TypeError("_serialize_double_vector called on a %s; "
                 "wanted ndarray" % type(v))
+    """complex is only datatype that can't be converted to float64"""
+    if issubdtype(v.dtype, complex):
+        raise TypeError("_serialize_double_vector called on a %s; "
+                "wanted ndarray" % type(v))
     if v.dtype != float64:
-        raise TypeError("_serialize_double_vector called on an ndarray of %s; "
-                "wanted ndarray of float64" % v.dtype)
+        v = v.astype(float64)
     if v.ndim != 1:
         raise TypeError("_serialize_double_vector called on a %ddarray; "
                 "wanted a 1darray" % v.ndim)