You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/05/15 19:43:21 UTC

spark git commit: [SPARK-7651] [MLLIB] [PYSPARK] GMM predict, predictSoft should raise error on bad input

Repository: spark
Updated Branches:
  refs/heads/master f96b85ab4 -> 8f4aaba0e


[SPARK-7651] [MLLIB] [PYSPARK] GMM predict, predictSoft should raise error on bad input

In the Python API for Gaussian Mixture Model, predict() and predictSoft() methods should raise an error when the input argument is not an RDD.

Author: FlytxtRnD <me...@flytxt.com>

Closes #6180 from FlytxtRnD/GmmPredictException and squashes the following commits:

4b6aa11 [FlytxtRnD] Raise error if the input to predict()/predictSoft() is not an RDD


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8f4aaba0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8f4aaba0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8f4aaba0

Branch: refs/heads/master
Commit: 8f4aaba0e4e3350ab152a476d08ff60e9495c6d2
Parents: f96b85a
Author: FlytxtRnD <me...@flytxt.com>
Authored: Fri May 15 10:43:18 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Fri May 15 10:43:18 2015 -0700

----------------------------------------------------------------------
 python/pyspark/mllib/clustering.py | 6 ++++++
 1 file changed, 6 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8f4aaba0/python/pyspark/mllib/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index a53333d..b55583f 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -212,6 +212,9 @@ class GaussianMixtureModel(object):
         if isinstance(x, RDD):
             cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
             return cluster_labels
+        else:
+            raise TypeError("x should be represented by an RDD, "
+                            "but got %s." % type(x))
 
     def predictSoft(self, x):
         """
@@ -225,6 +228,9 @@ class GaussianMixtureModel(object):
             membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
                                               _convert_to_vector(self._weights), means, sigmas)
             return membership_matrix.map(lambda x: pyarray.array('d', x))
+        else:
+            raise TypeError("x should be represented by an RDD, "
+                            "but got %s." % type(x))
 
 
 class GaussianMixture(object):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org