You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2019/02/22 04:21:57 UTC
[spark] branch master updated: [SPARK-25097][ML] Support prediction
on single instance in KMeans/BiKMeans/GMM
This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 89d42dc [SPARK-25097][ML] Support prediction on single instance in KMeans/BiKMeans/GMM
89d42dc is described below
commit 89d42dc6d38c9508b7009652323d6b343742c5b8
Author: zhengruifeng <ru...@foxmail.com>
AuthorDate: Thu Feb 21 22:21:28 2019 -0600
[SPARK-25097][ML] Support prediction on single instance in KMeans/BiKMeans/GMM
## What changes were proposed in this pull request?
expose method `predict` in KMeans/BiKMeans/GMM
## How was this patch tested?
added testsuites
Closes #22087 from zhengruifeng/clu_pre_instance.
Authored-by: zhengruifeng <ru...@foxmail.com>
Signed-off-by: Sean Owen <se...@databricks.com>
---
.../spark/ml/clustering/BisectingKMeans.scala | 6 ++---
.../spark/ml/clustering/GaussianMixture.scala | 6 +++--
.../org/apache/spark/ml/clustering/KMeans.scala | 7 +++---
.../spark/ml/clustering/BisectingKMeansSuite.scala | 7 ++++++
.../spark/ml/clustering/GaussianMixtureSuite.scala | 10 ++++++++
.../apache/spark/ml/clustering/KMeansSuite.scala | 7 ++++++
.../scala/org/apache/spark/ml/util/MLTest.scala | 28 +++++++++++++++++++++-
7 files changed, 61 insertions(+), 10 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index d846f17..03afdbe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -19,7 +19,6 @@ package org.apache.spark.ml.clustering
import org.apache.hadoop.fs.Path
-import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.Vector
@@ -30,7 +29,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans,
BisectingKMeansModel => MLlibBisectingKMeansModel}
import org.apache.spark.mllib.linalg.VectorImplicits._
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -118,7 +117,8 @@ class BisectingKMeansModel private[ml] (
validateAndTransformSchema(schema)
}
- private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
+ @Since("3.0.0")
+ def predict(features: Vector): Int = parentModel.predict(features)
@Since("2.0.0")
def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index c27ba55..3d6d1e3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -121,12 +121,14 @@ class GaussianMixtureModel private[ml] (
validateAndTransformSchema(schema)
}
- private[clustering] def predict(features: Vector): Int = {
+ @Since("3.0.0")
+ def predict(features: Vector): Int = {
val r = predictProbability(features)
r.argmax
}
- private[clustering] def predictProbability(features: Vector): Vector = {
+ @Since("3.0.0")
+ def predictProbability(features: Vector): Vector = {
val probs: Array[Double] =
GaussianMixtureModel.computeProbabilities(features.asBreeze.toDenseVector, gaussians, weights)
Vectors.dense(probs)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 319747d..b48a966 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -21,7 +21,6 @@ import scala.collection.mutable
import org.apache.hadoop.fs.Path
-import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
import org.apache.spark.ml.linalg.Vector
@@ -32,8 +31,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
@@ -139,7 +137,8 @@ class KMeansModel private[ml] (
validateAndTransformSchema(schema)
}
- private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
+ @Since("3.0.0")
+ def predict(features: Vector): Int = parentModel.predict(features)
@Since("2.0.0")
def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index 461f8b8..5708097 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -205,6 +205,13 @@ class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest {
assert(trueCost ~== doubleArrayCost absTol 1e-6)
assert(trueCost ~== floatArrayCost absTol 1e-6)
}
+
+ test("prediction on single instance") {
+ val bikm = new BisectingKMeans().setSeed(123L)
+ val model = bikm.fit(dataset)
+ testClusteringModelSinglePrediction(model, model.predict, dataset,
+ model.getFeaturesCol, model.getPredictionCol)
+ }
}
object BisectingKMeansSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index 13bed9d..11fdd3a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -268,6 +268,16 @@ class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest {
assert(trueLikelihood ~== doubleLikelihood absTol 1e-6)
assert(trueLikelihood ~== floatLikelihood absTol 1e-6)
}
+
+ test("prediction on single instance") {
+ val gmm = new GaussianMixture().setSeed(123L)
+ val model = gmm.fit(dataset)
+ testClusteringModelSinglePrediction(model, model.predict, dataset,
+ model.getFeaturesCol, model.getPredictionCol)
+
+ testClusteringModelSingleProbabilisticPrediction(model, model.predictProbability, dataset,
+ model.getFeaturesCol, model.getProbabilityCol)
+ }
}
object GaussianMixtureSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 4f47d91..b377582 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -244,6 +244,13 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes
}
testPMMLWrite(sc, kmeansModel, checkModel)
}
+
+ test("prediction on single instance") {
+ val kmeans = new KMeans().setSeed(123L)
+ val model = kmeans.fit(dataset)
+ testClusteringModelSinglePrediction(model, model.predict, dataset,
+ model.getFeaturesCol, model.getPredictionCol)
+ }
}
object KMeansSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
index 514fa7f..c23b6d8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -23,7 +23,7 @@ import org.scalatest.Suite
import org.apache.spark.{DebugFilesystem, SparkConf, SparkContext}
import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK
-import org.apache.spark.ml.{PredictionModel, Transformer}
+import org.apache.spark.ml.{Model, PredictionModel, Transformer}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row}
import org.apache.spark.sql.execution.streaming.MemoryStream
@@ -156,4 +156,30 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
assert(prediction === model.predict(features))
}
}
+
+ def testClusteringModelSinglePrediction(
+ model: Model[_],
+ transform: Vector => Int,
+ dataset: Dataset[_],
+ input: String,
+ output: String): Unit = {
+ model.transform(dataset).select(input, output)
+ .collect().foreach {
+ case Row(features: Vector, prediction: Int) =>
+ assert(prediction === transform(features))
+ }
+ }
+
+ def testClusteringModelSingleProbabilisticPrediction(
+ model: Model[_],
+ transform: Vector => Vector,
+ dataset: Dataset[_],
+ input: String,
+ output: String): Unit = {
+ model.transform(dataset).select(input, output)
+ .collect().foreach {
+ case Row(features: Vector, prediction: Vector) =>
+ assert(prediction === transform(features))
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org