You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/04/10 20:17:48 UTC
git commit: SPARK-1428: MLlib should convert non-float64 NumPy arrays
to float64 instead of complaining
Repository: spark
Updated Branches:
refs/heads/master 79820fe82 -> 3bd312940
SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining
Author: Sandeep <sa...@techaddict.me>
Closes #356 from techaddict/1428 and squashes the following commits:
3bdf5f6 [Sandeep] SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3bd31294
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3bd31294
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3bd31294
Branch: refs/heads/master
Commit: 3bd312940e2f5250edaf3e88d6c23de25bb1d0a9
Parents: 79820fe
Author: Sandeep <sa...@techaddict.me>
Authored: Thu Apr 10 11:17:41 2014 -0700
Committer: Matei Zaharia <ma...@databricks.com>
Committed: Thu Apr 10 11:17:41 2014 -0700
----------------------------------------------------------------------
python/pyspark/mllib/_common.py | 18 ++++++++++++++----
1 file changed, 14 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/3bd31294/python/pyspark/mllib/_common.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index 20a0e30..7ef251d 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -15,8 +15,9 @@
# limitations under the License.
#
-from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
+from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape, complex, issubdtype
from pyspark import SparkContext, RDD
+import numpy as np
from pyspark.serializers import Serializer
import struct
@@ -47,13 +48,22 @@ def _deserialize_byte_array(shape, ba, offset):
return ar.copy()
def _serialize_double_vector(v):
- """Serialize a double vector into a mutually understood format."""
+ """Serialize a double vector into a mutually understood format.
+
+ >>> x = array([1,2,3])
+ >>> y = _deserialize_double_vector(_serialize_double_vector(x))
+ >>> array_equal(y, array([1.0, 2.0, 3.0]))
+ True
+ """
if type(v) != ndarray:
raise TypeError("_serialize_double_vector called on a %s; "
"wanted ndarray" % type(v))
+ """complex is only datatype that can't be converted to float64"""
+ if issubdtype(v.dtype, complex):
+ raise TypeError("_serialize_double_vector called on a %s; "
+ "wanted ndarray" % type(v))
if v.dtype != float64:
- raise TypeError("_serialize_double_vector called on an ndarray of %s; "
- "wanted ndarray of float64" % v.dtype)
+ v = v.astype(float64)
if v.ndim != 1:
raise TypeError("_serialize_double_vector called on a %ddarray; "
"wanted a 1darray" % v.ndim)