You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/03/21 16:39:17 UTC
spark git commit: [SPARK-10884][ML] Support prediction on single
instance for regression and classification related models
Repository: spark
Updated Branches:
refs/heads/master 500b21c3d -> bf09f2f71
[SPARK-10884][ML] Support prediction on single instance for regression and classification related models
## What changes were proposed in this pull request?
Support prediction on single instance for regression and classification related models (i.e., PredictionModel, ClassificationModel and their sub classes).
Add corresponding test cases.
## How was this patch tested?
Test cases added.
Author: WeichenXu <we...@databricks.com>
Closes #19381 from WeichenXu123/single_prediction.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bf09f2f7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bf09f2f7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bf09f2f7
Branch: refs/heads/master
Commit: bf09f2f71276d3b3a84a8f89109bd785a066c3e6
Parents: 500b21c
Author: WeichenXu <we...@databricks.com>
Authored: Wed Mar 21 09:39:14 2018 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Wed Mar 21 09:39:14 2018 -0700
----------------------------------------------------------------------
.../scala/org/apache/spark/ml/Predictor.scala | 5 ++--
.../spark/ml/classification/Classifier.scala | 6 ++---
.../classification/DecisionTreeClassifier.scala | 2 +-
.../spark/ml/classification/GBTClassifier.scala | 2 +-
.../spark/ml/classification/LinearSVC.scala | 2 +-
.../ml/classification/LogisticRegression.scala | 2 +-
.../MultilayerPerceptronClassifier.scala | 2 +-
.../ml/regression/DecisionTreeRegressor.scala | 2 +-
.../spark/ml/regression/GBTRegressor.scala | 2 +-
.../GeneralizedLinearRegression.scala | 2 +-
.../spark/ml/regression/LinearRegression.scala | 2 +-
.../ml/regression/RandomForestRegressor.scala | 2 +-
.../DecisionTreeClassifierSuite.scala | 17 +++++++++++++-
.../ml/classification/GBTClassifierSuite.scala | 9 ++++++++
.../ml/classification/LinearSVCSuite.scala | 6 +++++
.../LogisticRegressionSuite.scala | 9 ++++++++
.../MultilayerPerceptronClassifierSuite.scala | 12 ++++++++++
.../ml/classification/NaiveBayesSuite.scala | 22 ++++++++++++++++++
.../RandomForestClassifierSuite.scala | 16 +++++++++++++
.../regression/DecisionTreeRegressorSuite.scala | 15 ++++++++++++
.../spark/ml/regression/GBTRegressorSuite.scala | 8 +++++++
.../GeneralizedLinearRegressionSuite.scala | 8 +++++++
.../ml/regression/LinearRegressionSuite.scala | 7 ++++++
.../regression/RandomForestRegressorSuite.scala | 24 ++++++++++++++++----
.../scala/org/apache/spark/ml/util/MLTest.scala | 15 ++++++++++--
25 files changed, 176 insertions(+), 23 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 08b0cb9..d8f3dfa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -219,7 +219,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
/**
* Predict label for the given features.
- * This internal method is used to implement `transform()` and output [[predictionCol]].
+ * This method is used to implement `transform()` and output [[predictionCol]].
*/
- protected def predict(features: FeaturesType): Double
+ @Since("2.4.0")
+ def predict(features: FeaturesType): Double
}
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 9d1d5aa..7e5790a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkException
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
@@ -192,12 +192,12 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
/**
* Predict label for the given features.
- * This internal method is used to implement `transform()` and output [[predictionCol]].
+ * This method is used to implement `transform()` and output [[predictionCol]].
*
* This default implementation for classification predicts the index of the maximum value
* from `predictRaw()`.
*/
- override protected def predict(features: FeaturesType): Double = {
+ override def predict(features: FeaturesType): Double = {
raw2prediction(predictRaw(features))
}
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 9f60f08..65cce69 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -181,7 +181,7 @@ class DecisionTreeClassificationModel private[ml] (
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
rootNode.predictImpl(features).prediction
}
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index f11bc1d..cd44489 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -267,7 +267,7 @@ class GBTClassificationModel private[ml](
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
// If thresholds defined, use predictRaw to get probabilities, otherwise use optimization
if (isDefined(thresholds)) {
super.predict(features)
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index ce400f4..8f950cd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -316,7 +316,7 @@ class LinearSVCModel private[classification] (
BLAS.dot(features, coefficients) + intercept
}
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
if (margin(features) > $(threshold)) 1.0 else 0.0
}
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index fa19160..3ae4db3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -1090,7 +1090,7 @@ class LogisticRegressionModel private[spark] (
* Predict label for the given feature vector.
* The behavior of this can be adjusted using `thresholds`.
*/
- override protected def predict(features: Vector): Double = if (isMultinomial) {
+ override def predict(features: Vector): Double = if (isMultinomial) {
super.predict(features)
} else {
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index fd4c98f..af2e469 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -322,7 +322,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
* Predict label for the given features.
* This internal method is used to implement `transform()` and output [[predictionCol]].
*/
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
LabelConverter.decodeLabel(mlpModel.predict(features))
}
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 0291a57..ad154fc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -178,7 +178,7 @@ class DecisionTreeRegressionModel private[ml] (
private[ml] def this(rootNode: Node, numFeatures: Int) =
this(Identifiable.randomUID("dtr"), rootNode, numFeatures)
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
rootNode.predictImpl(features).prediction
}
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index f41d15b..6569ff2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -230,7 +230,7 @@ class GBTRegressionModel private[ml](
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 917a4d2..9f1f240 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -1010,7 +1010,7 @@ class GeneralizedLinearRegressionModel private[ml] (
private lazy val familyAndLink = FamilyAndLink(this)
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
predict(features, 0.0)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 6d3fe7a..9251015 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -699,7 +699,7 @@ class LinearRegressionModel private[ml] (
}
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
dot(features, coefficients) + intercept
}
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 200b234..2d59446 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -199,7 +199,7 @@ class RandomForestRegressionModel private[ml] (
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index eeb0324..2930f49 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
+import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
@@ -264,6 +264,21 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
Vector, DecisionTreeClassificationModel](this, newTree, newData)
}
+ test("prediction on single instance") {
+ val rdd = continuousDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3)
+ val numClasses = 3
+
+ val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
+ val newTree = dt.fit(newData)
+
+ testPredictionModelSinglePrediction(newTree, newData)
+ }
+
test("training with 1-category categorical feature") {
val data = sc.parallelize(Seq(
LabeledPoint(0, Vectors.dense(0, 2, 3)),
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 092b4a0..5779606 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -197,6 +197,15 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
Vector, GBTClassificationModel](this, gbtModel, validationDataset)
}
+ test("prediction on single instance") {
+
+ val gbt = new GBTClassifier().setSeed(123)
+ val trainingDataset = trainData.toDF("label", "features")
+ val gbtModel = gbt.fit(trainingDataset)
+
+ testPredictionModelSinglePrediction(gbtModel, trainingDataset)
+ }
+
test("GBT parameter stepSize should be in interval (0, 1]") {
withClue("GBT parameter stepSize should be in interval (0, 1]") {
intercept[IllegalArgumentException] {
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index a93825b..c05c896 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -201,6 +201,12 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest {
dataset.as[LabeledPoint], estimator, modelEquals, 42L)
}
+ test("prediction on single instance") {
+ val trainer = new LinearSVC()
+ val model = trainer.fit(smallBinaryDataset)
+ testPredictionModelSinglePrediction(model, smallBinaryDataset)
+ }
+
test("linearSVC comparison with R e1071 and scikit-learn") {
val trainer1 = new LinearSVC()
.setRegParam(0.00002) // set regParam = 2.0 / datasize / c
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 9987cbf..36b7e51 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -499,6 +499,15 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
Vector, LogisticRegressionModel](this, model, smallBinaryDataset)
}
+ test("prediction on single instance") {
+ val blor = new LogisticRegression().setFamily("binomial")
+ val blorModel = blor.fit(smallBinaryDataset)
+ testPredictionModelSinglePrediction(blorModel, smallBinaryDataset)
+ val mlor = new LogisticRegression().setFamily("multinomial")
+ val mlorModel = mlor.fit(smallMultinomialDataset)
+ testPredictionModelSinglePrediction(mlorModel, smallMultinomialDataset)
+ }
+
test("coefficients and intercept methods") {
val mlr = new LogisticRegression().setMaxIter(1).setFamily("multinomial")
val mlrModel = mlr.fit(smallMultinomialDataset)
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index daa58a5..6b5fe6e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -76,6 +76,18 @@ class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTe
}
}
+ test("prediction on single instance") {
+ val layers = Array[Int](2, 5, 2)
+ val trainer = new MultilayerPerceptronClassifier()
+ .setLayers(layers)
+ .setBlockSize(1)
+ .setSeed(123L)
+ .setMaxIter(100)
+ .setSolver("l-bfgs")
+ val model = trainer.fit(dataset)
+ testPredictionModelSinglePrediction(model, dataset)
+ }
+
test("Predicted class probabilities: calibration on toy dataset") {
val layers = Array[Int](4, 5, 2)
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 49115c8..5f9ab98 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -167,6 +167,28 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
Vector, NaiveBayesModel](this, model, testDataset)
}
+ test("prediction on single instance") {
+ val nPoints = 1000
+ val piArray = Array(0.5, 0.1, 0.4).map(math.log)
+ val thetaArray = Array(
+ Array(0.70, 0.10, 0.10, 0.10), // label 0
+ Array(0.10, 0.70, 0.10, 0.10), // label 1
+ Array(0.10, 0.10, 0.70, 0.10) // label 2
+ ).map(_.map(math.log))
+ val pi = Vectors.dense(piArray)
+ val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
+
+ val trainDataset =
+ generateNaiveBayesInput(piArray, thetaArray, nPoints, seed, "multinomial").toDF()
+ val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
+ val model = nb.fit(trainDataset)
+
+ val validationDataset =
+ generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF()
+
+ testPredictionModelSinglePrediction(model, validationDataset)
+ }
+
test("Naive Bayes with weighted samples") {
val numClasses = 3
def modelEquals(m1: NaiveBayesModel, m2: NaiveBayesModel): Unit = {
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 02a9d5c..ba4a9cf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -155,6 +155,22 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
Vector, RandomForestClassificationModel](this, model, df)
}
+ test("prediction on single instance") {
+ val rdd = orderedLabeledPoints5_20
+ val rf = new RandomForestClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(3)
+ .setNumTrees(3)
+ .setSeed(123)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val numClasses = 2
+
+ val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
+ val model = rf.fit(df)
+
+ testPredictionModelSinglePrediction(model, df)
+ }
+
test("Fitting without numClasses in metadata") {
val df: DataFrame = TreeTests.featureImportanceData(sc).toDF()
val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1)
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 68a1218..29a4383 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -136,6 +136,21 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
assert(importances.toArray.forall(_ >= 0.0))
}
+ test("prediction on single instance") {
+ val dt = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(3)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+ val model = dt.fit(df)
+ testPredictionModelSinglePrediction(model, df)
+ }
+
test("should support all NumericType labels and not support other types") {
val dt = new DecisionTreeRegressor().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor](
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 11c593b..fad11d0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -99,6 +99,14 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
}
}
+ test("prediction on single instance") {
+ val gbt = new GBTRegressor()
+ .setMaxDepth(2)
+ .setMaxIter(2)
+ val model = gbt.fit(trainData.toDF())
+ testPredictionModelSinglePrediction(model, validationData.toDF)
+ }
+
test("Checkpointing") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index ef2ff94..d5bcbb2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -211,6 +211,14 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest
assert(model.getLink === "identity")
}
+ test("prediction on single instance") {
+ val glr = new GeneralizedLinearRegression
+ val model = glr.setFamily("gaussian").setLink("identity")
+ .fit(datasetGaussianIdentity)
+
+ testPredictionModelSinglePrediction(model, datasetGaussianIdentity)
+ }
+
test("generalized linear regression: gaussian family against glm") {
/*
R code:
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index d42cb17..9b19f63 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -636,6 +636,13 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
}
}
+ test("prediction on single instance") {
+ val trainer = new LinearRegression
+ val model = trainer.fit(datasetWithDenseFeature)
+
+ testPredictionModelSinglePrediction(model, datasetWithDenseFeature)
+ }
+
test("linear regression model with constant label") {
/*
R code:
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 8b8e8a6..e83c49f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -19,22 +19,22 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
/**
* Test suite for [[RandomForestRegressor]].
*/
-class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
- with DefaultReadWriteTest{
+class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{
import RandomForestRegressorSuite.compareAPIs
+ import testImplicits._
private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
@@ -74,6 +74,20 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
regressionTestWithContinuousFeatures(rf)
}
+ test("prediction on single instance") {
+ val rf = new RandomForestRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setNumTrees(1)
+ .setFeatureSubsetStrategy("auto")
+ .setSeed(123)
+
+ val df = orderedLabeledPoints50_1000.toDF()
+ val model = rf.fit(df)
+ testPredictionModelSinglePrediction(model, df)
+ }
+
test("Feature importance with toy data") {
val rf = new RandomForestRegressor()
.setImpurity("variance")
http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
index 795fd0e..76d41f9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -22,8 +22,9 @@ import java.io.File
import org.scalatest.Suite
import org.apache.spark.SparkContext
-import org.apache.spark.ml.Transformer
-import org.apache.spark.sql.{DataFrame, Encoder, Row}
+import org.apache.spark.ml.{PredictionModel, Transformer}
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.streaming.StreamTest
@@ -136,4 +137,14 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
assert(hasExpectedMessage(exceptionOnStreamData))
}
}
+
+ def testPredictionModelSinglePrediction(model: PredictionModel[Vector, _],
+ dataset: Dataset[_]): Unit = {
+
+ model.transform(dataset).select(model.getFeaturesCol, model.getPredictionCol)
+ .collect().foreach {
+ case Row(features: Vector, prediction: Double) =>
+ assert(prediction === model.predict(features))
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org