You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/08/19 04:27:17 UTC

[spark] branch branch-3.3 updated: [SPARK-35542][ML] Fix: Bucketizer created for multiple columns with parameters splitsArray, inputCols and outputCols can not be loaded after saving it

This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 87f957dea86 [SPARK-35542][ML] Fix: Bucketizer created for multiple columns with parameters splitsArray,  inputCols and outputCols can not be loaded after saving it
87f957dea86 is described below

commit 87f957dea86fe1b8c5979e499b5400866b235e43
Author: Weichen Xu <we...@databricks.com>
AuthorDate: Fri Aug 19 12:26:34 2022 +0800

    [SPARK-35542][ML] Fix: Bucketizer created for multiple columns with parameters splitsArray,  inputCols and outputCols can not be loaded after saving it
    
    Signed-off-by: Weichen Xu <weichen.xudatabricks.com>
    
    ### What changes were proposed in this pull request?
    Fix: Bucketizer created for multiple columns with parameters splitsArray,  inputCols and outputCols can not be loaded after saving it
    
    ### Why are the changes needed?
    Bugfix.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Unit test
    
    Closes #37568 from WeichenXu123/SPARK-35542.
    
    Authored-by: Weichen Xu <we...@databricks.com>
    Signed-off-by: Weichen Xu <we...@databricks.com>
    (cherry picked from commit 876ce6a5df118095de51c3c4789d6db6da95eb23)
    Signed-off-by: Weichen Xu <we...@databricks.com>
---
 python/pyspark/ml/tests/test_persistence.py | 17 ++++++++++++++++-
 python/pyspark/ml/wrapper.py                |  6 +++++-
 2 files changed, 21 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py
index 4f09a49dd04..0b54540f06d 100644
--- a/python/pyspark/ml/tests/test_persistence.py
+++ b/python/pyspark/ml/tests/test_persistence.py
@@ -32,7 +32,7 @@ from pyspark.ml.classification import (
     OneVsRestModel,
 )
 from pyspark.ml.clustering import KMeans
-from pyspark.ml.feature import Binarizer, HashingTF, PCA
+from pyspark.ml.feature import Binarizer, Bucketizer, HashingTF, PCA
 from pyspark.ml.linalg import Vectors
 from pyspark.ml.param import Params
 from pyspark.ml.pipeline import Pipeline, PipelineModel
@@ -518,6 +518,21 @@ class PersistenceTest(SparkSessionTestCase):
         )
         reader.getAndSetParams(lr, loadedMetadata)
 
+    # Test for SPARK-35542 fix.
+    def test_save_and_load_on_nested_list_params(self):
+        temp_path = tempfile.mkdtemp()
+        splitsArray = [
+            [-float("inf"), 0.5, 1.4, float("inf")],
+            [-float("inf"), 0.1, 1.2, float("inf")],
+        ]
+        bucketizer = Bucketizer(
+            splitsArray=splitsArray, inputCols=["values", "values"], outputCols=["b1", "b2"]
+        )
+        savePath = temp_path + "/bk"
+        bucketizer.write().overwrite().save(savePath)
+        loadedBucketizer = Bucketizer.load(savePath)
+        assert loadedBucketizer.getSplitsArray() == splitsArray
+
 
 if __name__ == "__main__":
     from pyspark.ml.tests.test_persistence import *  # noqa: F401
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 7853e766244..32856540d6d 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -220,7 +220,11 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
                 java_param = self._java_obj.getParam(param.name)
                 # SPARK-14931: Only check set params back to avoid default params mismatch.
                 if self._java_obj.isSet(java_param):
-                    value = _java2py(sc, self._java_obj.getOrDefault(java_param))
+                    java_value = self._java_obj.getOrDefault(java_param)
+                    if param.typeConverter.__name__.startswith("toList"):
+                        value = [_java2py(sc, x) for x in list(java_value)]
+                    else:
+                        value = _java2py(sc, java_value)
                     self._set(**{param.name: value})
                 # SPARK-10931: Temporary fix for params that have a default in Java
                 if self._java_obj.hasDefault(java_param) and not self.isDefined(param):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org