You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2023/03/11 14:46:14 UTC
[spark] branch branch-3.4 updated: [SPARK-42747][ML] Fix incorrect internal status of LoR and AFT
This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new cb7ae0407d4 [SPARK-42747][ML] Fix incorrect internal status of LoR and AFT
cb7ae0407d4 is described below
commit cb7ae0407d440feb6c228b1265af50c0006e21e9
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Sat Mar 11 08:45:54 2023 -0600
[SPARK-42747][ML] Fix incorrect internal status of LoR and AFT
### What changes were proposed in this pull request?
Add a hook `onParamChange` in `Params.{set, setDefault, clear}`, so that subclass can update the internal status within it.
### Why are the changes needed?
In 3.1, we added internal auxiliary variables in LoR and AFT to optimize prediction/transformation.
In LoR, when users call `model.{setThreshold, setThresholds}`, the internal status will be correctly updated.
But users still can call `model.set(model.threshold, value)`, then the status will not be updated.
And when users call `model.clear(model.threshold)`, the status should be updated with default threshold value 0.5.
for example:
```
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.classification._
val df = Seq((1.0, 1.0, Vectors.dense(0.0, 5.0)), (0.0, 2.0, Vectors.dense(1.0, 2.0)), (1.0, 3.0, Vectors.dense(2.0, 1.0)), (0.0, 4.0, Vectors.dense(3.0, 3.0))).toDF("label", "weight", "features")
val lor = new LogisticRegression().setWeightCol("weight")
val model = lor.fit(df)
val vec = Vectors.dense(0.0, 5.0)
val p0 = model.predict(vec) // return 0.0
model.setThreshold(0.05) // change status
val p1 = model.set(model.threshold, 0.5).predict(vec) // return 1.0; but should be 0.0
val p2 = model.clear(model.threshold).predict(vec) // return 1.0; but should be 0.0
```
what makes it even worse it that `pyspark.ml` always set params via `model.set(model.threshold, value)`, so the internal status is easily out of sync, see the example in [SPARK-42747](https://issues.apache.org/jira/browse/SPARK-42747)
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
added ut
Closes #40367 from zhengruifeng/ml_param_hook.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Sean Owen <sr...@gmail.com>
(cherry picked from commit 5a702f22f49ca6a1b6220ac645e3fce70ec5189d)
Signed-off-by: Sean Owen <sr...@gmail.com>
---
.../ml/classification/LogisticRegression.scala | 54 +++++++++-------------
.../scala/org/apache/spark/ml/param/params.scala | 16 +++----
.../ml/regression/AFTSurvivalRegression.scala | 26 ++++++-----
.../scala/org/apache/spark/ml/util/ReadWrite.scala | 2 +-
.../classification/LogisticRegressionSuite.scala | 21 +++++++++
.../ml/regression/AFTSurvivalRegressionSuite.scala | 13 ++++++
python/pyspark/ml/tests/test_algorithms.py | 35 ++++++++++++++
7 files changed, 113 insertions(+), 54 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 3ad1e2c17db..adf77eb6113 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -1112,46 +1112,36 @@ class LogisticRegressionModel private[spark] (
_intercept
}
- private lazy val _intercept = interceptVector(0)
- private lazy val _interceptVector = interceptVector.toDense
- private lazy val _binaryThresholdArray = {
- val array = Array(Double.NaN, Double.NaN)
- updateBinaryThresholds(array)
- array
- }
- private def _threshold: Double = _binaryThresholdArray(0)
- private def _rawThreshold: Double = _binaryThresholdArray(1)
-
- private def updateBinaryThresholds(array: Array[Double]): Unit = {
- if (!isMultinomial) {
- val _threshold = getThreshold
- array(0) = _threshold
- if (_threshold == 0.0) {
- array(1) = Double.NegativeInfinity
- } else if (_threshold == 1.0) {
- array(1) = Double.PositiveInfinity
+ private val _interceptVector = if (isMultinomial) interceptVector.toDense else null
+ private val _intercept = if (!isMultinomial) interceptVector(0) else Double.NaN
+ // Array(0.5, 0.0) is the value for default threshold (0.5) and thresholds (unset)
+ private var _binaryThresholds: Array[Double] = if (!isMultinomial) Array(0.5, 0.0) else null
+
+ private[ml] override def onParamChange(param: Param[_]): Unit = {
+ if (!isMultinomial && (param.name == "threshold" || param.name == "thresholds")) {
+ if (isDefined(threshold) || isDefined(thresholds)) {
+ val _threshold = getThreshold
+ if (_threshold == 0.0) {
+ _binaryThresholds = Array(_threshold, Double.NegativeInfinity)
+ } else if (_threshold == 1.0) {
+ _binaryThresholds = Array(_threshold, Double.PositiveInfinity)
+ } else {
+ _binaryThresholds = Array(_threshold, math.log(_threshold / (1.0 - _threshold)))
+ }
} else {
- array(1) = math.log(_threshold / (1.0 - _threshold))
+ _binaryThresholds = null
}
}
}
@Since("1.5.0")
- override def setThreshold(value: Double): this.type = {
- super.setThreshold(value)
- updateBinaryThresholds(_binaryThresholdArray)
- this
- }
+ override def setThreshold(value: Double): this.type = super.setThreshold(value)
@Since("1.5.0")
override def getThreshold: Double = super.getThreshold
@Since("1.5.0")
- override def setThresholds(value: Array[Double]): this.type = {
- super.setThresholds(value)
- updateBinaryThresholds(_binaryThresholdArray)
- this
- }
+ override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
@Since("1.5.0")
override def getThresholds: Array[Double] = super.getThresholds
@@ -1223,7 +1213,7 @@ class LogisticRegressionModel private[spark] (
super.predict(features)
} else {
// Note: We should use _threshold instead of $(threshold) since getThreshold is overridden.
- if (score(features) > _threshold) 1 else 0
+ if (score(features) > _binaryThresholds(0)) 1 else 0
}
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
@@ -1265,7 +1255,7 @@ class LogisticRegressionModel private[spark] (
super.raw2prediction(rawPrediction)
} else {
// Note: We should use _threshold instead of $(threshold) since getThreshold is overridden.
- if (rawPrediction(1) > _rawThreshold) 1.0 else 0.0
+ if (rawPrediction(1) > _binaryThresholds(1)) 1.0 else 0.0
}
}
@@ -1274,7 +1264,7 @@ class LogisticRegressionModel private[spark] (
super.probability2prediction(probability)
} else {
// Note: We should use _threshold instead of $(threshold) since getThreshold is overridden.
- if (probability(1) > _threshold) 1.0 else 0.0
+ if (probability(1) > _binaryThresholds(0)) 1.0 else 0.0
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index f12c1f995b7..52840e04eae 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -726,6 +726,7 @@ trait Params extends Identifiable with Serializable {
protected final def set(paramPair: ParamPair[_]): this.type = {
shouldOwn(paramPair.param)
paramMap.put(paramPair)
+ onParamChange(paramPair.param)
this
}
@@ -743,6 +744,7 @@ trait Params extends Identifiable with Serializable {
final def clear(param: Param[_]): this.type = {
shouldOwn(param)
paramMap.remove(param)
+ onParamChange(param)
this
}
@@ -767,8 +769,9 @@ trait Params extends Identifiable with Serializable {
* this method gets called.
* @param value the default value
*/
- protected final def setDefault[T](param: Param[T], value: T): this.type = {
+ protected[ml] final def setDefault[T](param: Param[T], value: T): this.type = {
defaultParamMap.put(param -> value)
+ onParamChange(param)
this
}
@@ -870,7 +873,7 @@ trait Params extends Identifiable with Serializable {
params.foreach { param =>
// copy default Params
if (defaultParamMap.contains(param) && to.hasParam(param.name)) {
- to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param))
+ to.setDefault(to.getParam(param.name), defaultParamMap(param))
}
// copy explicitly set Params
if (map.contains(param) && to.hasParam(param.name)) {
@@ -879,15 +882,8 @@ trait Params extends Identifiable with Serializable {
}
to
}
-}
-private[ml] object Params {
- /**
- * Sets a default param value for a `Params`.
- */
- private[ml] final def setDefault[T](params: Params, param: Param[T], value: T): Unit = {
- params.defaultParamMap.put(param -> value)
- }
+ private[ml] def onParamChange(param: Param[_]): Unit = {}
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index c48fe680e80..5ac58431f17 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -379,25 +379,29 @@ class AFTSurvivalRegressionModel private[ml] (
/** @group setParam */
@Since("1.6.0")
- def setQuantileProbabilities(value: Array[Double]): this.type = {
- set(quantileProbabilities, value)
- _quantiles(0) = $(quantileProbabilities).map(q => math.exp(math.log(-math.log1p(-q)) * scale))
- this
- }
+ def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value)
/** @group setParam */
@Since("1.6.0")
def setQuantilesCol(value: String): this.type = set(quantilesCol, value)
- private lazy val _quantiles = {
- Array($(quantileProbabilities).map(q => math.exp(math.log(-math.log1p(-q)) * scale)))
+ private var _quantiles: Vector = _
+
+ private[ml] override def onParamChange(param: Param[_]): Unit = {
+ if (param.name == "quantileProbabilities") {
+ if (isDefined(quantileProbabilities)) {
+ _quantiles = Vectors.dense(
+ $(quantileProbabilities).map(q => math.exp(math.log(-math.log1p(-q)) * scale)))
+ } else {
+ _quantiles = null
+ }
+ }
}
private def lambda2Quantiles(lambda: Double): Vector = {
- val quantiles = _quantiles(0).clone()
- var i = 0
- while (i < quantiles.length) { quantiles(i) *= lambda; i += 1 }
- Vectors.dense(quantiles)
+ val quantiles = _quantiles.copy
+ BLAS.scal(lambda, quantiles)
+ quantiles
}
@Since("2.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index fec05ccf15c..5e38b0aba95 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -563,7 +563,7 @@ private[ml] object DefaultParamsReader {
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
if (isDefault) {
- Params.setDefault(instance, param, value)
+ instance.setDefault(param, value)
} else {
instance.set(param, value)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 15405371a27..15f2e63bc85 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -2994,6 +2994,27 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
val expected = "LogisticRegressionModel: uid=logReg, numClasses=2, numFeatures=3"
assert(model.toString === expected)
}
+
+ test("test internal thresholds") {
+ val df = Seq(
+ (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+ (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+ (1.0, 3.0, Vectors.dense(2.0, 1.0)),
+ (0.0, 4.0, Vectors.dense(3.0, 3.0))
+ ).toDF("label", "weight", "features")
+
+ val lor = new LogisticRegression().setWeightCol("weight")
+ val model = lor.fit(df)
+ val vec = Vectors.dense(0.0, 5.0)
+
+ val p0 = model.predict(vec)
+ model.setThreshold(0.05)
+ val p1 = model.set(model.threshold, 0.5).predict(vec)
+ val p2 = model.clear(model.threshold).predict(vec)
+
+ assert(p0 === p1)
+ assert(p0 === p2)
+ }
}
object LogisticRegressionSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index c8f692654e4..c91f9dea705 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -481,6 +481,19 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest {
}
}
}
+
+ test("test internal quantiles") {
+ val quantileProbabilities = Array(0.1, 0.5, 0.9)
+ val aft = new AFTSurvivalRegression().setQuantilesCol("quantiles")
+ val model = aft.fit(datasetUnivariate)
+ val vec = Vectors.dense(6.559282795753792)
+
+ val p1 = model.setQuantileProbabilities(quantileProbabilities).predictQuantiles(vec)
+ model.setQuantileProbabilities(Array(0.2, 0.3, 0.9))
+ val p2 = model.set(model.quantileProbabilities, quantileProbabilities).predictQuantiles(vec)
+
+ assert(p1 === p2)
+ }
}
object AFTSurvivalRegressionSuite {
diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py
index accdddb29c0..fb2507fe085 100644
--- a/python/pyspark/ml/tests/test_algorithms.py
+++ b/python/pyspark/ml/tests/test_algorithms.py
@@ -83,6 +83,41 @@ class LogisticRegressionTest(SparkSessionTestCase):
np.allclose(model.interceptVector.toArray(), [-0.9057, -1.1392, -0.0033], atol=1e-4)
)
+ def test_logistic_regression_with_threshold(self):
+
+ df = self.spark.createDataFrame(
+ [
+ (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+ (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+ (1.0, 3.0, Vectors.dense(2.0, 1.0)),
+ (0.0, 4.0, Vectors.dense(3.0, 3.0)),
+ ],
+ ["label", "weight", "features"],
+ )
+
+ lor = LogisticRegression(weightCol="weight")
+ model = lor.fit(df)
+
+ # status changes 1
+ for t in [0.0, 0.1, 0.2, 0.5, 1.0]:
+ model.setThreshold(t).transform(df)
+
+ # status changes 2
+ [model.setThreshold(t).predict(Vectors.dense(0.0, 5.0)) for t in [0.0, 0.1, 0.2, 0.5, 1.0]]
+
+ self.assertEqual(
+ [row.prediction for row in model.setThreshold(0.0).transform(df).collect()],
+ [1.0, 1.0, 1.0, 1.0],
+ )
+ self.assertEqual(
+ [row.prediction for row in model.setThreshold(0.5).transform(df).collect()],
+ [0.0, 1.0, 1.0, 0.0],
+ )
+ self.assertEqual(
+ [row.prediction for row in model.setThreshold(1.0).transform(df).collect()],
+ [0.0, 0.0, 0.0, 0.0],
+ )
+
class MultilayerPerceptronClassifierTest(SparkSessionTestCase):
def test_raw_and_probability_prediction(self):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org