You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2019/02/22 04:21:57 UTC

[spark] branch master updated: [SPARK-25097][ML] Support prediction on single instance in KMeans/BiKMeans/GMM

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 89d42dc  [SPARK-25097][ML] Support prediction on single instance in KMeans/BiKMeans/GMM
89d42dc is described below

commit 89d42dc6d38c9508b7009652323d6b343742c5b8
Author: zhengruifeng <ru...@foxmail.com>
AuthorDate: Thu Feb 21 22:21:28 2019 -0600

    [SPARK-25097][ML] Support prediction on single instance in KMeans/BiKMeans/GMM
    
    ## What changes were proposed in this pull request?
    expose method `predict` in KMeans/BiKMeans/GMM
    
    ## How was this patch tested?
    added testsuites
    
    Closes #22087 from zhengruifeng/clu_pre_instance.
    
    Authored-by: zhengruifeng <ru...@foxmail.com>
    Signed-off-by: Sean Owen <se...@databricks.com>
---
 .../spark/ml/clustering/BisectingKMeans.scala      |  6 ++---
 .../spark/ml/clustering/GaussianMixture.scala      |  6 +++--
 .../org/apache/spark/ml/clustering/KMeans.scala    |  7 +++---
 .../spark/ml/clustering/BisectingKMeansSuite.scala |  7 ++++++
 .../spark/ml/clustering/GaussianMixtureSuite.scala | 10 ++++++++
 .../apache/spark/ml/clustering/KMeansSuite.scala   |  7 ++++++
 .../scala/org/apache/spark/ml/util/MLTest.scala    | 28 +++++++++++++++++++++-
 7 files changed, 61 insertions(+), 10 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index d846f17..03afdbe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -19,7 +19,6 @@ package org.apache.spark.ml.clustering
 
 import org.apache.hadoop.fs.Path
 
-import org.apache.spark.SparkException
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.linalg.Vector
@@ -30,7 +29,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
 import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans,
   BisectingKMeansModel => MLlibBisectingKMeansModel}
 import org.apache.spark.mllib.linalg.VectorImplicits._
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions.udf
 import org.apache.spark.sql.types.{IntegerType, StructType}
 
@@ -118,7 +117,8 @@ class BisectingKMeansModel private[ml] (
     validateAndTransformSchema(schema)
   }
 
-  private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
+  @Since("3.0.0")
+  def predict(features: Vector): Int = parentModel.predict(features)
 
   @Since("2.0.0")
   def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index c27ba55..3d6d1e3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -121,12 +121,14 @@ class GaussianMixtureModel private[ml] (
     validateAndTransformSchema(schema)
   }
 
-  private[clustering] def predict(features: Vector): Int = {
+  @Since("3.0.0")
+  def predict(features: Vector): Int = {
     val r = predictProbability(features)
     r.argmax
   }
 
-  private[clustering] def predictProbability(features: Vector): Vector = {
+  @Since("3.0.0")
+  def predictProbability(features: Vector): Vector = {
     val probs: Array[Double] =
       GaussianMixtureModel.computeProbabilities(features.asBreeze.toDenseVector, gaussians, weights)
     Vectors.dense(probs)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 319747d..b48a966 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -21,7 +21,6 @@ import scala.collection.mutable
 
 import org.apache.hadoop.fs.Path
 
-import org.apache.spark.SparkException
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.{Estimator, Model, PipelineStage}
 import org.apache.spark.ml.linalg.Vector
@@ -32,8 +31,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
 import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
 import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
 import org.apache.spark.mllib.linalg.VectorImplicits._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
 import org.apache.spark.sql.functions.udf
 import org.apache.spark.sql.types.{IntegerType, StructType}
 import org.apache.spark.storage.StorageLevel
@@ -139,7 +137,8 @@ class KMeansModel private[ml] (
     validateAndTransformSchema(schema)
   }
 
-  private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
+  @Since("3.0.0")
+  def predict(features: Vector): Int = parentModel.predict(features)
 
   @Since("2.0.0")
   def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index 461f8b8..5708097 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -205,6 +205,13 @@ class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest {
     assert(trueCost ~== doubleArrayCost absTol 1e-6)
     assert(trueCost ~== floatArrayCost absTol 1e-6)
   }
+
+  test("prediction on single instance") {
+    val bikm = new BisectingKMeans().setSeed(123L)
+    val model = bikm.fit(dataset)
+    testClusteringModelSinglePrediction(model, model.predict, dataset,
+      model.getFeaturesCol, model.getPredictionCol)
+  }
 }
 
 object BisectingKMeansSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index 13bed9d..11fdd3a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -268,6 +268,16 @@ class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest {
     assert(trueLikelihood ~== doubleLikelihood absTol 1e-6)
     assert(trueLikelihood ~== floatLikelihood absTol 1e-6)
   }
+
+  test("prediction on single instance") {
+    val gmm = new GaussianMixture().setSeed(123L)
+    val model = gmm.fit(dataset)
+    testClusteringModelSinglePrediction(model, model.predict, dataset,
+      model.getFeaturesCol, model.getPredictionCol)
+
+    testClusteringModelSingleProbabilisticPrediction(model, model.predictProbability, dataset,
+      model.getFeaturesCol, model.getProbabilityCol)
+  }
 }
 
 object GaussianMixtureSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 4f47d91..b377582 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -244,6 +244,13 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes
     }
     testPMMLWrite(sc, kmeansModel, checkModel)
   }
+
+  test("prediction on single instance") {
+    val kmeans = new KMeans().setSeed(123L)
+    val model = kmeans.fit(dataset)
+    testClusteringModelSinglePrediction(model, model.predict, dataset,
+      model.getFeaturesCol, model.getPredictionCol)
+  }
 }
 
 object KMeansSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
index 514fa7f..c23b6d8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -23,7 +23,7 @@ import org.scalatest.Suite
 
 import org.apache.spark.{DebugFilesystem, SparkConf, SparkContext}
 import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK
-import org.apache.spark.ml.{PredictionModel, Transformer}
+import org.apache.spark.ml.{Model, PredictionModel, Transformer}
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row}
 import org.apache.spark.sql.execution.streaming.MemoryStream
@@ -156,4 +156,30 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
         assert(prediction === model.predict(features))
     }
   }
+
+  def testClusteringModelSinglePrediction(
+    model: Model[_],
+    transform: Vector => Int,
+    dataset: Dataset[_],
+    input: String,
+    output: String): Unit = {
+    model.transform(dataset).select(input, output)
+      .collect().foreach {
+      case Row(features: Vector, prediction: Int) =>
+        assert(prediction === transform(features))
+    }
+  }
+
+  def testClusteringModelSingleProbabilisticPrediction(
+    model: Model[_],
+    transform: Vector => Vector,
+    dataset: Dataset[_],
+    input: String,
+    output: String): Unit = {
+    model.transform(dataset).select(input, output)
+      .collect().foreach {
+      case Row(features: Vector, prediction: Vector) =>
+        assert(prediction === transform(features))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org