You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2022/06/19 04:52:01 UTC
[spark] branch master updated: [SPARK-38775][ML] cleanup validation functions
This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 27ed89b7be5 [SPARK-38775][ML] cleanup validation functions
27ed89b7be5 is described below
commit 27ed89b7be5ebb91e4a0b106b1669a7867a6012d
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Sat Jun 18 21:51:50 2022 -0700
[SPARK-38775][ML] cleanup validation functions
### What changes were proposed in this pull request?
1, remove unused `extractInstances` and `extractLabeledPoints` in `Predictor`;
2, remove unused `checkNonNegativeWeight` in `function`;
3, move `getNumClasses` from `Clasifier` to `DatasetUtils`;
4, move `getNumFeatures` from `MetadataUtils` to `DatasetUtils`;
### Why are the changes needed?
to unify to methods
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
existing testsuites
Closes #36049 from zhengruifeng/validate_cleanup.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
.../spark/examples/ml/DeveloperApiExample.scala | 7 +-
.../main/scala/org/apache/spark/ml/Predictor.scala | 51 +---------
.../spark/ml/classification/Classifier.scala | 106 +--------------------
.../ml/classification/DecisionTreeClassifier.scala | 3 +-
.../spark/ml/classification/FMClassifier.scala | 2 +-
.../spark/ml/classification/GBTClassifier.scala | 20 +---
.../ml/classification/RandomForestClassifier.scala | 2 +-
.../spark/ml/clustering/GaussianMixture.scala | 2 +-
.../evaluation/BinaryClassificationEvaluator.scala | 7 +-
.../spark/ml/evaluation/ClusteringEvaluator.scala | 21 ++--
.../spark/ml/evaluation/ClusteringMetrics.scala | 6 +-
.../MulticlassClassificationEvaluator.scala | 8 +-
.../spark/ml/evaluation/RegressionEvaluator.scala | 16 ++--
.../scala/org/apache/spark/ml/feature/LSH.scala | 2 +-
.../org/apache/spark/ml/feature/RobustScaler.scala | 2 +-
.../org/apache/spark/ml/feature/Selector.scala | 2 +-
.../ml/feature/UnivariateFeatureSelector.scala | 2 +-
.../apache/spark/ml/feature/VectorIndexer.scala | 2 +-
.../main/scala/org/apache/spark/ml/functions.scala | 6 --
.../apache/spark/ml/regression/FMRegressor.scala | 2 +-
.../apache/spark/ml/regression/GBTRegressor.scala | 20 +---
.../regression/GeneralizedLinearRegression.scala | 2 +-
.../spark/ml/regression/LinearRegression.scala | 2 +-
.../org/apache/spark/ml/util/DatasetUtils.scala | 82 +++++++++++++++-
.../org/apache/spark/ml/util/MetadataUtils.scala | 14 +--
.../spark/ml/classification/ClassifierSuite.scala | 44 +--------
project/MimaExcludes.scala | 16 +++-
27 files changed, 152 insertions(+), 297 deletions(-)
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index 487cb27b93f..bfee3301f8e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -24,6 +24,7 @@ import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{Dataset, Row, SparkSession}
+import org.apache.spark.sql.functions.col
/**
* A simple example demonstrating how to write your own learning algorithm using Estimator,
@@ -120,8 +121,10 @@ private class MyLogisticRegression(override val uid: String)
// This method is used by fit()
override protected def train(dataset: Dataset[_]): MyLogisticRegressionModel = {
- // Extract columns from data using helper method.
- val oldDataset = extractLabeledPoints(dataset)
+ // Extract columns from data.
+ val oldDataset = dataset.select(col($(labelCol)).cast("double"), col($(featuresCol)))
+ .rdd
+ .map { case Row(l: Double, f: Vector) => LabeledPoint(l, f) }
// Do learning to estimate the coefficients vector.
val numFeatures = oldDataset.take(1)(0).features.size
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index e0b128e3698..9c6eb880c80 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -18,14 +18,11 @@
package org.apache.spark.ml
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.feature.{Instance, LabeledPoint}
-import org.apache.spark.ml.functions.checkNonNegativeWeight
-import org.apache.spark.ml.linalg.{Vector, VectorUDT}
+import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
@@ -63,40 +60,6 @@ private[ml] trait PredictorParams extends Params
}
SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
}
-
- /**
- * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
- * and put it in an RDD with strong types.
- */
- protected def extractInstances(dataset: Dataset[_]): RDD[Instance] = {
- val w = this match {
- case p: HasWeightCol =>
- if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
- checkNonNegativeWeight((col($(p.weightCol)).cast(DoubleType)))
- } else {
- lit(1.0)
- }
- }
-
- dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
- case Row(label: Double, weight: Double, features: Vector) =>
- Instance(label, weight, features)
- }
- }
-
- /**
- * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
- * and put it in an RDD with strong types.
- * Validate the output instances with the given function.
- */
- protected def extractInstances(
- dataset: Dataset[_],
- validateInstance: Instance => Unit): RDD[Instance] = {
- extractInstances(dataset).map { instance =>
- validateInstance(instance)
- instance
- }
- }
}
/**
@@ -176,16 +139,6 @@ abstract class Predictor[
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, featuresDataType)
}
-
- /**
- * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
- * and put it in an RDD with strong types.
- */
- protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
- dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
- case Row(label: Double, features: Vector) => LabeledPoint(label, features)
- }
- }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 09324e2087d..2d7719a29ca 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -17,17 +17,13 @@
package org.apache.spark.ml.classification
-import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
-import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util._
-import org.apache.spark.ml.util.DatasetUtils._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, StructType}
@@ -44,23 +40,6 @@ private[spark] trait ClassifierParams
val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT)
}
-
- /**
- * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
- * and put it in an RDD with strong types.
- * Validates the label on the classifier is a valid integer in the range [0, numClasses).
- */
- protected def extractInstances(
- dataset: Dataset[_],
- numClasses: Int): RDD[Instance] = {
- val validateInstance = (instance: Instance) => {
- val label = instance.label
- require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
- s" dataset with invalid label $label. Labels must be integers in range" +
- s" [0, $numClasses).")
- }
- extractInstances(dataset, validateInstance)
- }
}
/**
@@ -81,89 +60,6 @@ abstract class Classifier[
def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]
// TODO: defaultEvaluator (follow-up PR)
-
- /**
- * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
- * and put it in an RDD with strong types.
- *
- * @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]])
- * and features (`Vector`).
- * @param numClasses Number of classes label can take. Labels must be integers in the range
- * [0, numClasses).
- * @note Throws `SparkException` if any label is a non-integer or is negative
- */
- protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
- validateNumClasses(numClasses)
- dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
- case Row(label: Double, features: Vector) =>
- validateLabel(label, numClasses)
- LabeledPoint(label, features)
- }
- }
-
- /**
- * Validates that number of classes is greater than zero.
- *
- * @param numClasses Number of classes label can take.
- */
- protected def validateNumClasses(numClasses: Int): Unit = {
- require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
- s" $numClasses, but requires numClasses > 0.")
- }
-
- /**
- * Validates the label on the classifier is a valid integer in the range [0, numClasses).
- *
- * @param label The label to validate.
- * @param numClasses Number of classes label can take. Labels must be integers in the range
- * [0, numClasses).
- */
- protected def validateLabel(label: Double, numClasses: Int): Unit = {
- require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
- s" dataset with invalid label $label. Labels must be integers in range" +
- s" [0, $numClasses).")
- }
-
- /**
- * Get the number of classes. This looks in column metadata first, and if that is missing,
- * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
- * by finding the maximum label value.
- *
- * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
- * such as in `extractLabeledPoints()`.
- *
- * @param dataset Dataset which contains a column [[labelCol]]
- * @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses
- * is specified in the metadata, then maxNumClasses is ignored.
- * @return number of classes
- * @throws IllegalArgumentException if metadata does not specify numClasses, and the
- * actual numClasses exceeds maxNumClasses
- */
- protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = {
- MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
- case Some(n: Int) => n
- case None =>
- // Get number of classes from dataset itself.
- val maxLabelRow: Array[Row] = dataset
- .select(max(checkClassificationLabels($(labelCol), Some(maxNumClasses))))
- .take(1)
- if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) {
- throw new SparkException("ML algorithm was given empty dataset.")
- }
- val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0)
- require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" +
- s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})")
- val numClasses = maxDoubleLabel.toInt + 1
- require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" +
- s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" +
- s" to be inferred from values. To avoid this error for labels with > $maxNumClasses" +
- s" classes, specify numClasses explicitly in the metadata; this can be done by applying" +
- s" StringIndexer to the label column.")
- logInfo(this.getClass.getCanonicalName + s" inferred $numClasses classes for" +
- s" labelCol=$labelCol since numClasses was not specified in the column metadata.")
- numClasses
- }
- }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index ec9e779709d..688d2d18f48 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -117,14 +117,13 @@ class DecisionTreeClassifier @Since("1.4.0") (
instr.logPipelineStage(this)
instr.logDataset(dataset)
val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val numClasses = getNumClasses(dataset)
+ val numClasses = getNumClasses(dataset, $(labelCol))
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
- validateNumClasses(numClasses)
val instances = dataset.select(
checkClassificationLabels($(labelCol), Some(numClasses)),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
index a2e6f0c49ee..51f312cf183 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
@@ -190,7 +190,7 @@ class FMClassifier @Since("3.0.0") (
miniBatchFraction, initStd, maxIter, stepSize, tol, solver, thresholds)
instr.logNumClasses(numClasses)
- val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+ val numFeatures = getNumFeatures(dataset, $(featuresCol))
instr.logNumFeatures(numFeatures)
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index a767bc01445..3910beda3d0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -22,14 +22,13 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util._
-import org.apache.spark.ml.util.DatasetUtils._
+import org.apache.spark.ml.util.DatasetUtils.extractInstances
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -169,21 +168,12 @@ class GBTClassifier @Since("1.4.0") (
override protected def train(
dataset: Dataset[_]): GBTClassificationModel = instrumented { instr =>
-
- def extractInstances(df: Dataset[_]) = {
- df.select(
- checkClassificationLabels($(labelCol), Some(2)),
- checkNonNegativeWeights(get(weightCol)),
- checkNonNanVectors($(featuresCol))
- ).rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v) }
- }
-
val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
val (trainDataset, validationDataset) = if (withValidation) {
- (extractInstances(dataset.filter(not(col($(validationIndicatorCol))))),
- extractInstances(dataset.filter(col($(validationIndicatorCol)))))
+ (extractInstances(this, dataset.filter(not(col($(validationIndicatorCol)))), Some(2)),
+ extractInstances(this, dataset.filter(col($(validationIndicatorCol))), Some(2)))
} else {
- (extractInstances(dataset), null)
+ (extractInstances(this, dataset, Some(2)), null)
}
val numClasses = 2
@@ -390,7 +380,7 @@ class GBTClassificationModel private[ml](
*/
@Since("2.4.0")
def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = {
- val data = extractInstances(dataset)
+ val data = extractInstances(this, dataset, Some(2))
GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss,
OldAlgo.Classification)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 25f4e103ac7..048e5949e1c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -141,7 +141,7 @@ class RandomForestClassifier @Since("1.4.0") (
instr.logDataset(dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val numClasses: Int = getNumClasses(dataset)
+ val numClasses = getNumClasses(dataset, $(labelCol))
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index bc2fcc03768..03315554b81 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -381,7 +381,7 @@ class GaussianMixture @Since("2.0.0") (
val spark = dataset.sparkSession
import spark.implicits._
- val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+ val numFeatures = getNumFeatures(dataset, $(featuresCol))
require(numFeatures < GaussianMixture.MAX_NUM_FEATURES, s"GaussianMixture cannot handle more " +
s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" +
s" matrix is quadratic in the number of features.")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 93b66f3ab70..1a97eb29100 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -18,11 +18,10 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, MetadataUtils, SchemaUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions._
@@ -129,8 +128,8 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
dataset.select(
col($(rawPredictionCol)),
col($(labelCol)).cast(DoubleType),
- if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
- else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))).rdd.map {
+ DatasetUtils.checkNonNegativeWeights(get(weightCol))
+ ).rdd.map {
case Row(rawPrediction: Vector, label: Double, weight: Double) =>
(rawPrediction(1), label, weight)
case Row(rawPrediction: Double, label: Double, weight: Double) =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
index fa2c25a5912..143e26f2f74 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
@@ -18,13 +18,11 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DoubleType
/**
* Evaluator for clustering results.
@@ -130,18 +128,13 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str
SchemaUtils.checkNumericType(schema, $(weightCol))
}
- val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
-
- val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol))
- val df = if (!isDefined(weightCol) || $(weightCol).isEmpty) {
- dataset.select(col($(predictionCol)),
- vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata),
- lit(1.0).as(weightColName))
- } else {
- dataset.select(col($(predictionCol)),
- vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata),
- checkNonNegativeWeight(col(weightColName).cast(DoubleType)))
- }
+ val df = dataset.select(
+ col($(predictionCol)),
+ DatasetUtils.columnToVector(dataset, $(featuresCol))
+ .as($(featuresCol), dataset.schema($(featuresCol)).metadata),
+ DatasetUtils.checkNonNegativeWeights(get(weightCol))
+ .as(if (!isDefined(weightCol)) "weightCol" else $(weightCol))
+ )
val metrics = new ClusteringMetrics(df)
metrics.setDistanceMeasure($(distanceMeasure))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala
index ffeb9492777..0106c872297 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors}
-import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.ml.util.DatasetUtils._
import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
@@ -293,7 +293,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
predictionCol: String,
featuresCol: String,
weightCol: String): Map[Double, ClusterStats] = {
- val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol)
+ val numFeatures = getNumFeatures(df, featuresCol)
val clustersStatsRDD = df.select(
col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm"), col(weightCol))
.rdd
@@ -509,7 +509,7 @@ private[evaluation] object CosineSilhouette extends Silhouette {
featuresCol: String,
predictionCol: String,
weightCol: String): Map[Double, (Vector, Double)] = {
- val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol)
+ val numFeatures = getNumFeatures(df, featuresCol)
val clustersStatsRDD = df.select(
col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName), col(weightCol))
.rdd
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index beeefde8c5f..023987d09ba 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -18,7 +18,6 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -180,18 +179,13 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
SchemaUtils.checkNumericType(schema, $(labelCol))
- val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
- checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
- } else {
- lit(1.0)
- }
-
if ($(metricName) == "logLoss") {
// probabilityCol is only needed to compute logloss
require(schema.fieldNames.contains($(probabilityCol)),
"probabilityCol is needed to compute logloss")
}
+ val w = DatasetUtils.checkNonNegativeWeights(get(weightCol))
val rdd = if (schema.fieldNames.contains($(probabilityCol))) {
val p = DatasetUtils.columnToVector(dataset, $(probabilityCol))
dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType), w, p)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 902869cc681..9503e9ea11b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -18,10 +18,9 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
-import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions._
@@ -120,12 +119,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
SchemaUtils.checkNumericType(schema, $(labelCol))
val predictionAndLabelsWithWeights = dataset
- .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType),
- if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
- else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)))
- .rdd
- .map { case Row(prediction: Double, label: Double, weight: Double) =>
- (prediction, label, weight) }
+ .select(
+ col($(predictionCol)).cast(DoubleType),
+ col($(labelCol)).cast(DoubleType),
+ DatasetUtils.checkNonNegativeWeights(get(weightCol))
+ ).rdd.map { case Row(prediction: Double, label: Double, weight: Double) =>
+ (prediction, label, weight)
+ }
new RegressionMetrics(predictionAndLabelsWithWeights, $(throughOrigin))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
index 7963fc88697..5254762d210 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
@@ -346,7 +346,7 @@ private[ml] abstract class LSH[T <: LSHModel[T]]
override def fit(dataset: Dataset[_]): T = {
transformSchema(dataset.schema, logging = true)
- val inputDim = MetadataUtils.getNumFeatures(dataset, $(inputCol))
+ val inputDim = DatasetUtils.getNumFeatures(dataset, $(inputCol))
val model = createRawLSHModel(inputDim).setParent(this)
copyValues(model)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
index e8f325ec584..85352d6bcbd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
@@ -145,7 +145,7 @@ class RobustScaler @Since("3.0.0") (@Since("3.0.0") override val uid: String)
override def fit(dataset: Dataset[_]): RobustScalerModel = {
transformSchema(dataset.schema, logging = true)
- val numFeatures = MetadataUtils.getNumFeatures(dataset, $(inputCol))
+ val numFeatures = DatasetUtils.getNumFeatures(dataset, $(inputCol))
val vectors = dataset.select($(inputCol)).rdd.map {
case Row(vec: Vector) =>
require(vec.size == numFeatures,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala
index e24593a01b6..1afab326dd7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala
@@ -206,7 +206,7 @@ private[ml] abstract class Selector[T <: SelectorModel[T]]
val spark = dataset.sparkSession
import spark.implicits._
- val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+ val numFeatures = DatasetUtils.getNumFeatures(dataset, $(featuresCol))
val resultDF = getSelectionTestResult(dataset.toDF)
def getTopIndices(k: Int): Array[Int] = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
index 7412c42986f..3b43404072d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
@@ -164,7 +164,7 @@ final class UnivariateFeatureSelector @Since("3.1.1")(@Since("3.1.1") override v
@Since("3.1.1")
override def fit(dataset: Dataset[_]): UnivariateFeatureSelectorModel = {
transformSchema(dataset.schema, logging = true)
- val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+ val numFeatures = DatasetUtils.getNumFeatures(dataset, $(featuresCol))
var threshold = Double.NaN
if (isSet(selectionThreshold)) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 874b4213872..0e571ad508f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -140,7 +140,7 @@ class VectorIndexer @Since("1.4.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): VectorIndexerModel = {
transformSchema(dataset.schema, logging = true)
- val numFeatures = MetadataUtils.getNumFeatures(dataset, $(inputCol))
+ val numFeatures = DatasetUtils.getNumFeatures(dataset, $(inputCol))
val vectorDataset = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
val maxCats = $(maxCategories)
val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/functions.scala b/mllib/src/main/scala/org/apache/spark/ml/functions.scala
index 43622a4f3ed..2bd7233f3ac 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/functions.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/functions.scala
@@ -85,10 +85,4 @@ object functions {
def array_to_vector(v: Column): Column = {
arrayToVectorUdf(v)
}
-
- private[ml] def checkNonNegativeWeight = udf {
- value: Double =>
- require(value >= 0, s"illegal weight value: $value. weight must be >= 0.0.")
- value
- }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
index c0178ac6c76..e6e8c2f1fa4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
@@ -413,7 +413,7 @@ class FMRegressor @Since("3.0.0") (
instr.logParams(this, factorSize, fitIntercept, fitLinear, regParam,
miniBatchFraction, initStd, maxIter, stepSize, tol, solver)
- val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+ val numFeatures = getNumFeatures(dataset, $(featuresCol))
instr.logNumFeatures(numFeatures)
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 10a203e9ee6..0c58cc2449b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -22,13 +22,12 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{BLAS, Vector}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util._
-import org.apache.spark.ml.util.DatasetUtils._
+import org.apache.spark.ml.util.DatasetUtils.extractInstances
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -166,21 +165,12 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
def setWeightCol(value: String): this.type = set(weightCol, value)
override protected def train(dataset: Dataset[_]): GBTRegressionModel = instrumented { instr =>
-
- def extractInstances(df: Dataset[_]) = {
- df.select(
- checkRegressionLabels($(labelCol)),
- checkNonNegativeWeights(get(weightCol)),
- checkNonNanVectors($(featuresCol))
- ).rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v) }
- }
-
val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
val (trainDataset, validationDataset) = if (withValidation) {
- (extractInstances(dataset.filter(not(col($(validationIndicatorCol))))),
- extractInstances(dataset.filter(col($(validationIndicatorCol)))))
+ (extractInstances(this, dataset.filter(not(col($(validationIndicatorCol))))),
+ extractInstances(this, dataset.filter(col($(validationIndicatorCol)))))
} else {
- (extractInstances(dataset), null)
+ (extractInstances(this, dataset), null)
}
instr.logPipelineStage(this)
@@ -349,7 +339,7 @@ class GBTRegressionModel private[ml](
*/
@Since("2.4.0")
def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = {
- val data = extractInstances(dataset)
+ val data = extractInstances(this, dataset)
GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights,
convertToOldLossType(loss), OldAlgo.Regression)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 88581d03084..6d8507239eb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -384,7 +384,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
instr.logParams(this, labelCol, featuresCol, weightCol, offsetCol, predictionCol,
linkPredictionCol, family, solver, fitIntercept, link, maxIter, regParam, tol,
aggregationDepth)
- val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+ val numFeatures = getNumFeatures(dataset, $(featuresCol))
instr.logNumFeatures(numFeatures)
if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index a53ef8c79b4..46986249e0b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -338,7 +338,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
}
// Extract the number of features before deciding optimization solver.
- val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+ val numFeatures = getNumFeatures(dataset, $(featuresCol))
instr.logNumFeatures(numFeatures)
val instances = dataset.select(
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
index c32e901e5cd..130790ac909 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
@@ -17,7 +17,13 @@
package org.apache.spark.ml.util
+import org.apache.spark.SparkException
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.PredictorParams
+import org.apache.spark.ml.classification.ClassifierParams
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
@@ -25,7 +31,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
-private[spark] object DatasetUtils {
+private[spark] object DatasetUtils extends Logging {
private[ml] def checkNonNanValues(colName: String, displayed: String): Column = {
val casted = col(colName).cast(DoubleType)
@@ -96,6 +102,26 @@ private[spark] object DatasetUtils {
}
}
+ private[ml] def extractInstances(
+ p: PredictorParams,
+ df: Dataset[_],
+ numClasses: Option[Int] = None): RDD[Instance] = {
+ val labelCol = p match {
+ case c: ClassifierParams =>
+ checkClassificationLabels(c.getLabelCol, numClasses)
+ case _ => // TODO: there is no RegressorParams, maybe add it in the future?
+ checkRegressionLabels(p.getLabelCol)
+ }
+
+ val weightCol = p match {
+ case w: HasWeightCol => checkNonNegativeWeights(w.get(w.weightCol))
+ case _ => lit(1.0)
+ }
+
+ df.select(labelCol, weightCol, checkNonNanVectors(p.getFeaturesCol))
+ .rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v) }
+ }
+
/**
* Cast a column in a Dataset to Vector type.
*
@@ -138,4 +164,58 @@ private[spark] object DatasetUtils {
case Row(point: Vector) => OldVectors.fromML(point)
}
}
+
+ /**
+ * Get the number of classes. This looks in column metadata first, and if that is missing,
+ * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
+ * by finding the maximum label value.
+ *
+ * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
+ * such as in `extractLabeledPoints()`.
+ *
+ * @param dataset Dataset which contains a column [[labelCol]]
+ * @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses
+ * is specified in the metadata, then maxNumClasses is ignored.
+ * @return number of classes
+ * @throws IllegalArgumentException if metadata does not specify numClasses, and the
+ * actual numClasses exceeds maxNumClasses
+ */
+ private[ml] def getNumClasses(
+ dataset: Dataset[_],
+ labelCol: String,
+ maxNumClasses: Int = 100): Int = {
+ MetadataUtils.getNumClasses(dataset.schema(labelCol)) match {
+ case Some(n: Int) => n
+ case None =>
+ // Get number of classes from dataset itself.
+ val maxLabelRow: Array[Row] = dataset
+ .select(max(checkClassificationLabels(labelCol, Some(maxNumClasses))))
+ .take(1)
+ if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) {
+ throw new SparkException("ML algorithm was given empty dataset.")
+ }
+ val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0)
+ require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" +
+ s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})")
+ val numClasses = maxDoubleLabel.toInt + 1
+ require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" +
+ s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" +
+ s" to be inferred from values. To avoid this error for labels with > $maxNumClasses" +
+ s" classes, specify numClasses explicitly in the metadata; this can be done by applying" +
+ s" StringIndexer to the label column.")
+ logInfo(this.getClass.getCanonicalName + s" inferred $numClasses classes for" +
+ s" labelCol=$labelCol since numClasses was not specified in the column metadata.")
+ numClasses
+ }
+ }
+
+ /**
+ * Obtain the number of features in a vector column.
+ * If no metadata is available, extract it from the dataset.
+ */
+ private[ml] def getNumFeatures(dataset: Dataset[_], vectorCol: String): Int = {
+ MetadataUtils.getNumFeatures(dataset.schema(vectorCol)).getOrElse {
+ dataset.select(columnToVector(dataset, vectorCol)).head.getAs[Vector](0).size
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
index 6db0408e8d2..631261af249 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
@@ -20,8 +20,7 @@ package org.apache.spark.ml.util
import scala.collection.immutable.HashMap
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.Dataset
+import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.sql.types.StructField
@@ -42,17 +41,6 @@ private[spark] object MetadataUtils {
}
}
- /**
- * Obtain the number of features in a vector column.
- * If no metadata is available, extract it from the dataset.
- */
- def getNumFeatures(dataset: Dataset[_], vectorCol: String): Int = {
- getNumFeatures(dataset.schema(vectorCol)).getOrElse {
- dataset.select(DatasetUtils.columnToVector(dataset, vectorCol))
- .head.getAs[Vector](0).size
- }
- }
-
/**
* Examine a schema to identify the number of features in a vector column.
* Returns None if the number of features is not specified.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
index 1aea4b47cd8..57cd99ecced 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
@@ -22,9 +22,8 @@ import org.apache.spark.ml.classification.ClassifierSuite.MockClassifier
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -35,41 +34,6 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }.toDF()
}
- test("extractLabeledPoints") {
- val c = new MockClassifier
- // Valid dataset
- val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0))
- c.extractLabeledPoints(df0, 6).count()
- // Invalid datasets
- val df1 = getTestData(Seq(0.0, -2.0, 1.0, 5.0))
- withClue("Classifier should fail if label is negative") {
- val e: SparkException = intercept[SparkException] {
- c.extractLabeledPoints(df1, 6).count()
- }
- assert(e.getMessage.contains("given dataset with invalid label"))
- }
- val df2 = getTestData(Seq(0.0, 2.1, 1.0, 5.0))
- withClue("Classifier should fail if label is not an integer") {
- val e: SparkException = intercept[SparkException] {
- c.extractLabeledPoints(df2, 6).count()
- }
- assert(e.getMessage.contains("given dataset with invalid label"))
- }
- // extractLabeledPoints with numClasses specified
- withClue("Classifier should fail if label is >= numClasses") {
- val e: SparkException = intercept[SparkException] {
- c.extractLabeledPoints(df0, numClasses = 5).count()
- }
- assert(e.getMessage.contains("given dataset with invalid label"))
- }
- withClue("Classifier.extractLabeledPoints should fail if numClasses <= 0") {
- val e: IllegalArgumentException = intercept[IllegalArgumentException] {
- c.extractLabeledPoints(df0, numClasses = 0).count()
- }
- assert(e.getMessage.contains("but requires numClasses > 0"))
- }
- }
-
test("getNumClasses") {
val c = new MockClassifier
// Valid dataset
@@ -122,10 +86,8 @@ object ClassifierSuite {
override def train(dataset: Dataset[_]): MockClassificationModel =
throw new UnsupportedOperationException()
- // Make methods public
- override def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] =
- super.extractLabeledPoints(dataset, numClasses)
- def getNumClasses(dataset: Dataset[_]): Int = super.getNumClasses(dataset)
+ def getNumClasses(dataset: Dataset[_]): Int =
+ DatasetUtils.getNumClasses(dataset, $(labelCol))
}
class MockClassificationModel(override val uid: String)
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 927384d4f1e..01fc5d65c03 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -40,7 +40,21 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.recommendation.ALSModel.checkedCast"),
// [SPARK-39110] Show metrics properties in HistoryServer environment tab
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationEnvironmentInfo.this")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationEnvironmentInfo.this"),
+
+ // [SPARK-38775][ML] Cleanup validation functions
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.PredictionModel.extractInstances"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.Predictor.extractInstances"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.Predictor.extractLabeledPoints"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.ClassificationModel.extractInstances"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.extractInstances"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.extractLabeledPoints"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateNumClasses"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateLabel"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses$default$2"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRest.extractInstances"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.extractInstances")
)
// Exclude rules for 3.3.x from 3.2.0
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org