You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/09/29 23:57:09 UTC

spark git commit: Updated the following PR with minor changes to allow cherry-pick to branch-2.0

Repository: spark
Updated Branches:
  refs/heads/branch-2.0 0cdd7370a -> a99ea4c9e


Updated the following PR with minor changes to allow cherry-pick to branch-2.0

[SPARK-17697][ML] Fixed bug in summary calculations that pattern match against label without casting

In calling LogisticRegression.evaluate and GeneralizedLinearRegression.evaluate using a Dataset where the Label is not of a double type, calculations pattern match against a double and throw a MatchError.  This fix casts the Label column to a DoubleType to ensure there is no MatchError.

Added unit tests to call evaluate with a dataset that has Label as other numeric types.

Author: Bryan Cutler <cu...@gmail.com>

Closes #15288 from BryanCutler/binaryLOR-numericCheck-SPARK-17697.

(cherry picked from commit 2f739567080d804a942cfcca0e22f91ab7cbea36)
Signed-off-by: Joseph K. Bradley <jo...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a99ea4c9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a99ea4c9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a99ea4c9

Branch: refs/heads/branch-2.0
Commit: a99ea4c9e0e2f91e4b524987788f0acee88e564d
Parents: 0cdd737
Author: Bryan Cutler <cu...@gmail.com>
Authored: Thu Sep 29 16:31:30 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Sep 29 16:56:34 2016 -0700

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  |  2 +-
 .../GeneralizedLinearRegression.scala           | 11 +++++----
 .../LogisticRegressionSuite.scala               | 18 +++++++++++++-
 .../GeneralizedLinearRegressionSuite.scala      | 25 ++++++++++++++++++++
 4 files changed, 49 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a99ea4c9/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index cca3374..c50ee5d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -860,7 +860,7 @@ class BinaryLogisticRegressionSummary private[classification] (
   // TODO: Allow the user to vary the number of bins using a setBins method in
   // BinaryClassificationMetrics. For now the default is set to 100.
   @transient private val binaryMetrics = new BinaryClassificationMetrics(
-    predictions.select(probabilityCol, labelCol).rdd.map {
+    predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType)).rdd.map {
       case Row(score: Vector, label: Double) => (score(1), label)
     }, 100
   )

http://git-wip-us.apache.org/repos/asf/spark/blob/a99ea4c9/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 2bdc09e..7f88c12 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -988,7 +988,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
     } else {
       link.unlink(0.0)
     }
-    predictions.select(col(model.getLabelCol), w).rdd.map {
+    predictions.select(col(model.getLabelCol).cast(DoubleType), w).rdd.map {
       case Row(y: Double, weight: Double) =>
         family.deviance(y, wtdmu, weight)
     }.sum()
@@ -1000,7 +1000,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
   @Since("2.0.0")
   lazy val deviance: Double = {
     val w = weightCol
-    predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map {
+    predictions.select(col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map {
       case Row(label: Double, pred: Double, weight: Double) =>
         family.deviance(label, pred, weight)
     }.sum()
@@ -1026,9 +1026,10 @@ class GeneralizedLinearRegressionSummary private[regression] (
   lazy val aic: Double = {
     val w = weightCol
     val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0)
-    val t = predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map {
-      case Row(label: Double, pred: Double, weight: Double) =>
-        (label, pred, weight)
+    val t = predictions.select(
+      col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map {
+        case Row(label: Double, pred: Double, weight: Double) =>
+          (label, pred, weight)
     }
     family.aic(t, deviance, numInstances, weightSum) + 2 * rank
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/a99ea4c9/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index a1b4853..27c872a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -31,7 +31,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
-import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.LongType
 
 class LogisticRegressionSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -812,6 +813,21 @@ class LogisticRegressionSuite
       summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
   }
 
+  test("evaluate with labels that are not doubles") {
+    // Evaluate a test set with Label that is a numeric type other than Double
+    val lr = new LogisticRegression()
+      .setMaxIter(1)
+      .setRegParam(1.0)
+    val model = lr.fit(dataset)
+    val summary = model.evaluate(dataset).asInstanceOf[BinaryLogisticRegressionSummary]
+
+    val longLabelData = dataset.select(col(model.getLabelCol).cast(LongType),
+      col(model.getFeaturesCol))
+    val longSummary = model.evaluate(longLabelData).asInstanceOf[BinaryLogisticRegressionSummary]
+
+    assert(summary.areaUnderROC ~== longSummary.areaUnderROC absTol 1E-10)
+  }
+
   test("statistics on training data") {
     // Test that loss is monotonically decreasing.
     val lr = new LogisticRegression()

http://git-wip-us.apache.org/repos/asf/spark/blob/a99ea4c9/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index a4568e83..9d10215 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.mllib.random._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.FloatType
 
 class GeneralizedLinearRegressionSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -1034,6 +1035,30 @@ class GeneralizedLinearRegressionSuite
       .setFamily("gaussian")
       .fit(datasetGaussianIdentity.as[LabeledPoint])
   }
+
+  test("evaluate with labels that are not doubles") {
+    // Evaulate with a dataset that contains Labels not as doubles to verify correct casting
+    val dataset = spark.createDataFrame(sc.parallelize(Seq(
+      Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+      Instance(19.0, 1.0, Vectors.dense(1.0, 7.0)),
+      Instance(23.0, 1.0, Vectors.dense(2.0, 11.0)),
+      Instance(29.0, 1.0, Vectors.dense(3.0, 13.0))
+    ), 2))
+
+    val trainer = new GeneralizedLinearRegression()
+      .setMaxIter(1)
+    val model = trainer.fit(dataset)
+    assert(model.hasSummary)
+    val summary = model.summary
+
+    val longLabelDataset = dataset.select(col(model.getLabelCol).cast(FloatType),
+      col(model.getFeaturesCol))
+    val evalSummary = model.evaluate(longLabelDataset)
+    // The calculations below involve pattern matching with Label as a double
+    assert(evalSummary.nullDeviance === summary.nullDeviance)
+    assert(evalSummary.deviance === summary.deviance)
+    assert(evalSummary.aic === summary.aic)
+  }
 }
 
 object GeneralizedLinearRegressionSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org