You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@spark.apache.org by "Apache Spark (Jira)" <ji...@apache.org> on 2023/03/10 13:31:00 UTC

[jira] [Commented] (SPARK-42747) Fix incorrect internal status of LoR and AFT

    [ https://issues.apache.org/jira/browse/SPARK-42747?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=17698943#comment-17698943 ] 

Apache Spark commented on SPARK-42747:
--------------------------------------

User 'zhengruifeng' has created a pull request for this issue:
https://github.com/apache/spark/pull/40367

> Fix incorrect internal status of LoR and AFT
> --------------------------------------------
>
>                 Key: SPARK-42747
>                 URL: https://issues.apache.org/jira/browse/SPARK-42747
>             Project: Spark
>          Issue Type: Bug
>          Components: ML, PySpark
>    Affects Versions: 3.1.0, 3.2.0, 3.3.0, 3.4.0
>            Reporter: Ruifeng Zheng
>            Priority: Major
>
> LoR and AFT applied internal status to optimize prediction/transform, but the status is not correctly updated in some case:
> {code:java}
> from pyspark.sql import Row
> from pyspark.ml.classification import *
> from pyspark.ml.linalg import Vectors
> df = spark.createDataFrame(
>     [
>         (1.0, 1.0, Vectors.dense(0.0, 5.0)),
>         (0.0, 2.0, Vectors.dense(1.0, 2.0)),
>         (1.0, 3.0, Vectors.dense(2.0, 1.0)),
>         (0.0, 4.0, Vectors.dense(3.0, 3.0)),
>     ],
>     ["label", "weight", "features"],
> )
> lor = LogisticRegression(weightCol="weight")
> model = lor.fit(df)
> # status changes 1
> for t in [0.0, 0.1, 0.2, 0.5, 1.0]:
>     model.setThreshold(t).transform(df)
> # status changes 2
> [model.setThreshold(t).predict(Vectors.dense(0.0, 5.0)) for t in [0.0, 0.1, 0.2, 0.5, 1.0]]
> for t in [0.0, 0.1, 0.2, 0.5, 1.0]:
>     print(t)
>     model.setThreshold(t).transform(df).show()                                        #  <- error results
> {code}
> results:
> {code:java}
> 0.0
> +-----+------+---------+--------------------+--------------------+----------+
> |label|weight| features|       rawPrediction|         probability|prediction|
> +-----+------+---------+--------------------+--------------------+----------+
> |  1.0|   1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...|       0.0|
> |  0.0|   2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...|       0.0|
> |  1.0|   3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...|       0.0|
> |  0.0|   4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...|       0.0|
> +-----+------+---------+--------------------+--------------------+----------+
> 0.1
> +-----+------+---------+--------------------+--------------------+----------+
> |label|weight| features|       rawPrediction|         probability|prediction|
> +-----+------+---------+--------------------+--------------------+----------+
> |  1.0|   1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...|       0.0|
> |  0.0|   2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...|       0.0|
> |  1.0|   3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...|       0.0|
> |  0.0|   4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...|       0.0|
> +-----+------+---------+--------------------+--------------------+----------+
> 0.2
> +-----+------+---------+--------------------+--------------------+----------+
> |label|weight| features|       rawPrediction|         probability|prediction|
> +-----+------+---------+--------------------+--------------------+----------+
> |  1.0|   1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...|       0.0|
> |  0.0|   2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...|       0.0|
> |  1.0|   3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...|       0.0|
> |  0.0|   4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...|       0.0|
> +-----+------+---------+--------------------+--------------------+----------+
> 0.5
> +-----+------+---------+--------------------+--------------------+----------+
> |label|weight| features|       rawPrediction|         probability|prediction|
> +-----+------+---------+--------------------+--------------------+----------+
> |  1.0|   1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...|       0.0|
> |  0.0|   2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...|       0.0|
> |  1.0|   3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...|       0.0|
> |  0.0|   4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...|       0.0|
> +-----+------+---------+--------------------+--------------------+----------+
> 1.0
> +-----+------+---------+--------------------+--------------------+----------+
> |label|weight| features|       rawPrediction|         probability|prediction|
> +-----+------+---------+--------------------+--------------------+----------+
> |  1.0|   1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...|       0.0|
> |  0.0|   2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...|       0.0|
> |  1.0|   3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...|       0.0|
> |  0.0|   4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...|       0.0|
> +-----+------+---------+--------------------+--------------------+----------+
> {code}



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org