You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2023/03/11 14:46:07 UTC

[spark] branch master updated: [SPARK-42747][ML] Fix incorrect internal status of LoR and AFT

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5a702f22f49 [SPARK-42747][ML] Fix incorrect internal status of LoR and AFT
5a702f22f49 is described below

commit 5a702f22f49ca6a1b6220ac645e3fce70ec5189d
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Sat Mar 11 08:45:54 2023 -0600

    [SPARK-42747][ML] Fix incorrect internal status of LoR and AFT
    
    ### What changes were proposed in this pull request?
    Add a hook `onParamChange` in `Params.{set, setDefault, clear}`, so that subclass can update the internal status within it.
    
    ### Why are the changes needed?
    In 3.1, we added internal auxiliary variables in LoR and AFT to optimize prediction/transformation.
    In LoR, when users call `model.{setThreshold, setThresholds}`, the internal status will be correctly updated.
    
    But users still can call `model.set(model.threshold, value)`, then the status will not be updated.
    And when users call `model.clear(model.threshold)`, the status should be updated with default threshold value 0.5.
    
    for example:
    ```
    import org.apache.spark.ml.linalg._
    import org.apache.spark.ml.classification._
    
    val df = Seq((1.0, 1.0, Vectors.dense(0.0, 5.0)), (0.0, 2.0, Vectors.dense(1.0, 2.0)), (1.0, 3.0, Vectors.dense(2.0, 1.0)), (0.0, 4.0, Vectors.dense(3.0, 3.0))).toDF("label", "weight", "features")
    
    val lor = new LogisticRegression().setWeightCol("weight")
    val model = lor.fit(df)
    
    val vec = Vectors.dense(0.0, 5.0)
    
    val p0 = model.predict(vec)                                               //  return 0.0
    model.setThreshold(0.05)                                                 //  change status
    val p1 = model.set(model.threshold, 0.5).predict(vec)   //  return 1.0; but should be 0.0
    val p2 = model.clear(model.threshold).predict(vec)      //  return 1.0; but should be 0.0
    ```
    
    what makes it even worse it that `pyspark.ml` always set params via `model.set(model.threshold, value)`, so the internal status is easily out of sync, see the example in [SPARK-42747](https://issues.apache.org/jira/browse/SPARK-42747)
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    added ut
    
    Closes #40367 from zhengruifeng/ml_param_hook.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Sean Owen <sr...@gmail.com>
---
 .../ml/classification/LogisticRegression.scala     | 54 +++++++++-------------
 .../scala/org/apache/spark/ml/param/params.scala   | 16 +++----
 .../ml/regression/AFTSurvivalRegression.scala      | 26 ++++++-----
 .../scala/org/apache/spark/ml/util/ReadWrite.scala |  2 +-
 .../classification/LogisticRegressionSuite.scala   | 21 +++++++++
 .../ml/regression/AFTSurvivalRegressionSuite.scala | 13 ++++++
 python/pyspark/ml/tests/test_algorithms.py         | 35 ++++++++++++++
 7 files changed, 113 insertions(+), 54 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 3ad1e2c17db..adf77eb6113 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -1112,46 +1112,36 @@ class LogisticRegressionModel private[spark] (
     _intercept
   }
 
-  private lazy val _intercept = interceptVector(0)
-  private lazy val _interceptVector = interceptVector.toDense
-  private lazy val _binaryThresholdArray = {
-    val array = Array(Double.NaN, Double.NaN)
-    updateBinaryThresholds(array)
-    array
-  }
-  private def _threshold: Double = _binaryThresholdArray(0)
-  private def _rawThreshold: Double = _binaryThresholdArray(1)
-
-  private def updateBinaryThresholds(array: Array[Double]): Unit = {
-    if (!isMultinomial) {
-      val _threshold = getThreshold
-      array(0) = _threshold
-      if (_threshold == 0.0) {
-        array(1) = Double.NegativeInfinity
-      } else if (_threshold == 1.0) {
-        array(1) = Double.PositiveInfinity
+  private val _interceptVector = if (isMultinomial) interceptVector.toDense else null
+  private val _intercept = if (!isMultinomial) interceptVector(0) else Double.NaN
+  // Array(0.5, 0.0) is the value for default threshold (0.5) and thresholds (unset)
+  private var _binaryThresholds: Array[Double] = if (!isMultinomial) Array(0.5, 0.0) else null
+
+  private[ml] override def onParamChange(param: Param[_]): Unit = {
+    if (!isMultinomial && (param.name == "threshold" || param.name == "thresholds")) {
+      if (isDefined(threshold) || isDefined(thresholds)) {
+        val _threshold = getThreshold
+        if (_threshold == 0.0) {
+          _binaryThresholds = Array(_threshold, Double.NegativeInfinity)
+        } else if (_threshold == 1.0) {
+          _binaryThresholds = Array(_threshold, Double.PositiveInfinity)
+        } else {
+          _binaryThresholds = Array(_threshold, math.log(_threshold / (1.0 - _threshold)))
+        }
       } else {
-        array(1) = math.log(_threshold / (1.0 - _threshold))
+        _binaryThresholds = null
       }
     }
   }
 
   @Since("1.5.0")
-  override def setThreshold(value: Double): this.type = {
-    super.setThreshold(value)
-    updateBinaryThresholds(_binaryThresholdArray)
-    this
-  }
+  override def setThreshold(value: Double): this.type = super.setThreshold(value)
 
   @Since("1.5.0")
   override def getThreshold: Double = super.getThreshold
 
   @Since("1.5.0")
-  override def setThresholds(value: Array[Double]): this.type = {
-    super.setThresholds(value)
-    updateBinaryThresholds(_binaryThresholdArray)
-    this
-  }
+  override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
 
   @Since("1.5.0")
   override def getThresholds: Array[Double] = super.getThresholds
@@ -1223,7 +1213,7 @@ class LogisticRegressionModel private[spark] (
     super.predict(features)
   } else {
     // Note: We should use _threshold instead of $(threshold) since getThreshold is overridden.
-    if (score(features) > _threshold) 1 else 0
+    if (score(features) > _binaryThresholds(0)) 1 else 0
   }
 
   override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
@@ -1265,7 +1255,7 @@ class LogisticRegressionModel private[spark] (
       super.raw2prediction(rawPrediction)
     } else {
       // Note: We should use _threshold instead of $(threshold) since getThreshold is overridden.
-      if (rawPrediction(1) > _rawThreshold) 1.0 else 0.0
+      if (rawPrediction(1) > _binaryThresholds(1)) 1.0 else 0.0
     }
   }
 
@@ -1274,7 +1264,7 @@ class LogisticRegressionModel private[spark] (
       super.probability2prediction(probability)
     } else {
       // Note: We should use _threshold instead of $(threshold) since getThreshold is overridden.
-      if (probability(1) > _threshold) 1.0 else 0.0
+      if (probability(1) > _binaryThresholds(0)) 1.0 else 0.0
     }
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index f12c1f995b7..52840e04eae 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -726,6 +726,7 @@ trait Params extends Identifiable with Serializable {
   protected final def set(paramPair: ParamPair[_]): this.type = {
     shouldOwn(paramPair.param)
     paramMap.put(paramPair)
+    onParamChange(paramPair.param)
     this
   }
 
@@ -743,6 +744,7 @@ trait Params extends Identifiable with Serializable {
   final def clear(param: Param[_]): this.type = {
     shouldOwn(param)
     paramMap.remove(param)
+    onParamChange(param)
     this
   }
 
@@ -767,8 +769,9 @@ trait Params extends Identifiable with Serializable {
    *               this method gets called.
    * @param value  the default value
    */
-  protected final def setDefault[T](param: Param[T], value: T): this.type = {
+  protected[ml] final def setDefault[T](param: Param[T], value: T): this.type = {
     defaultParamMap.put(param -> value)
+    onParamChange(param)
     this
   }
 
@@ -870,7 +873,7 @@ trait Params extends Identifiable with Serializable {
     params.foreach { param =>
       // copy default Params
       if (defaultParamMap.contains(param) && to.hasParam(param.name)) {
-        to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param))
+        to.setDefault(to.getParam(param.name), defaultParamMap(param))
       }
       // copy explicitly set Params
       if (map.contains(param) && to.hasParam(param.name)) {
@@ -879,15 +882,8 @@ trait Params extends Identifiable with Serializable {
     }
     to
   }
-}
 
-private[ml] object Params {
-  /**
-   * Sets a default param value for a `Params`.
-   */
-  private[ml] final def setDefault[T](params: Params, param: Param[T], value: T): Unit = {
-    params.defaultParamMap.put(param -> value)
-  }
+  private[ml] def onParamChange(param: Param[_]): Unit = {}
 }
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index c48fe680e80..5ac58431f17 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -379,25 +379,29 @@ class AFTSurvivalRegressionModel private[ml] (
 
   /** @group setParam */
   @Since("1.6.0")
-  def setQuantileProbabilities(value: Array[Double]): this.type = {
-    set(quantileProbabilities, value)
-    _quantiles(0) = $(quantileProbabilities).map(q => math.exp(math.log(-math.log1p(-q)) * scale))
-    this
-  }
+  def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value)
 
   /** @group setParam */
   @Since("1.6.0")
   def setQuantilesCol(value: String): this.type = set(quantilesCol, value)
 
-  private lazy val _quantiles = {
-    Array($(quantileProbabilities).map(q => math.exp(math.log(-math.log1p(-q)) * scale)))
+  private var _quantiles: Vector = _
+
+  private[ml] override def onParamChange(param: Param[_]): Unit = {
+    if (param.name == "quantileProbabilities") {
+      if (isDefined(quantileProbabilities)) {
+        _quantiles = Vectors.dense(
+          $(quantileProbabilities).map(q => math.exp(math.log(-math.log1p(-q)) * scale)))
+      } else {
+        _quantiles = null
+      }
+    }
   }
 
   private def lambda2Quantiles(lambda: Double): Vector = {
-    val quantiles = _quantiles(0).clone()
-    var i = 0
-    while (i < quantiles.length) { quantiles(i) *= lambda; i += 1 }
-    Vectors.dense(quantiles)
+    val quantiles = _quantiles.copy
+    BLAS.scal(lambda, quantiles)
+    quantiles
   }
 
   @Since("2.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index fec05ccf15c..5e38b0aba95 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -563,7 +563,7 @@ private[ml] object DefaultParamsReader {
               val param = instance.getParam(paramName)
               val value = param.jsonDecode(compact(render(jsonValue)))
               if (isDefault) {
-                Params.setDefault(instance, param, value)
+                instance.setDefault(param, value)
               } else {
                 instance.set(param, value)
               }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 15405371a27..15f2e63bc85 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -2994,6 +2994,27 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
     val expected = "LogisticRegressionModel: uid=logReg, numClasses=2, numFeatures=3"
     assert(model.toString === expected)
   }
+
+  test("test internal thresholds") {
+    val df = Seq(
+      (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+      (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+      (1.0, 3.0, Vectors.dense(2.0, 1.0)),
+      (0.0, 4.0, Vectors.dense(3.0, 3.0))
+    ).toDF("label", "weight", "features")
+
+    val lor = new LogisticRegression().setWeightCol("weight")
+    val model = lor.fit(df)
+    val vec = Vectors.dense(0.0, 5.0)
+
+    val p0 = model.predict(vec)
+    model.setThreshold(0.05)
+    val p1 = model.set(model.threshold, 0.5).predict(vec)
+    val p2 = model.clear(model.threshold).predict(vec)
+
+    assert(p0 === p1)
+    assert(p0 === p2)
+  }
 }
 
 object LogisticRegressionSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index c8f692654e4..c91f9dea705 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -481,6 +481,19 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest {
       }
     }
   }
+
+  test("test internal quantiles") {
+    val quantileProbabilities = Array(0.1, 0.5, 0.9)
+    val aft = new AFTSurvivalRegression().setQuantilesCol("quantiles")
+    val model = aft.fit(datasetUnivariate)
+    val vec = Vectors.dense(6.559282795753792)
+
+    val p1 = model.setQuantileProbabilities(quantileProbabilities).predictQuantiles(vec)
+    model.setQuantileProbabilities(Array(0.2, 0.3, 0.9))
+    val p2 = model.set(model.quantileProbabilities, quantileProbabilities).predictQuantiles(vec)
+
+    assert(p1 === p2)
+  }
 }
 
 object AFTSurvivalRegressionSuite {
diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py
index accdddb29c0..fb2507fe085 100644
--- a/python/pyspark/ml/tests/test_algorithms.py
+++ b/python/pyspark/ml/tests/test_algorithms.py
@@ -83,6 +83,41 @@ class LogisticRegressionTest(SparkSessionTestCase):
             np.allclose(model.interceptVector.toArray(), [-0.9057, -1.1392, -0.0033], atol=1e-4)
         )
 
+    def test_logistic_regression_with_threshold(self):
+
+        df = self.spark.createDataFrame(
+            [
+                (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+                (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+                (1.0, 3.0, Vectors.dense(2.0, 1.0)),
+                (0.0, 4.0, Vectors.dense(3.0, 3.0)),
+            ],
+            ["label", "weight", "features"],
+        )
+
+        lor = LogisticRegression(weightCol="weight")
+        model = lor.fit(df)
+
+        # status changes 1
+        for t in [0.0, 0.1, 0.2, 0.5, 1.0]:
+            model.setThreshold(t).transform(df)
+
+        # status changes 2
+        [model.setThreshold(t).predict(Vectors.dense(0.0, 5.0)) for t in [0.0, 0.1, 0.2, 0.5, 1.0]]
+
+        self.assertEqual(
+            [row.prediction for row in model.setThreshold(0.0).transform(df).collect()],
+            [1.0, 1.0, 1.0, 1.0],
+        )
+        self.assertEqual(
+            [row.prediction for row in model.setThreshold(0.5).transform(df).collect()],
+            [0.0, 1.0, 1.0, 0.0],
+        )
+        self.assertEqual(
+            [row.prediction for row in model.setThreshold(1.0).transform(df).collect()],
+            [0.0, 0.0, 0.0, 0.0],
+        )
+
 
 class MultilayerPerceptronClassifierTest(SparkSessionTestCase):
     def test_raw_and_probability_prediction(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org