You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2022/06/19 04:52:01 UTC

[spark] branch master updated: [SPARK-38775][ML] cleanup validation functions

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 27ed89b7be5 [SPARK-38775][ML] cleanup validation functions
27ed89b7be5 is described below

commit 27ed89b7be5ebb91e4a0b106b1669a7867a6012d
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Sat Jun 18 21:51:50 2022 -0700

    [SPARK-38775][ML] cleanup validation functions
    
    ### What changes were proposed in this pull request?
    1, remove unused `extractInstances` and `extractLabeledPoints` in `Predictor`;
    2, remove unused `checkNonNegativeWeight` in `function`;
    3, move `getNumClasses` from `Clasifier` to `DatasetUtils`;
    4, move `getNumFeatures` from `MetadataUtils` to `DatasetUtils`;
    
    ### Why are the changes needed?
    to unify to methods
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    existing testsuites
    
    Closes #36049 from zhengruifeng/validate_cleanup.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../spark/examples/ml/DeveloperApiExample.scala    |   7 +-
 .../main/scala/org/apache/spark/ml/Predictor.scala |  51 +---------
 .../spark/ml/classification/Classifier.scala       | 106 +--------------------
 .../ml/classification/DecisionTreeClassifier.scala |   3 +-
 .../spark/ml/classification/FMClassifier.scala     |   2 +-
 .../spark/ml/classification/GBTClassifier.scala    |  20 +---
 .../ml/classification/RandomForestClassifier.scala |   2 +-
 .../spark/ml/clustering/GaussianMixture.scala      |   2 +-
 .../evaluation/BinaryClassificationEvaluator.scala |   7 +-
 .../spark/ml/evaluation/ClusteringEvaluator.scala  |  21 ++--
 .../spark/ml/evaluation/ClusteringMetrics.scala    |   6 +-
 .../MulticlassClassificationEvaluator.scala        |   8 +-
 .../spark/ml/evaluation/RegressionEvaluator.scala  |  16 ++--
 .../scala/org/apache/spark/ml/feature/LSH.scala    |   2 +-
 .../org/apache/spark/ml/feature/RobustScaler.scala |   2 +-
 .../org/apache/spark/ml/feature/Selector.scala     |   2 +-
 .../ml/feature/UnivariateFeatureSelector.scala     |   2 +-
 .../apache/spark/ml/feature/VectorIndexer.scala    |   2 +-
 .../main/scala/org/apache/spark/ml/functions.scala |   6 --
 .../apache/spark/ml/regression/FMRegressor.scala   |   2 +-
 .../apache/spark/ml/regression/GBTRegressor.scala  |  20 +---
 .../regression/GeneralizedLinearRegression.scala   |   2 +-
 .../spark/ml/regression/LinearRegression.scala     |   2 +-
 .../org/apache/spark/ml/util/DatasetUtils.scala    |  82 +++++++++++++++-
 .../org/apache/spark/ml/util/MetadataUtils.scala   |  14 +--
 .../spark/ml/classification/ClassifierSuite.scala  |  44 +--------
 project/MimaExcludes.scala                         |  16 +++-
 27 files changed, 152 insertions(+), 297 deletions(-)

diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index 487cb27b93f..bfee3301f8e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -24,6 +24,7 @@ import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
 import org.apache.spark.ml.param.{IntParam, ParamMap}
 import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.sql.{Dataset, Row, SparkSession}
+import org.apache.spark.sql.functions.col
 
 /**
  * A simple example demonstrating how to write your own learning algorithm using Estimator,
@@ -120,8 +121,10 @@ private class MyLogisticRegression(override val uid: String)
 
   // This method is used by fit()
   override protected def train(dataset: Dataset[_]): MyLogisticRegressionModel = {
-    // Extract columns from data using helper method.
-    val oldDataset = extractLabeledPoints(dataset)
+    // Extract columns from data.
+    val oldDataset = dataset.select(col($(labelCol)).cast("double"), col($(featuresCol)))
+      .rdd
+      .map { case Row(l: Double, f: Vector) => LabeledPoint(l, f) }
 
     // Do learning to estimate the coefficients vector.
     val numFeatures = oldDataset.take(1)(0).features.size
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index e0b128e3698..9c6eb880c80 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -18,14 +18,11 @@
 package org.apache.spark.ml
 
 import org.apache.spark.annotation.Since
-import org.apache.spark.ml.feature.{Instance, LabeledPoint}
-import org.apache.spark.ml.functions.checkNonNegativeWeight
-import org.apache.spark.ml.linalg.{Vector, VectorUDT}
+import org.apache.spark.ml.linalg.VectorUDT
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.SchemaUtils
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
 
@@ -63,40 +60,6 @@ private[ml] trait PredictorParams extends Params
     }
     SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
   }
-
-  /**
-   * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
-   * and put it in an RDD with strong types.
-   */
-  protected def extractInstances(dataset: Dataset[_]): RDD[Instance] = {
-    val w = this match {
-      case p: HasWeightCol =>
-        if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
-          checkNonNegativeWeight((col($(p.weightCol)).cast(DoubleType)))
-        } else {
-          lit(1.0)
-        }
-    }
-
-    dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
-      case Row(label: Double, weight: Double, features: Vector) =>
-        Instance(label, weight, features)
-    }
-  }
-
-  /**
-   * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
-   * and put it in an RDD with strong types.
-   * Validate the output instances with the given function.
-   */
-  protected def extractInstances(
-      dataset: Dataset[_],
-      validateInstance: Instance => Unit): RDD[Instance] = {
-    extractInstances(dataset).map { instance =>
-      validateInstance(instance)
-      instance
-    }
-  }
 }
 
 /**
@@ -176,16 +139,6 @@ abstract class Predictor[
   override def transformSchema(schema: StructType): StructType = {
     validateAndTransformSchema(schema, fitting = true, featuresDataType)
   }
-
-  /**
-   * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
-   * and put it in an RDD with strong types.
-   */
-  protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
-    dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
-      case Row(label: Double, features: Vector) => LabeledPoint(label, features)
-    }
-  }
 }
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 09324e2087d..2d7719a29ca 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -17,17 +17,13 @@
 
 package org.apache.spark.ml.classification
 
-import org.apache.spark.SparkException
 import org.apache.spark.annotation.Since
 import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
-import org.apache.spark.ml.feature.{Instance, LabeledPoint}
 import org.apache.spark.ml.linalg.{Vector, VectorUDT}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared.HasRawPredictionCol
 import org.apache.spark.ml.util._
-import org.apache.spark.ml.util.DatasetUtils._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DataType, StructType}
 
@@ -44,23 +40,6 @@ private[spark] trait ClassifierParams
     val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
     SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT)
   }
-
-  /**
-   * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
-   * and put it in an RDD with strong types.
-   * Validates the label on the classifier is a valid integer in the range [0, numClasses).
-   */
-  protected def extractInstances(
-      dataset: Dataset[_],
-      numClasses: Int): RDD[Instance] = {
-    val validateInstance = (instance: Instance) => {
-      val label = instance.label
-      require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
-        s" dataset with invalid label $label. Labels must be integers in range" +
-        s" [0, $numClasses).")
-    }
-    extractInstances(dataset, validateInstance)
-  }
 }
 
 /**
@@ -81,89 +60,6 @@ abstract class Classifier[
   def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]
 
   // TODO: defaultEvaluator (follow-up PR)
-
-  /**
-   * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
-   * and put it in an RDD with strong types.
-   *
-   * @param dataset  DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]])
-   *                 and features (`Vector`).
-   * @param numClasses  Number of classes label can take.  Labels must be integers in the range
-   *                    [0, numClasses).
-   * @note Throws `SparkException` if any label is a non-integer or is negative
-   */
-  protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
-    validateNumClasses(numClasses)
-    dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
-      case Row(label: Double, features: Vector) =>
-        validateLabel(label, numClasses)
-        LabeledPoint(label, features)
-    }
-  }
-
-  /**
-   * Validates that number of classes is greater than zero.
-   *
-   * @param numClasses Number of classes label can take.
-   */
-  protected def validateNumClasses(numClasses: Int): Unit = {
-    require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
-      s" $numClasses, but requires numClasses > 0.")
-  }
-
-  /**
-   * Validates the label on the classifier is a valid integer in the range [0, numClasses).
-   *
-   * @param label The label to validate.
-   * @param numClasses Number of classes label can take.  Labels must be integers in the range
-   *                  [0, numClasses).
-   */
-  protected def validateLabel(label: Double, numClasses: Int): Unit = {
-    require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
-      s" dataset with invalid label $label.  Labels must be integers in range" +
-      s" [0, $numClasses).")
-  }
-
-  /**
-   * Get the number of classes.  This looks in column metadata first, and if that is missing,
-   * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
-   * by finding the maximum label value.
-   *
-   * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
-   * such as in `extractLabeledPoints()`.
-   *
-   * @param dataset  Dataset which contains a column [[labelCol]]
-   * @param maxNumClasses  Maximum number of classes allowed when inferred from data.  If numClasses
-   *                       is specified in the metadata, then maxNumClasses is ignored.
-   * @return  number of classes
-   * @throws IllegalArgumentException  if metadata does not specify numClasses, and the
-   *                                   actual numClasses exceeds maxNumClasses
-   */
-  protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = {
-    MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
-      case Some(n: Int) => n
-      case None =>
-        // Get number of classes from dataset itself.
-        val maxLabelRow: Array[Row] = dataset
-          .select(max(checkClassificationLabels($(labelCol), Some(maxNumClasses))))
-          .take(1)
-        if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) {
-          throw new SparkException("ML algorithm was given empty dataset.")
-        }
-        val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0)
-        require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" +
-          s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})")
-        val numClasses = maxDoubleLabel.toInt + 1
-        require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" +
-          s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" +
-          s" to be inferred from values.  To avoid this error for labels with > $maxNumClasses" +
-          s" classes, specify numClasses explicitly in the metadata; this can be done by applying" +
-          s" StringIndexer to the label column.")
-        logInfo(this.getClass.getCanonicalName + s" inferred $numClasses classes for" +
-          s" labelCol=$labelCol since numClasses was not specified in the column metadata.")
-        numClasses
-    }
-  }
 }
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index ec9e779709d..688d2d18f48 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -117,14 +117,13 @@ class DecisionTreeClassifier @Since("1.4.0") (
     instr.logPipelineStage(this)
     instr.logDataset(dataset)
     val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
-    val numClasses = getNumClasses(dataset)
+    val numClasses = getNumClasses(dataset, $(labelCol))
 
     if (isDefined(thresholds)) {
       require($(thresholds).length == numClasses, this.getClass.getSimpleName +
         ".train() called with non-matching numClasses and thresholds.length." +
         s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
     }
-    validateNumClasses(numClasses)
 
     val instances = dataset.select(
       checkClassificationLabels($(labelCol), Some(numClasses)),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
index a2e6f0c49ee..51f312cf183 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
@@ -190,7 +190,7 @@ class FMClassifier @Since("3.0.0") (
       miniBatchFraction, initStd, maxIter, stepSize, tol, solver, thresholds)
     instr.logNumClasses(numClasses)
 
-    val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+    val numFeatures = getNumFeatures(dataset, $(featuresCol))
     instr.logNumFeatures(numFeatures)
 
     val handlePersistence = dataset.storageLevel == StorageLevel.NONE
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index a767bc01445..3910beda3d0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -22,14 +22,13 @@ import org.json4s.JsonDSL._
 
 import org.apache.spark.annotation.Since
 import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
 import org.apache.spark.ml.tree._
 import org.apache.spark.ml.tree.impl.GradientBoostedTrees
 import org.apache.spark.ml.util._
-import org.apache.spark.ml.util.DatasetUtils._
+import org.apache.spark.ml.util.DatasetUtils.extractInstances
 import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.ml.util.Instrumentation.instrumented
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -169,21 +168,12 @@ class GBTClassifier @Since("1.4.0") (
 
   override protected def train(
       dataset: Dataset[_]): GBTClassificationModel = instrumented { instr =>
-
-    def extractInstances(df: Dataset[_]) = {
-      df.select(
-        checkClassificationLabels($(labelCol), Some(2)),
-        checkNonNegativeWeights(get(weightCol)),
-        checkNonNanVectors($(featuresCol))
-      ).rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v) }
-    }
-
     val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
     val (trainDataset, validationDataset) = if (withValidation) {
-      (extractInstances(dataset.filter(not(col($(validationIndicatorCol))))),
-        extractInstances(dataset.filter(col($(validationIndicatorCol)))))
+      (extractInstances(this, dataset.filter(not(col($(validationIndicatorCol)))), Some(2)),
+        extractInstances(this, dataset.filter(col($(validationIndicatorCol))), Some(2)))
     } else {
-      (extractInstances(dataset), null)
+      (extractInstances(this, dataset, Some(2)), null)
     }
 
     val numClasses = 2
@@ -390,7 +380,7 @@ class GBTClassificationModel private[ml](
    */
   @Since("2.4.0")
   def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = {
-    val data = extractInstances(dataset)
+    val data = extractInstances(this, dataset, Some(2))
     GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss,
       OldAlgo.Classification)
   }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 25f4e103ac7..048e5949e1c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -141,7 +141,7 @@ class RandomForestClassifier @Since("1.4.0") (
     instr.logDataset(dataset)
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
-    val numClasses: Int = getNumClasses(dataset)
+    val numClasses = getNumClasses(dataset, $(labelCol))
 
     if (isDefined(thresholds)) {
       require($(thresholds).length == numClasses, this.getClass.getSimpleName +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index bc2fcc03768..03315554b81 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -381,7 +381,7 @@ class GaussianMixture @Since("2.0.0") (
     val spark = dataset.sparkSession
     import spark.implicits._
 
-    val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+    val numFeatures = getNumFeatures(dataset, $(featuresCol))
     require(numFeatures < GaussianMixture.MAX_NUM_FEATURES, s"GaussianMixture cannot handle more " +
       s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" +
       s" matrix is quadratic in the number of features.")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 93b66f3ab70..1a97eb29100 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -18,11 +18,10 @@
 package org.apache.spark.ml.evaluation
 
 import org.apache.spark.annotation.Since
-import org.apache.spark.ml.functions.checkNonNegativeWeight
 import org.apache.spark.ml.linalg.{Vector, VectorUDT}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, MetadataUtils, SchemaUtils}
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
 import org.apache.spark.sql.{Dataset, Row}
 import org.apache.spark.sql.functions._
@@ -129,8 +128,8 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
       dataset.select(
         col($(rawPredictionCol)),
         col($(labelCol)).cast(DoubleType),
-        if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
-        else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))).rdd.map {
+        DatasetUtils.checkNonNegativeWeights(get(weightCol))
+      ).rdd.map {
         case Row(rawPrediction: Vector, label: Double, weight: Double) =>
           (rawPrediction(1), label, weight)
         case Row(rawPrediction: Double, label: Double, weight: Double) =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
index fa2c25a5912..143e26f2f74 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
@@ -18,13 +18,11 @@
 package org.apache.spark.ml.evaluation
 
 import org.apache.spark.annotation.Since
-import org.apache.spark.ml.functions.checkNonNegativeWeight
 import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
 import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol, HasWeightCol}
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.Dataset
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DoubleType
 
 /**
  * Evaluator for clustering results.
@@ -130,18 +128,13 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str
       SchemaUtils.checkNumericType(schema, $(weightCol))
     }
 
-    val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
-
-    val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol))
-    val df = if (!isDefined(weightCol) || $(weightCol).isEmpty) {
-      dataset.select(col($(predictionCol)),
-        vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata),
-        lit(1.0).as(weightColName))
-    } else {
-      dataset.select(col($(predictionCol)),
-        vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata),
-        checkNonNegativeWeight(col(weightColName).cast(DoubleType)))
-    }
+    val df = dataset.select(
+      col($(predictionCol)),
+      DatasetUtils.columnToVector(dataset, $(featuresCol))
+        .as($(featuresCol), dataset.schema($(featuresCol)).metadata),
+      DatasetUtils.checkNonNegativeWeights(get(weightCol))
+        .as(if (!isDefined(weightCol)) "weightCol" else $(weightCol))
+    )
 
     val metrics = new ClusteringMetrics(df)
     metrics.setDistanceMeasure($(distanceMeasure))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala
index ffeb9492777..0106c872297 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Since
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors}
-import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.ml.util.DatasetUtils._
 import org.apache.spark.sql.{Column, DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.DoubleType
@@ -293,7 +293,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
       predictionCol: String,
       featuresCol: String,
       weightCol: String): Map[Double, ClusterStats] = {
-    val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol)
+    val numFeatures = getNumFeatures(df, featuresCol)
     val clustersStatsRDD = df.select(
       col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm"), col(weightCol))
       .rdd
@@ -509,7 +509,7 @@ private[evaluation] object CosineSilhouette extends Silhouette {
       featuresCol: String,
       predictionCol: String,
       weightCol: String): Map[Double, (Vector, Double)] = {
-    val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol)
+    val numFeatures = getNumFeatures(df, featuresCol)
     val clustersStatsRDD = df.select(
       col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName), col(weightCol))
       .rdd
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index beeefde8c5f..023987d09ba 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.ml.evaluation
 
 import org.apache.spark.annotation.Since
-import org.apache.spark.ml.functions.checkNonNegativeWeight
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
@@ -180,18 +179,13 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
     SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
     SchemaUtils.checkNumericType(schema, $(labelCol))
 
-    val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
-      checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
-    } else {
-      lit(1.0)
-    }
-
     if ($(metricName) == "logLoss") {
       // probabilityCol is only needed to compute logloss
       require(schema.fieldNames.contains($(probabilityCol)),
         "probabilityCol is needed to compute logloss")
     }
 
+    val w = DatasetUtils.checkNonNegativeWeights(get(weightCol))
     val rdd = if (schema.fieldNames.contains($(probabilityCol))) {
       val p = DatasetUtils.columnToVector(dataset, $(probabilityCol))
       dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType), w, p)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 902869cc681..9503e9ea11b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -18,10 +18,9 @@
 package org.apache.spark.ml.evaluation
 
 import org.apache.spark.annotation.Since
-import org.apache.spark.ml.functions.checkNonNegativeWeight
 import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
 import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
-import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.evaluation.RegressionMetrics
 import org.apache.spark.sql.{Dataset, Row}
 import org.apache.spark.sql.functions._
@@ -120,12 +119,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
     SchemaUtils.checkNumericType(schema, $(labelCol))
 
     val predictionAndLabelsWithWeights = dataset
-      .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType),
-        if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
-        else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)))
-      .rdd
-      .map { case Row(prediction: Double, label: Double, weight: Double) =>
-        (prediction, label, weight) }
+      .select(
+        col($(predictionCol)).cast(DoubleType),
+        col($(labelCol)).cast(DoubleType),
+        DatasetUtils.checkNonNegativeWeights(get(weightCol))
+      ).rdd.map { case Row(prediction: Double, label: Double, weight: Double) =>
+        (prediction, label, weight)
+      }
     new RegressionMetrics(predictionAndLabelsWithWeights, $(throughOrigin))
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
index 7963fc88697..5254762d210 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
@@ -346,7 +346,7 @@ private[ml] abstract class LSH[T <: LSHModel[T]]
 
   override def fit(dataset: Dataset[_]): T = {
     transformSchema(dataset.schema, logging = true)
-    val inputDim = MetadataUtils.getNumFeatures(dataset, $(inputCol))
+    val inputDim = DatasetUtils.getNumFeatures(dataset, $(inputCol))
     val model = createRawLSHModel(inputDim).setParent(this)
     copyValues(model)
   }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
index e8f325ec584..85352d6bcbd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
@@ -145,7 +145,7 @@ class RobustScaler @Since("3.0.0") (@Since("3.0.0") override val uid: String)
   override def fit(dataset: Dataset[_]): RobustScalerModel = {
     transformSchema(dataset.schema, logging = true)
 
-    val numFeatures = MetadataUtils.getNumFeatures(dataset, $(inputCol))
+    val numFeatures = DatasetUtils.getNumFeatures(dataset, $(inputCol))
     val vectors = dataset.select($(inputCol)).rdd.map {
       case Row(vec: Vector) =>
         require(vec.size == numFeatures,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala
index e24593a01b6..1afab326dd7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala
@@ -206,7 +206,7 @@ private[ml] abstract class Selector[T <: SelectorModel[T]]
     val spark = dataset.sparkSession
     import spark.implicits._
 
-    val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+    val numFeatures = DatasetUtils.getNumFeatures(dataset, $(featuresCol))
     val resultDF = getSelectionTestResult(dataset.toDF)
 
     def getTopIndices(k: Int): Array[Int] = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
index 7412c42986f..3b43404072d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
@@ -164,7 +164,7 @@ final class UnivariateFeatureSelector @Since("3.1.1")(@Since("3.1.1") override v
   @Since("3.1.1")
   override def fit(dataset: Dataset[_]): UnivariateFeatureSelectorModel = {
     transformSchema(dataset.schema, logging = true)
-    val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+    val numFeatures = DatasetUtils.getNumFeatures(dataset, $(featuresCol))
 
     var threshold = Double.NaN
     if (isSet(selectionThreshold)) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 874b4213872..0e571ad508f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -140,7 +140,7 @@ class VectorIndexer @Since("1.4.0") (
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): VectorIndexerModel = {
     transformSchema(dataset.schema, logging = true)
-    val numFeatures = MetadataUtils.getNumFeatures(dataset, $(inputCol))
+    val numFeatures = DatasetUtils.getNumFeatures(dataset, $(inputCol))
     val vectorDataset = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
     val maxCats = $(maxCategories)
     val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/functions.scala b/mllib/src/main/scala/org/apache/spark/ml/functions.scala
index 43622a4f3ed..2bd7233f3ac 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/functions.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/functions.scala
@@ -85,10 +85,4 @@ object functions {
   def array_to_vector(v: Column): Column = {
     arrayToVectorUdf(v)
   }
-
-  private[ml] def checkNonNegativeWeight = udf {
-    value: Double =>
-      require(value >= 0, s"illegal weight value: $value. weight must be >= 0.0.")
-      value
-  }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
index c0178ac6c76..e6e8c2f1fa4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
@@ -413,7 +413,7 @@ class FMRegressor @Since("3.0.0") (
     instr.logParams(this, factorSize, fitIntercept, fitLinear, regParam,
       miniBatchFraction, initStd, maxIter, stepSize, tol, solver)
 
-    val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+    val numFeatures = getNumFeatures(dataset, $(featuresCol))
     instr.logNumFeatures(numFeatures)
 
     val handlePersistence = dataset.storageLevel == StorageLevel.NONE
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 10a203e9ee6..0c58cc2449b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -22,13 +22,12 @@ import org.json4s.JsonDSL._
 
 import org.apache.spark.annotation.Since
 import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg.{BLAS, Vector}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree._
 import org.apache.spark.ml.tree.impl.GradientBoostedTrees
 import org.apache.spark.ml.util._
-import org.apache.spark.ml.util.DatasetUtils._
+import org.apache.spark.ml.util.DatasetUtils.extractInstances
 import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.ml.util.Instrumentation.instrumented
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -166,21 +165,12 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
   def setWeightCol(value: String): this.type = set(weightCol, value)
 
   override protected def train(dataset: Dataset[_]): GBTRegressionModel = instrumented { instr =>
-
-    def extractInstances(df: Dataset[_]) = {
-      df.select(
-        checkRegressionLabels($(labelCol)),
-        checkNonNegativeWeights(get(weightCol)),
-        checkNonNanVectors($(featuresCol))
-      ).rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v) }
-    }
-
     val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
     val (trainDataset, validationDataset) = if (withValidation) {
-      (extractInstances(dataset.filter(not(col($(validationIndicatorCol))))),
-        extractInstances(dataset.filter(col($(validationIndicatorCol)))))
+      (extractInstances(this, dataset.filter(not(col($(validationIndicatorCol))))),
+        extractInstances(this, dataset.filter(col($(validationIndicatorCol)))))
     } else {
-      (extractInstances(dataset), null)
+      (extractInstances(this, dataset), null)
     }
 
     instr.logPipelineStage(this)
@@ -349,7 +339,7 @@ class GBTRegressionModel private[ml](
    */
   @Since("2.4.0")
   def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = {
-    val data = extractInstances(dataset)
+    val data = extractInstances(this, dataset)
     GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights,
       convertToOldLossType(loss), OldAlgo.Regression)
   }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 88581d03084..6d8507239eb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -384,7 +384,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
     instr.logParams(this, labelCol, featuresCol, weightCol, offsetCol, predictionCol,
       linkPredictionCol, family, solver, fitIntercept, link, maxIter, regParam, tol,
       aggregationDepth)
-    val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+    val numFeatures = getNumFeatures(dataset, $(featuresCol))
     instr.logNumFeatures(numFeatures)
 
     if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index a53ef8c79b4..46986249e0b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -338,7 +338,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
     }
 
     // Extract the number of features before deciding optimization solver.
-    val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+    val numFeatures = getNumFeatures(dataset, $(featuresCol))
     instr.logNumFeatures(numFeatures)
 
     val instances = dataset.select(
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
index c32e901e5cd..130790ac909 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
@@ -17,7 +17,13 @@
 
 package org.apache.spark.ml.util
 
+import org.apache.spark.SparkException
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.PredictorParams
+import org.apache.spark.ml.classification.ClassifierParams
+import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.param.shared.HasWeightCol
 import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
@@ -25,7 +31,7 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 
 
-private[spark] object DatasetUtils {
+private[spark] object DatasetUtils extends Logging {
 
   private[ml] def checkNonNanValues(colName: String, displayed: String): Column = {
     val casted = col(colName).cast(DoubleType)
@@ -96,6 +102,26 @@ private[spark] object DatasetUtils {
     }
   }
 
+  private[ml] def extractInstances(
+      p: PredictorParams,
+      df: Dataset[_],
+      numClasses: Option[Int] = None): RDD[Instance] = {
+    val labelCol = p match {
+      case c: ClassifierParams =>
+        checkClassificationLabels(c.getLabelCol, numClasses)
+      case _ => // TODO: there is no RegressorParams, maybe add it in the future?
+        checkRegressionLabels(p.getLabelCol)
+    }
+
+    val weightCol = p match {
+      case w: HasWeightCol => checkNonNegativeWeights(w.get(w.weightCol))
+      case _ => lit(1.0)
+    }
+
+    df.select(labelCol, weightCol, checkNonNanVectors(p.getFeaturesCol))
+      .rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v) }
+  }
+
   /**
    * Cast a column in a Dataset to Vector type.
    *
@@ -138,4 +164,58 @@ private[spark] object DatasetUtils {
       case Row(point: Vector) => OldVectors.fromML(point)
     }
   }
+
+  /**
+   * Get the number of classes.  This looks in column metadata first, and if that is missing,
+   * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
+   * by finding the maximum label value.
+   *
+   * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
+   * such as in `extractLabeledPoints()`.
+   *
+   * @param dataset  Dataset which contains a column [[labelCol]]
+   * @param maxNumClasses  Maximum number of classes allowed when inferred from data.  If numClasses
+   *                       is specified in the metadata, then maxNumClasses is ignored.
+   * @return  number of classes
+   * @throws IllegalArgumentException  if metadata does not specify numClasses, and the
+   *                                   actual numClasses exceeds maxNumClasses
+   */
+  private[ml] def getNumClasses(
+      dataset: Dataset[_],
+      labelCol: String,
+      maxNumClasses: Int = 100): Int = {
+    MetadataUtils.getNumClasses(dataset.schema(labelCol)) match {
+      case Some(n: Int) => n
+      case None =>
+        // Get number of classes from dataset itself.
+        val maxLabelRow: Array[Row] = dataset
+          .select(max(checkClassificationLabels(labelCol, Some(maxNumClasses))))
+          .take(1)
+        if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) {
+          throw new SparkException("ML algorithm was given empty dataset.")
+        }
+        val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0)
+        require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" +
+          s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})")
+        val numClasses = maxDoubleLabel.toInt + 1
+        require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" +
+          s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" +
+          s" to be inferred from values.  To avoid this error for labels with > $maxNumClasses" +
+          s" classes, specify numClasses explicitly in the metadata; this can be done by applying" +
+          s" StringIndexer to the label column.")
+        logInfo(this.getClass.getCanonicalName + s" inferred $numClasses classes for" +
+          s" labelCol=$labelCol since numClasses was not specified in the column metadata.")
+        numClasses
+    }
+  }
+
+  /**
+   * Obtain the number of features in a vector column.
+   * If no metadata is available, extract it from the dataset.
+   */
+  private[ml] def getNumFeatures(dataset: Dataset[_], vectorCol: String): Int = {
+    MetadataUtils.getNumFeatures(dataset.schema(vectorCol)).getOrElse {
+      dataset.select(columnToVector(dataset, vectorCol)).head.getAs[Vector](0).size
+    }
+  }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
index 6db0408e8d2..631261af249 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
@@ -20,8 +20,7 @@ package org.apache.spark.ml.util
 import scala.collection.immutable.HashMap
 
 import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.Dataset
+import org.apache.spark.ml.linalg.VectorUDT
 import org.apache.spark.sql.types.StructField
 
 
@@ -42,17 +41,6 @@ private[spark] object MetadataUtils {
     }
   }
 
-  /**
-   * Obtain the number of features in a vector column.
-   * If no metadata is available, extract it from the dataset.
-   */
-  def getNumFeatures(dataset: Dataset[_], vectorCol: String): Int = {
-    getNumFeatures(dataset.schema(vectorCol)).getOrElse {
-      dataset.select(DatasetUtils.columnToVector(dataset, vectorCol))
-        .head.getAs[Vector](0).size
-    }
-  }
-
   /**
    * Examine a schema to identify the number of features in a vector column.
    * Returns None if the number of features is not specified.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
index 1aea4b47cd8..57cd99ecced 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
@@ -22,9 +22,8 @@ import org.apache.spark.ml.classification.ClassifierSuite.MockClassifier
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset}
 
 class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -35,41 +34,6 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
     labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }.toDF()
   }
 
-  test("extractLabeledPoints") {
-    val c = new MockClassifier
-    // Valid dataset
-    val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0))
-    c.extractLabeledPoints(df0, 6).count()
-    // Invalid datasets
-    val df1 = getTestData(Seq(0.0, -2.0, 1.0, 5.0))
-    withClue("Classifier should fail if label is negative") {
-      val e: SparkException = intercept[SparkException] {
-        c.extractLabeledPoints(df1, 6).count()
-      }
-      assert(e.getMessage.contains("given dataset with invalid label"))
-    }
-    val df2 = getTestData(Seq(0.0, 2.1, 1.0, 5.0))
-    withClue("Classifier should fail if label is not an integer") {
-      val e: SparkException = intercept[SparkException] {
-        c.extractLabeledPoints(df2, 6).count()
-      }
-      assert(e.getMessage.contains("given dataset with invalid label"))
-    }
-    // extractLabeledPoints with numClasses specified
-    withClue("Classifier should fail if label is >= numClasses") {
-      val e: SparkException = intercept[SparkException] {
-        c.extractLabeledPoints(df0, numClasses = 5).count()
-      }
-      assert(e.getMessage.contains("given dataset with invalid label"))
-    }
-    withClue("Classifier.extractLabeledPoints should fail if numClasses <= 0") {
-      val e: IllegalArgumentException = intercept[IllegalArgumentException] {
-        c.extractLabeledPoints(df0, numClasses = 0).count()
-      }
-      assert(e.getMessage.contains("but requires numClasses > 0"))
-    }
-  }
-
   test("getNumClasses") {
     val c = new MockClassifier
     // Valid dataset
@@ -122,10 +86,8 @@ object ClassifierSuite {
     override def train(dataset: Dataset[_]): MockClassificationModel =
       throw new UnsupportedOperationException()
 
-    // Make methods public
-    override def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] =
-      super.extractLabeledPoints(dataset, numClasses)
-    def getNumClasses(dataset: Dataset[_]): Int = super.getNumClasses(dataset)
+    def getNumClasses(dataset: Dataset[_]): Int =
+      DatasetUtils.getNumClasses(dataset, $(labelCol))
   }
 
   class MockClassificationModel(override val uid: String)
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 927384d4f1e..01fc5d65c03 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -40,7 +40,21 @@ object MimaExcludes {
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.recommendation.ALSModel.checkedCast"),
 
     // [SPARK-39110] Show metrics properties in HistoryServer environment tab
-    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationEnvironmentInfo.this")
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationEnvironmentInfo.this"),
+
+    // [SPARK-38775][ML] Cleanup validation functions
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.PredictionModel.extractInstances"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.Predictor.extractInstances"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.Predictor.extractLabeledPoints"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.ClassificationModel.extractInstances"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.extractInstances"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.extractLabeledPoints"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateNumClasses"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateLabel"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses$default$2"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRest.extractInstances"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.extractInstances")
   )
 
   // Exclude rules for 3.3.x from 3.2.0


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org