You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@spark.apache.org by "Apache Spark (Jira)" <ji...@apache.org> on 2023/03/10 13:31:00 UTC
[jira] [Assigned] (SPARK-42747) Fix incorrect internal status of LoR and AFT
[ https://issues.apache.org/jira/browse/SPARK-42747?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Apache Spark reassigned SPARK-42747:
------------------------------------
Assignee: Apache Spark
> Fix incorrect internal status of LoR and AFT
> --------------------------------------------
>
> Key: SPARK-42747
> URL: https://issues.apache.org/jira/browse/SPARK-42747
> Project: Spark
> Issue Type: Bug
> Components: ML, PySpark
> Affects Versions: 3.1.0, 3.2.0, 3.3.0, 3.4.0
> Reporter: Ruifeng Zheng
> Assignee: Apache Spark
> Priority: Major
>
> LoR and AFT applied internal status to optimize prediction/transform, but the status is not correctly updated in some case:
> {code:java}
> from pyspark.sql import Row
> from pyspark.ml.classification import *
> from pyspark.ml.linalg import Vectors
> df = spark.createDataFrame(
> [
> (1.0, 1.0, Vectors.dense(0.0, 5.0)),
> (0.0, 2.0, Vectors.dense(1.0, 2.0)),
> (1.0, 3.0, Vectors.dense(2.0, 1.0)),
> (0.0, 4.0, Vectors.dense(3.0, 3.0)),
> ],
> ["label", "weight", "features"],
> )
> lor = LogisticRegression(weightCol="weight")
> model = lor.fit(df)
> # status changes 1
> for t in [0.0, 0.1, 0.2, 0.5, 1.0]:
> model.setThreshold(t).transform(df)
> # status changes 2
> [model.setThreshold(t).predict(Vectors.dense(0.0, 5.0)) for t in [0.0, 0.1, 0.2, 0.5, 1.0]]
> for t in [0.0, 0.1, 0.2, 0.5, 1.0]:
> print(t)
> model.setThreshold(t).transform(df).show() # <- error results
> {code}
> results:
> {code:java}
> 0.0
> +-----+------+---------+--------------------+--------------------+----------+
> |label|weight| features| rawPrediction| probability|prediction|
> +-----+------+---------+--------------------+--------------------+----------+
> | 1.0| 1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...| 0.0|
> | 0.0| 2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...| 0.0|
> | 1.0| 3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...| 0.0|
> | 0.0| 4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...| 0.0|
> +-----+------+---------+--------------------+--------------------+----------+
> 0.1
> +-----+------+---------+--------------------+--------------------+----------+
> |label|weight| features| rawPrediction| probability|prediction|
> +-----+------+---------+--------------------+--------------------+----------+
> | 1.0| 1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...| 0.0|
> | 0.0| 2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...| 0.0|
> | 1.0| 3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...| 0.0|
> | 0.0| 4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...| 0.0|
> +-----+------+---------+--------------------+--------------------+----------+
> 0.2
> +-----+------+---------+--------------------+--------------------+----------+
> |label|weight| features| rawPrediction| probability|prediction|
> +-----+------+---------+--------------------+--------------------+----------+
> | 1.0| 1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...| 0.0|
> | 0.0| 2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...| 0.0|
> | 1.0| 3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...| 0.0|
> | 0.0| 4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...| 0.0|
> +-----+------+---------+--------------------+--------------------+----------+
> 0.5
> +-----+------+---------+--------------------+--------------------+----------+
> |label|weight| features| rawPrediction| probability|prediction|
> +-----+------+---------+--------------------+--------------------+----------+
> | 1.0| 1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...| 0.0|
> | 0.0| 2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...| 0.0|
> | 1.0| 3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...| 0.0|
> | 0.0| 4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...| 0.0|
> +-----+------+---------+--------------------+--------------------+----------+
> 1.0
> +-----+------+---------+--------------------+--------------------+----------+
> |label|weight| features| rawPrediction| probability|prediction|
> +-----+------+---------+--------------------+--------------------+----------+
> | 1.0| 1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...| 0.0|
> | 0.0| 2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...| 0.0|
> | 1.0| 3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...| 0.0|
> | 0.0| 4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...| 0.0|
> +-----+------+---------+--------------------+--------------------+----------+
> {code}
--
This message was sent by Atlassian Jira
(v8.20.10#820010)
---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org