You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/03/28 04:04:22 UTC

spark git commit: [SPARK-10691][ML] Make LogisticRegressionModel, LinearRegressionModel evaluate() public

Repository: spark
Updated Branches:
  refs/heads/master 0f02a5c6e -> 8ef493760


[SPARK-10691][ML] Make LogisticRegressionModel, LinearRegressionModel evaluate() public

## What changes were proposed in this pull request?

Made evaluate method public.  Fixed LogisticRegressionModel evaluate to handle case when probabilityCol is not specified.

## How was this patch tested?

There were already unit tests for these methods.

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #11928 from jkbradley/public-evaluate.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8ef49376
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8ef49376
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8ef49376

Branch: refs/heads/master
Commit: 8ef493760f58687df766d03ccf64039635a2609f
Parents: 0f02a5c
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Sun Mar 27 19:04:18 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Sun Mar 27 19:04:18 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/classification/LogisticRegression.scala    | 12 +++++++-----
 .../apache/spark/ml/regression/LinearRegression.scala   |  8 ++++----
 2 files changed, 11 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8ef49376/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 861b1d4..3d1d5b6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -539,13 +539,15 @@ class LogisticRegressionModel private[spark] (
   def hasSummary: Boolean = trainingSummary.isDefined
 
   /**
-   * Evaluates the model on a testset.
+   * Evaluates the model on a test dataset.
    * @param dataset Test dataset to evaluate model on.
    */
-  // TODO: decide on a good name before exposing to public API
-  private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
-    new BinaryLogisticRegressionSummary(
-      this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol))
+  @Since("2.0.0")
+  def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
+    // Handle possible missing or invalid prediction columns
+    val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol()
+    new BinaryLogisticRegressionSummary(summaryModel.transform(dataset),
+      probabilityColName, $(labelCol), $(featuresCol))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/8ef49376/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index b81c588..5ec0213 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -412,15 +412,15 @@ class LinearRegressionModel private[ml] (
   def hasSummary: Boolean = trainingSummary.isDefined
 
   /**
-   * Evaluates the model on a testset.
+   * Evaluates the model on a test dataset.
    * @param dataset Test dataset to evaluate model on.
    */
-  // TODO: decide on a good name before exposing to public API
-  private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = {
+  @Since("2.0.0")
+  def evaluate(dataset: DataFrame): LinearRegressionSummary = {
     // Handle possible missing or invalid prediction columns
     val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
     new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName,
-      $(labelCol), this, Array(0D))
+      $(labelCol), summaryModel, Array(0D))
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org