You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/02/22 11:02:23 UTC
[spark] branch master updated: [SPARK-42526][ML] Add Classifier.getNumClasses back
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a6098beade0 [SPARK-42526][ML] Add Classifier.getNumClasses back
a6098beade0 is described below
commit a6098beade01eac5cf92727e69b3537fcac31b2d
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Wed Feb 22 19:02:01 2023 +0800
[SPARK-42526][ML] Add Classifier.getNumClasses back
### What changes were proposed in this pull request?
Add Classifier.getNumClasses back
### Why are the changes needed?
some famous libraries like `xgboost` happen to depend on this method, even though it is not a public API
so it should be nice to make xgboost integration better.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
update mima
Closes #40119 from zhengruifeng/ml_add_classifier_get_num_class.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../apache/spark/ml/classification/Classifier.scala | 19 +++++++++++++++++++
.../ml/classification/DecisionTreeClassifier.scala | 2 +-
.../ml/classification/RandomForestClassifier.scala | 2 +-
project/MimaExcludes.scala | 2 --
4 files changed, 21 insertions(+), 4 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 2d7719a29ca..c46be175cb2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -56,6 +56,25 @@ abstract class Classifier[
M <: ClassificationModel[FeaturesType, M]]
extends Predictor[FeaturesType, E, M] with ClassifierParams {
+ /**
+ * Get the number of classes. This looks in column metadata first, and if that is missing,
+ * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
+ * by finding the maximum label value.
+ *
+ * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
+ * such as in `extractLabeledPoints()`.
+ *
+ * @param dataset Dataset which contains a column [[labelCol]]
+ * @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses
+ * is specified in the metadata, then maxNumClasses is ignored.
+ * @return number of classes
+ * @throws IllegalArgumentException if metadata does not specify numClasses, and the
+ * actual numClasses exceeds maxNumClasses
+ */
+ protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = {
+ DatasetUtils.getNumClasses(dataset, $(labelCol), maxNumClasses)
+ }
+
/** @group setParam */
def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 688d2d18f48..7deefda2eea 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -117,7 +117,7 @@ class DecisionTreeClassifier @Since("1.4.0") (
instr.logPipelineStage(this)
instr.logDataset(dataset)
val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val numClasses = getNumClasses(dataset, $(labelCol))
+ val numClasses = getNumClasses(dataset)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 048e5949e1c..9295425f9d6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -141,7 +141,7 @@ class RandomForestClassifier @Since("1.4.0") (
instr.logDataset(dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val numClasses = getNumClasses(dataset, $(labelCol))
+ val numClasses = getNumClasses(dataset)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 70a7c29b8dc..9741e53452a 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -55,8 +55,6 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.extractLabeledPoints"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateNumClasses"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateLabel"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses$default$2"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRest.extractInstances"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.extractInstances"),
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org