You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/01/11 23:43:30 UTC

spark git commit: [SPARK-12603][MLLIB] PySpark MLlib GaussianMixtureModel should support single instance predict/predictSoft

Repository: spark
Updated Branches:
  refs/heads/master a767ee8a0 -> ee4ee02b8


[SPARK-12603][MLLIB] PySpark MLlib GaussianMixtureModel should support single instance predict/predictSoft

PySpark MLlib ```GaussianMixtureModel``` should support single instance ```predict/predictSoft``` just like Scala do.

Author: Yanbo Liang <yb...@gmail.com>

Closes #10552 from yanboliang/spark-12603.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ee4ee02b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ee4ee02b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ee4ee02b

Branch: refs/heads/master
Commit: ee4ee02b86be8756a6d895a2e23e80862134a6d3
Parents: a767ee8
Author: Yanbo Liang <yb...@gmail.com>
Authored: Mon Jan 11 14:43:25 2016 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Jan 11 14:43:25 2016 -0800

----------------------------------------------------------------------
 .../main/python/mllib/gaussian_mixture_model.py |  4 +++
 .../examples/mllib/DenseGaussianMixture.scala   |  6 ++++
 .../python/GaussianMixtureModelWrapper.scala    |  4 +++
 .../mllib/clustering/GaussianMixtureModel.scala |  2 +-
 python/pyspark/mllib/clustering.py              | 35 ++++++++++++--------
 5 files changed, 37 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ee4ee02b/examples/src/main/python/mllib/gaussian_mixture_model.py
----------------------------------------------------------------------
diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py
index 2cb8010..69e836f 100644
--- a/examples/src/main/python/mllib/gaussian_mixture_model.py
+++ b/examples/src/main/python/mllib/gaussian_mixture_model.py
@@ -62,5 +62,9 @@ if __name__ == "__main__":
     for i in range(args.k):
         print(("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu,
                "sigma = ", model.gaussians[i].sigma.toArray()))
+    print("\n")
+    print(("The membership value of each vector to all mixture components (first 100): ",
+           model.predictSoft(data).take(100)))
+    print("\n")
     print(("Cluster labels (first 100): ", model.predict(data).take(100)))
     sc.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/ee4ee02b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
index 1fce4ba..90b817b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
@@ -58,6 +58,12 @@ object DenseGaussianMixture {
         (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma))
     }
 
+    println("The membership value of each vector to all mixture components (first <= 100):")
+    val membership = clusters.predictSoft(data)
+    membership.take(100).foreach { x =>
+      print(" " + x.mkString(","))
+    }
+    println()
     println("Cluster labels (first <= 100):")
     val clusterLabels = clusters.predict(data)
     clusterLabels.take(100).foreach { x =>

http://git-wip-us.apache.org/repos/asf/spark/blob/ee4ee02b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
index 6a3b20c..a689b09 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
@@ -40,5 +40,9 @@ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
     SerDe.dumps(JavaConverters.seqAsJavaListConverter(modelGaussians).asJava)
   }
 
+  def predictSoft(point: Vector): Vector = {
+    Vectors.dense(model.predictSoft(point))
+  }
+
   def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ee4ee02b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 16bc45b..42fe270 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -75,7 +75,7 @@ class GaussianMixtureModel @Since("1.3.0") (
    */
   @Since("1.5.0")
   def predict(point: Vector): Int = {
-    val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
+    val r = predictSoft(point)
     r.indexOf(r.max)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ee4ee02b/python/pyspark/mllib/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index d22a7f4..580cb51 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -202,16 +202,25 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
 
     >>> clusterdata_1 =  sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
     ...                                         0.9,0.8,0.75,0.935,
-    ...                                        -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2))
+    ...                                        -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2), 2)
     >>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001,
     ...                                 maxIterations=50, seed=10)
     >>> labels = model.predict(clusterdata_1).collect()
     >>> labels[0]==labels[1]
     False
     >>> labels[1]==labels[2]
-    True
+    False
     >>> labels[4]==labels[5]
     True
+    >>> model.predict([-0.1,-0.05])
+    0
+    >>> softPredicted = model.predictSoft([-0.1,-0.05])
+    >>> abs(softPredicted[0] - 1.0) < 0.001
+    True
+    >>> abs(softPredicted[1] - 0.0) < 0.001
+    True
+    >>> abs(softPredicted[2] - 0.0) < 0.001
+    True
 
     >>> path = tempfile.mkdtemp()
     >>> model.save(sc, path)
@@ -277,26 +286,27 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
     @since('1.3.0')
     def predict(self, x):
         """
-        Find the cluster to which the points in 'x' has maximum membership
-        in this model.
+        Find the cluster to which the point 'x' or each point in RDD 'x'
+        has maximum membership in this model.
 
-        :param x:    RDD of data points.
-        :return:     cluster_labels. RDD of cluster labels.
+        :param x:    vector or RDD of vector represents data points.
+        :return:     cluster label or RDD of cluster labels.
         """
         if isinstance(x, RDD):
             cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
             return cluster_labels
         else:
-            raise TypeError("x should be represented by an RDD, "
-                            "but got %s." % type(x))
+            z = self.predictSoft(x)
+            return z.argmax()
 
     @since('1.3.0')
     def predictSoft(self, x):
         """
-        Find the membership of each point in 'x' to all mixture components.
+        Find the membership of point 'x' or each point in RDD 'x' to all mixture components.
 
-        :param x:    RDD of data points.
-        :return:     membership_matrix. RDD of array of double values.
+        :param x:    vector or RDD of vector represents data points.
+        :return:     the membership value to all mixture components for vector 'x'
+                     or each vector in RDD 'x'.
         """
         if isinstance(x, RDD):
             means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
@@ -304,8 +314,7 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
                                               _convert_to_vector(self.weights), means, sigmas)
             return membership_matrix.map(lambda x: pyarray.array('d', x))
         else:
-            raise TypeError("x should be represented by an RDD, "
-                            "but got %s." % type(x))
+            return self.call("predictSoft", _convert_to_vector(x)).toArray()
 
     @classmethod
     @since('1.5.0')


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org