You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by jkbradley <gi...@git.apache.org> on 2017/08/17 00:24:48 UTC

[GitHub] spark pull request #15435: [SPARK-17139][ML] Add model summary for Multinomi...

Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/15435#discussion_r133598714
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala ---
    @@ -1324,90 +1350,130 @@ private[ml] class MultiClassSummarizer extends Serializable {
     }
     
     /**
    - * Abstraction for multinomial Logistic Regression Training results.
    - * Currently, the training summary ignores the training weights except
    - * for the objective trace.
    - */
    -sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
    -
    -  /** objective function (scaled loss + regularization) at each iteration. */
    -  def objectiveHistory: Array[Double]
    -
    -  /** Number of training iterations until termination */
    -  def totalIterations: Int = objectiveHistory.length
    -
    -}
    -
    -/**
    - * Abstraction for Logistic Regression Results for a given model.
    + * Abstraction for logistic regression results for a given model.
      */
     sealed trait LogisticRegressionSummary extends Serializable {
     
       /**
        * Dataframe output by the model's `transform` method.
        */
    +  @Since("2.3.0")
       def predictions: DataFrame
     
       /** Field in "predictions" which gives the probability of each class as a vector. */
    +  @Since("2.3.0")
       def probabilityCol: String
     
    +  /** Field in "predictions" which gives the prediction of each class. */
    +  @Since("2.3.0")
    +  def predictionCol: String
    +
       /** Field in "predictions" which gives the true label of each instance (if available). */
    +  @Since("2.3.0")
       def labelCol: String
     
       /** Field in "predictions" which gives the features of each instance as a vector. */
    +  @Since("2.3.0")
       def featuresCol: String
     
    +  @transient private val multiclassMetrics = {
    +    new MulticlassMetrics(
    +      predictions.select(
    +        col(predictionCol),
    +        col(labelCol).cast(DoubleType))
    +        .rdd.map { case Row(prediction: Double, label: Double) => (prediction, label) })
    +  }
    +
    +  /** Returns true positive rate for each label. */
    +  @Since("2.3.0")
    +  def truePositiveRateByLabel: Array[Double] = recallByLabel
    +
    +  /** Returns false positive rate for each label. */
    +  @Since("2.3.0")
    +  def falsePositiveRateByLabel: Array[Double] = {
    +    multiclassMetrics.labels.map(label => multiclassMetrics.falsePositiveRate(label))
    +  }
    +
    +  /** Returns precision for each label. */
    +  @Since("2.3.0")
    +  def precisionByLabel: Array[Double] = {
    +    multiclassMetrics.labels.map(label => multiclassMetrics.precision(label))
    +  }
    +
    +  /** Returns recall for each label. */
    +  @Since("2.3.0")
    +  def recallByLabel: Array[Double] = {
    +    multiclassMetrics.labels.map(label => multiclassMetrics.recall(label))
    +  }
    +
    +  /**
    +   * Returns f-measure for each label.
    +   */
    +  @Since("2.3.0")
    +  def fMeasureByLabel(beta: Double): Array[Double] = {
    +    multiclassMetrics.labels.map(label => multiclassMetrics.fMeasure(label, beta))
    +  }
    +
    +  /** Returns f1-measure for each label. */
    +  @Since("2.3.0")
    +  def fMeasureByLabel: Array[Double] = fMeasureByLabel(1.0)
    +
    +  /** Returns accuracy. */
    +  @Since("2.3.0")
    +  def accuracy: Double = multiclassMetrics.accuracy
    +
    +  /** Returns weighted true positive rate. */
    +  @Since("2.3.0")
    +  def weightedTruePositiveRate: Double = weightedRecall
    +
    +  /** Returns weighted false positive rate. */
    +  @Since("2.3.0")
    +  def weightedFalsePositiveRate: Double = multiclassMetrics.weightedFalsePositiveRate
    +
    +  /** Returns weighted averaged recall. */
    +  @Since("2.3.0")
    +  def weightedRecall: Double = multiclassMetrics.weightedRecall
    --- End diff --
    
    For all of these, can you please make sure to copy all of the info from the MulticlassMetrics Scala doc?  There is some info missing, and users cannot be expected to find it on their own.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org