You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/04/21 08:30:10 UTC
[spark] branch branch-3.1 updated: [SPARK-35142][PYTHON][ML] Fix
incorrect return type for `rawPredictionUDF` in `OneVsRestModel`
This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push:
new 0208810 [SPARK-35142][PYTHON][ML] Fix incorrect return type for `rawPredictionUDF` in `OneVsRestModel`
0208810 is described below
commit 0208810b93e234b822bc972f4236bf04bf521e7d
Author: harupy <17...@users.noreply.github.com>
AuthorDate: Wed Apr 21 16:29:10 2021 +0800
[SPARK-35142][PYTHON][ML] Fix incorrect return type for `rawPredictionUDF` in `OneVsRestModel`
### What changes were proposed in this pull request?
Fixes incorrect return type for `rawPredictionUDF` in `OneVsRestModel`.
### Why are the changes needed?
Bugfix
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit test.
Closes #32245 from harupy/SPARK-35142.
Authored-by: harupy <17...@users.noreply.github.com>
Signed-off-by: Weichen Xu <we...@databricks.com>
(cherry picked from commit b6350f5bb00f99a060953850b069a419b70c329e)
Signed-off-by: Weichen Xu <we...@databricks.com>
---
python/pyspark/ml/classification.py | 4 ++--
python/pyspark/ml/tests/test_algorithms.py | 14 +++++++++++++-
2 files changed, 15 insertions(+), 3 deletions(-)
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 0553a61..17994ed 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -40,7 +40,7 @@ from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, \
from pyspark.ml.wrapper import JavaParams, \
JavaPredictor, JavaPredictionModel, JavaWrapper
from pyspark.ml.common import inherit_doc
-from pyspark.ml.linalg import Vectors
+from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql import DataFrame
from pyspark.sql.functions import udf, when
from pyspark.sql.types import ArrayType, DoubleType
@@ -3151,7 +3151,7 @@ class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
predArray.append(x)
return Vectors.dense(predArray)
- rawPredictionUDF = udf(func)
+ rawPredictionUDF = udf(func, VectorUDT())
aggregatedDataset = aggregatedDataset.withColumn(
self.getRawPredictionCol(), rawPredictionUDF(aggregatedDataset[accColName]))
diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py
index 5047521..35ce48b 100644
--- a/python/pyspark/ml/tests/test_algorithms.py
+++ b/python/pyspark/ml/tests/test_algorithms.py
@@ -25,7 +25,7 @@ from pyspark.ml.classification import FMClassifier, LogisticRegression, \
MultilayerPerceptronClassifier, OneVsRest
from pyspark.ml.clustering import DistributedLDAModel, KMeans, LocalLDAModel, LDA, LDAModel
from pyspark.ml.fpm import FPGrowth
-from pyspark.ml.linalg import Matrices, Vectors
+from pyspark.ml.linalg import Matrices, Vectors, DenseVector
from pyspark.ml.recommendation import ALS
from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression
from pyspark.sql import Row
@@ -116,6 +116,18 @@ class OneVsRestTests(SparkSessionTestCase):
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "rawPrediction", "prediction"])
+ def test_raw_prediction_column_is_of_vector_type(self):
+ # SPARK-35142: `OneVsRestModel` outputs raw prediction as a string column
+ df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+ (1.0, Vectors.sparse(2, [], [])),
+ (2.0, Vectors.dense(0.5, 0.5))],
+ ["label", "features"])
+ lr = LogisticRegression(maxIter=5, regParam=0.01)
+ ovr = OneVsRest(classifier=lr, parallelism=1)
+ model = ovr.fit(df)
+ row = model.transform(df).head()
+ self.assertIsInstance(row["rawPrediction"], DenseVector)
+
def test_parallelism_does_not_change_output(self):
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
(1.0, Vectors.sparse(2, [], [])),
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org