You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/02/22 11:02:23 UTC

[spark] branch master updated: [SPARK-42526][ML] Add Classifier.getNumClasses back

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a6098beade0 [SPARK-42526][ML] Add Classifier.getNumClasses back
a6098beade0 is described below

commit a6098beade01eac5cf92727e69b3537fcac31b2d
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Wed Feb 22 19:02:01 2023 +0800

    [SPARK-42526][ML] Add Classifier.getNumClasses back
    
    ### What changes were proposed in this pull request?
    Add Classifier.getNumClasses back
    
    ### Why are the changes needed?
    some famous libraries like `xgboost` happen to depend on this method, even though it is not a public API
    so it should be nice to make xgboost integration better.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    update mima
    
    Closes #40119 from zhengruifeng/ml_add_classifier_get_num_class.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../apache/spark/ml/classification/Classifier.scala   | 19 +++++++++++++++++++
 .../ml/classification/DecisionTreeClassifier.scala    |  2 +-
 .../ml/classification/RandomForestClassifier.scala    |  2 +-
 project/MimaExcludes.scala                            |  2 --
 4 files changed, 21 insertions(+), 4 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 2d7719a29ca..c46be175cb2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -56,6 +56,25 @@ abstract class Classifier[
     M <: ClassificationModel[FeaturesType, M]]
   extends Predictor[FeaturesType, E, M] with ClassifierParams {
 
+  /**
+   * Get the number of classes.  This looks in column metadata first, and if that is missing,
+   * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
+   * by finding the maximum label value.
+   *
+   * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
+   * such as in `extractLabeledPoints()`.
+   *
+   * @param dataset       Dataset which contains a column [[labelCol]]
+   * @param maxNumClasses Maximum number of classes allowed when inferred from data.  If numClasses
+   *                      is specified in the metadata, then maxNumClasses is ignored.
+   * @return number of classes
+   * @throws IllegalArgumentException if metadata does not specify numClasses, and the
+   *                                  actual numClasses exceeds maxNumClasses
+   */
+  protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = {
+    DatasetUtils.getNumClasses(dataset, $(labelCol), maxNumClasses)
+  }
+
   /** @group setParam */
   def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 688d2d18f48..7deefda2eea 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -117,7 +117,7 @@ class DecisionTreeClassifier @Since("1.4.0") (
     instr.logPipelineStage(this)
     instr.logDataset(dataset)
     val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
-    val numClasses = getNumClasses(dataset, $(labelCol))
+    val numClasses = getNumClasses(dataset)
 
     if (isDefined(thresholds)) {
       require($(thresholds).length == numClasses, this.getClass.getSimpleName +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 048e5949e1c..9295425f9d6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -141,7 +141,7 @@ class RandomForestClassifier @Since("1.4.0") (
     instr.logDataset(dataset)
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
-    val numClasses = getNumClasses(dataset, $(labelCol))
+    val numClasses = getNumClasses(dataset)
 
     if (isDefined(thresholds)) {
       require($(thresholds).length == numClasses, this.getClass.getSimpleName +
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 70a7c29b8dc..9741e53452a 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -55,8 +55,6 @@ object MimaExcludes {
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.extractLabeledPoints"),
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateNumClasses"),
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateLabel"),
-    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses"),
-    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses$default$2"),
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRest.extractInstances"),
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.extractInstances"),
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org