You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/09/29 23:57:09 UTC
spark git commit: Updated the following PR with minor changes to
allow cherry-pick to branch-2.0
Repository: spark
Updated Branches:
refs/heads/branch-2.0 0cdd7370a -> a99ea4c9e
Updated the following PR with minor changes to allow cherry-pick to branch-2.0
[SPARK-17697][ML] Fixed bug in summary calculations that pattern match against label without casting
In calling LogisticRegression.evaluate and GeneralizedLinearRegression.evaluate using a Dataset where the Label is not of a double type, calculations pattern match against a double and throw a MatchError. This fix casts the Label column to a DoubleType to ensure there is no MatchError.
Added unit tests to call evaluate with a dataset that has Label as other numeric types.
Author: Bryan Cutler <cu...@gmail.com>
Closes #15288 from BryanCutler/binaryLOR-numericCheck-SPARK-17697.
(cherry picked from commit 2f739567080d804a942cfcca0e22f91ab7cbea36)
Signed-off-by: Joseph K. Bradley <jo...@databricks.com>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a99ea4c9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a99ea4c9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a99ea4c9
Branch: refs/heads/branch-2.0
Commit: a99ea4c9e0e2f91e4b524987788f0acee88e564d
Parents: 0cdd737
Author: Bryan Cutler <cu...@gmail.com>
Authored: Thu Sep 29 16:31:30 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Sep 29 16:56:34 2016 -0700
----------------------------------------------------------------------
.../ml/classification/LogisticRegression.scala | 2 +-
.../GeneralizedLinearRegression.scala | 11 +++++----
.../LogisticRegressionSuite.scala | 18 +++++++++++++-
.../GeneralizedLinearRegressionSuite.scala | 25 ++++++++++++++++++++
4 files changed, 49 insertions(+), 7 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/a99ea4c9/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index cca3374..c50ee5d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -860,7 +860,7 @@ class BinaryLogisticRegressionSummary private[classification] (
// TODO: Allow the user to vary the number of bins using a setBins method in
// BinaryClassificationMetrics. For now the default is set to 100.
@transient private val binaryMetrics = new BinaryClassificationMetrics(
- predictions.select(probabilityCol, labelCol).rdd.map {
+ predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType)).rdd.map {
case Row(score: Vector, label: Double) => (score(1), label)
}, 100
)
http://git-wip-us.apache.org/repos/asf/spark/blob/a99ea4c9/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 2bdc09e..7f88c12 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -988,7 +988,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
} else {
link.unlink(0.0)
}
- predictions.select(col(model.getLabelCol), w).rdd.map {
+ predictions.select(col(model.getLabelCol).cast(DoubleType), w).rdd.map {
case Row(y: Double, weight: Double) =>
family.deviance(y, wtdmu, weight)
}.sum()
@@ -1000,7 +1000,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
@Since("2.0.0")
lazy val deviance: Double = {
val w = weightCol
- predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map {
+ predictions.select(col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map {
case Row(label: Double, pred: Double, weight: Double) =>
family.deviance(label, pred, weight)
}.sum()
@@ -1026,9 +1026,10 @@ class GeneralizedLinearRegressionSummary private[regression] (
lazy val aic: Double = {
val w = weightCol
val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0)
- val t = predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map {
- case Row(label: Double, pred: Double, weight: Double) =>
- (label, pred, weight)
+ val t = predictions.select(
+ col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map {
+ case Row(label: Double, pred: Double, weight: Double) =>
+ (label, pred, weight)
}
family.aic(t, deviance, numInstances, weightSum) + 2 * rank
}
http://git-wip-us.apache.org/repos/asf/spark/blob/a99ea4c9/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index a1b4853..27c872a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -31,7 +31,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, Row}
-import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.LongType
class LogisticRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -812,6 +813,21 @@ class LogisticRegressionSuite
summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
}
+ test("evaluate with labels that are not doubles") {
+ // Evaluate a test set with Label that is a numeric type other than Double
+ val lr = new LogisticRegression()
+ .setMaxIter(1)
+ .setRegParam(1.0)
+ val model = lr.fit(dataset)
+ val summary = model.evaluate(dataset).asInstanceOf[BinaryLogisticRegressionSummary]
+
+ val longLabelData = dataset.select(col(model.getLabelCol).cast(LongType),
+ col(model.getFeaturesCol))
+ val longSummary = model.evaluate(longLabelData).asInstanceOf[BinaryLogisticRegressionSummary]
+
+ assert(summary.areaUnderROC ~== longSummary.areaUnderROC absTol 1E-10)
+ }
+
test("statistics on training data") {
// Test that loss is monotonically decreasing.
val lr = new LogisticRegression()
http://git-wip-us.apache.org/repos/asf/spark/blob/a99ea4c9/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index a4568e83..9d10215 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.mllib.random._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.FloatType
class GeneralizedLinearRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -1034,6 +1035,30 @@ class GeneralizedLinearRegressionSuite
.setFamily("gaussian")
.fit(datasetGaussianIdentity.as[LabeledPoint])
}
+
+ test("evaluate with labels that are not doubles") {
+ // Evaulate with a dataset that contains Labels not as doubles to verify correct casting
+ val dataset = spark.createDataFrame(sc.parallelize(Seq(
+ Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(19.0, 1.0, Vectors.dense(1.0, 7.0)),
+ Instance(23.0, 1.0, Vectors.dense(2.0, 11.0)),
+ Instance(29.0, 1.0, Vectors.dense(3.0, 13.0))
+ ), 2))
+
+ val trainer = new GeneralizedLinearRegression()
+ .setMaxIter(1)
+ val model = trainer.fit(dataset)
+ assert(model.hasSummary)
+ val summary = model.summary
+
+ val longLabelDataset = dataset.select(col(model.getLabelCol).cast(FloatType),
+ col(model.getFeaturesCol))
+ val evalSummary = model.evaluate(longLabelDataset)
+ // The calculations below involve pattern matching with Label as a double
+ assert(evalSummary.nullDeviance === summary.nullDeviance)
+ assert(evalSummary.deviance === summary.deviance)
+ assert(evalSummary.aic === summary.aic)
+ }
}
object GeneralizedLinearRegressionSuite {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org