You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/12/07 03:46:01 UTC

[spark] branch branch-3.0 updated: [SPARK-33592][ML][PYTHON][3.0] Backport Fix: Pyspark ML Validator params in estimatorParamMaps may be lost after saving and reloading

This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 8acbe5b  [SPARK-33592][ML][PYTHON][3.0] Backport Fix: Pyspark ML Validator params in estimatorParamMaps may be lost after saving and reloading
8acbe5b is described below

commit 8acbe5b822b7b0ad49079aa223ad52afe70b5afa
Author: Weichen Xu <we...@databricks.com>
AuthorDate: Mon Dec 7 11:42:18 2020 +0800

    [SPARK-33592][ML][PYTHON][3.0] Backport Fix: Pyspark ML Validator params in estimatorParamMaps may be lost after saving and reloading
    
    Fix: Pyspark ML Validator params in estimatorParamMaps may be lost after saving and reloading
    
    When saving validator estimatorParamMaps, will check all nested stages in tuned estimator to get correct param parent.
    
    Two typical cases to manually test:
    ~~~python
    tokenizer = Tokenizer(inputCol="text", outputCol="words")
    hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
    lr = LogisticRegression()
    pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
    
    paramGrid = ParamGridBuilder() \
        .addGrid(hashingTF.numFeatures, [10, 100]) \
        .addGrid(lr.maxIter, [100, 200]) \
        .build()
    tvs = TrainValidationSplit(estimator=pipeline,
                               estimatorParamMaps=paramGrid,
                               evaluator=MulticlassClassificationEvaluator())
    
    tvs.save(tvsPath)
    loadedTvs = TrainValidationSplit.load(tvsPath)
    
    ~~~
    
    ~~~python
    lr = LogisticRegression()
    ova = OneVsRest(classifier=lr)
    grid = ParamGridBuilder().addGrid(lr.maxIter, [100, 200]).build()
    evaluator = MulticlassClassificationEvaluator()
    tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator)
    
    tvs.save(tvsPath)
    loadedTvs = TrainValidationSplit.load(tvsPath)
    
    ~~~
    
    Bug fix.
    
    No
    
    Unit test.
    
    Closes #30539 from WeichenXu123/fix_tuning_param_maps_io.
    
    Authored-by: Weichen Xu <weichen.xudatabricks.com>
    Signed-off-by: Ruifeng Zheng <ruifengzfoxmail.com>
    (cherry picked from commit 80161238fe9393aabd5fcd56752ff1e43f6989b1)
    Signed-off-by: Weichen Xu <weichen.xudatabricks.com>
    
    ### What changes were proposed in this pull request?
    
    ### Why are the changes needed?
    
    ### Does this PR introduce _any_ user-facing change?
    
    ### How was this patch tested?
    
    Closes #30590 from WeichenXu123/SPARK-33592-bp-3.0.
    
    Authored-by: Weichen Xu <we...@databricks.com>
    Signed-off-by: Weichen Xu <we...@databricks.com>
---
 dev/sparktestsupport/modules.py        |  1 +
 python/pyspark/ml/classification.py    | 48 +-----------------
 python/pyspark/ml/param/__init__.py    |  6 +++
 python/pyspark/ml/pipeline.py          | 53 +-------------------
 python/pyspark/ml/tests/test_tuning.py | 47 +++++++++++++++--
 python/pyspark/ml/tests/test_util.py   | 84 +++++++++++++++++++++++++++++++
 python/pyspark/ml/tuning.py            | 92 +++++++++++++++++++++++++++++++---
 python/pyspark/ml/util.py              | 38 ++++++++++++++
 8 files changed, 261 insertions(+), 108 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 75bdec0..8705d52 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -548,6 +548,7 @@ pyspark_ml = Module(
         "pyspark.ml.tests.test_stat",
         "pyspark.ml.tests.test_training_summary",
         "pyspark.ml.tests.test_tuning",
+        "pyspark.ml.tests.test_util",
         "pyspark.ml.tests.test_wrapper",
     ],
     blacklisted_python_implementations=[
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index c8e15ca..1392bc7 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -27,9 +27,9 @@ from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
     _HasVarianceImpurity, _TreeClassifierParams, _TreeEnsembleParams
 from pyspark.ml.regression import _FactorizationMachinesParams, DecisionTreeRegressionModel
 from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
+from pyspark.ml.wrapper import JavaParams, \
     JavaPredictor, _JavaPredictorParams, JavaPredictionModel, JavaWrapper
-from pyspark.ml.common import inherit_doc, _java2py, _py2java
+from pyspark.ml.common import inherit_doc
 from pyspark.ml.linalg import Vectors
 from pyspark.sql import DataFrame
 from pyspark.sql.functions import udf, when
@@ -2635,50 +2635,6 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, Jav
         _java_obj.setRawPredictionCol(self.getRawPredictionCol())
         return _java_obj
 
-    def _make_java_param_pair(self, param, value):
-        """
-        Makes a Java param pair.
-        """
-        sc = SparkContext._active_spark_context
-        param = self._resolveParam(param)
-        _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
-                                             self.uid)
-        java_param = _java_obj.getParam(param.name)
-        if isinstance(value, JavaParams):
-            # used in the case of an estimator having another estimator as a parameter
-            # the reason why this is not in _py2java in common.py is that importing
-            # Estimator and Model in common.py results in a circular import with inherit_doc
-            java_value = value._to_java()
-        else:
-            java_value = _py2java(sc, value)
-        return java_param.w(java_value)
-
-    def _transfer_param_map_to_java(self, pyParamMap):
-        """
-        Transforms a Python ParamMap into a Java ParamMap.
-        """
-        paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
-        for param in self.params:
-            if param in pyParamMap:
-                pair = self._make_java_param_pair(param, pyParamMap[param])
-                paramMap.put([pair])
-        return paramMap
-
-    def _transfer_param_map_from_java(self, javaParamMap):
-        """
-        Transforms a Java ParamMap into a Python ParamMap.
-        """
-        sc = SparkContext._active_spark_context
-        paramMap = dict()
-        for pair in javaParamMap.toList():
-            param = pair.param()
-            if self.hasParam(str(param.name())):
-                if param.name() == "classifier":
-                    paramMap[self.getParam(param.name())] = JavaParams._from_java(pair.value())
-                else:
-                    paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
-        return paramMap
-
 
 class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
     """
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index 1be8755..f838757 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -426,6 +426,12 @@ class Params(Identifiable):
         else:
             raise ValueError("Cannot resolve %r as a param." % param)
 
+    def _testOwnParam(self, param_parent, param_name):
+        """
+        Test the ownership. Return True or False
+        """
+        return self.uid == param_parent and self.hasParam(param_name)
+
     @staticmethod
     def _dummy():
         """
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 53d07ec..09e0748 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -25,8 +25,8 @@ from pyspark import since, keyword_only, SparkContext
 from pyspark.ml.base import Estimator, Model, Transformer
 from pyspark.ml.param import Param, Params
 from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaParams, JavaWrapper
-from pyspark.ml.common import inherit_doc, _java2py, _py2java
+from pyspark.ml.wrapper import JavaParams
+from pyspark.ml.common import inherit_doc
 
 
 @inherit_doc
@@ -174,55 +174,6 @@ class Pipeline(Estimator, MLReadable, MLWritable):
 
         return _java_obj
 
-    def _make_java_param_pair(self, param, value):
-        """
-        Makes a Java param pair.
-        """
-        sc = SparkContext._active_spark_context
-        param = self._resolveParam(param)
-        java_param = sc._jvm.org.apache.spark.ml.param.Param(param.parent, param.name, param.doc)
-        if isinstance(value, Params) and hasattr(value, "_to_java"):
-            # Convert JavaEstimator/JavaTransformer object or Estimator/Transformer object which
-            # implements `_to_java` method (such as OneVsRest, Pipeline object) to java object.
-            # used in the case of an estimator having another estimator as a parameter
-            # the reason why this is not in _py2java in common.py is that importing
-            # Estimator and Model in common.py results in a circular import with inherit_doc
-            java_value = value._to_java()
-        else:
-            java_value = _py2java(sc, value)
-        return java_param.w(java_value)
-
-    def _transfer_param_map_to_java(self, pyParamMap):
-        """
-        Transforms a Python ParamMap into a Java ParamMap.
-        """
-        paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
-        for param in self.params:
-            if param in pyParamMap:
-                pair = self._make_java_param_pair(param, pyParamMap[param])
-                paramMap.put([pair])
-        return paramMap
-
-    def _transfer_param_map_from_java(self, javaParamMap):
-        """
-        Transforms a Java ParamMap into a Python ParamMap.
-        """
-        sc = SparkContext._active_spark_context
-        paramMap = dict()
-        for pair in javaParamMap.toList():
-            param = pair.param()
-            if self.hasParam(str(param.name())):
-                java_obj = pair.value()
-                if sc._jvm.Class.forName("org.apache.spark.ml.PipelineStage").isInstance(java_obj):
-                    # Note: JavaParams._from_java support both JavaEstimator/JavaTransformer class
-                    # and Estimator/Transformer class which implements `_from_java` static method
-                    # (such as OneVsRest, Pipeline class).
-                    py_obj = JavaParams._from_java(java_obj)
-                else:
-                    py_obj = _java2py(sc, java_obj)
-                paramMap[self.getParam(param.name())] = py_obj
-        return paramMap
-
 
 @inherit_doc
 class PipelineWriter(MLWriter):
diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py
index b1acaf6..b266ced 100644
--- a/python/pyspark/ml/tests/test_tuning.py
+++ b/python/pyspark/ml/tests/test_tuning.py
@@ -73,7 +73,21 @@ class ParamGridBuilderTests(SparkSessionTestCase):
                     .build())
 
 
-class CrossValidatorTests(SparkSessionTestCase):
+class ValidatorTestUtilsMixin:
+    def assert_param_maps_equal(self, paramMaps1, paramMaps2):
+        self.assertEqual(len(paramMaps1), len(paramMaps2))
+        for paramMap1, paramMap2 in zip(paramMaps1, paramMaps2):
+            self.assertEqual(set(paramMap1.keys()), set(paramMap2.keys()))
+            for param in paramMap1.keys():
+                v1 = paramMap1[param]
+                v2 = paramMap2[param]
+                if isinstance(v1, Params):
+                    self.assertEqual(v1.uid, v2.uid)
+                else:
+                    self.assertEqual(v1, v2)
+
+
+class CrossValidatorTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
 
     def test_copy(self):
         dataset = self.spark.createDataFrame([
@@ -253,7 +267,7 @@ class CrossValidatorTests(SparkSessionTestCase):
         loadedCV = CrossValidator.load(cvPath)
         self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
         self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
-        self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
+        self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
 
         # test save/load of CrossValidatorModel
         cvModelPath = temp_path + "/cvModel"
@@ -348,6 +362,7 @@ class CrossValidatorTests(SparkSessionTestCase):
         cvPath = temp_path + "/cv"
         cv.save(cvPath)
         loadedCV = CrossValidator.load(cvPath)
+        self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), grid)
         self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
         self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
 
@@ -364,6 +379,7 @@ class CrossValidatorTests(SparkSessionTestCase):
         cvModelPath = temp_path + "/cvModel"
         cvModel.save(cvModelPath)
         loadedModel = CrossValidatorModel.load(cvModelPath)
+        self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid)
         self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
 
     def test_save_load_pipeline_estimator(self):
@@ -398,6 +414,11 @@ class CrossValidatorTests(SparkSessionTestCase):
                                   estimatorParamMaps=paramGrid,
                                   evaluator=MulticlassClassificationEvaluator(),
                                   numFolds=2)  # use 3+ folds in practice
+        cvPath = temp_path + "/cv"
+        crossval.save(cvPath)
+        loadedCV = CrossValidator.load(cvPath)
+        self.assert_param_maps_equal(loadedCV.getEstimatorParamMaps(), paramGrid)
+        self.assertEqual(loadedCV.getEstimator().uid, crossval.getEstimator().uid)
 
         # Run cross-validation, and choose the best set of parameters.
         cvModel = crossval.fit(training)
@@ -418,6 +439,11 @@ class CrossValidatorTests(SparkSessionTestCase):
                                    estimatorParamMaps=paramGrid,
                                    evaluator=MulticlassClassificationEvaluator(),
                                    numFolds=2)  # use 3+ folds in practice
+        cv2Path = temp_path + "/cv2"
+        crossval2.save(cv2Path)
+        loadedCV2 = CrossValidator.load(cv2Path)
+        self.assert_param_maps_equal(loadedCV2.getEstimatorParamMaps(), paramGrid)
+        self.assertEqual(loadedCV2.getEstimator().uid, crossval2.getEstimator().uid)
 
         # Run cross-validation, and choose the best set of parameters.
         cvModel2 = crossval2.fit(training)
@@ -436,7 +462,7 @@ class CrossValidatorTests(SparkSessionTestCase):
             self.assertEqual(loadedStage.uid, originalStage.uid)
 
 
-class TrainValidationSplitTests(SparkSessionTestCase):
+class TrainValidationSplitTests(SparkSessionTestCase, ValidatorTestUtilsMixin):
 
     def test_fit_minimize_metric(self):
         dataset = self.spark.createDataFrame([
@@ -557,7 +583,8 @@ class TrainValidationSplitTests(SparkSessionTestCase):
         loadedTvs = TrainValidationSplit.load(tvsPath)
         self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
         self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
-        self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
+        self.assert_param_maps_equal(
+            loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
 
         tvsModelPath = temp_path + "/tvsModel"
         tvsModel.save(tvsModelPath)
@@ -638,6 +665,7 @@ class TrainValidationSplitTests(SparkSessionTestCase):
         tvsPath = temp_path + "/tvs"
         tvs.save(tvsPath)
         loadedTvs = TrainValidationSplit.load(tvsPath)
+        self.assert_param_maps_equal(loadedTvs.getEstimatorParamMaps(), grid)
         self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
         self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
 
@@ -653,6 +681,7 @@ class TrainValidationSplitTests(SparkSessionTestCase):
         tvsModelPath = temp_path + "/tvsModel"
         tvsModel.save(tvsModelPath)
         loadedModel = TrainValidationSplitModel.load(tvsModelPath)
+        self.assert_param_maps_equal(loadedModel.getEstimatorParamMaps(), grid)
         self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
 
     def test_save_load_pipeline_estimator(self):
@@ -686,6 +715,11 @@ class TrainValidationSplitTests(SparkSessionTestCase):
         tvs = TrainValidationSplit(estimator=pipeline,
                                    estimatorParamMaps=paramGrid,
                                    evaluator=MulticlassClassificationEvaluator())
+        tvsPath = temp_path + "/tvs"
+        tvs.save(tvsPath)
+        loadedTvs = TrainValidationSplit.load(tvsPath)
+        self.assert_param_maps_equal(loadedTvs.getEstimatorParamMaps(), paramGrid)
+        self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
 
         # Run train validation split, and choose the best set of parameters.
         tvsModel = tvs.fit(training)
@@ -705,6 +739,11 @@ class TrainValidationSplitTests(SparkSessionTestCase):
         tvs2 = TrainValidationSplit(estimator=nested_pipeline,
                                     estimatorParamMaps=paramGrid,
                                     evaluator=MulticlassClassificationEvaluator())
+        tvs2Path = temp_path + "/tvs2"
+        tvs2.save(tvs2Path)
+        loadedTvs2 = TrainValidationSplit.load(tvs2Path)
+        self.assert_param_maps_equal(loadedTvs2.getEstimatorParamMaps(), paramGrid)
+        self.assertEqual(loadedTvs2.getEstimator().uid, tvs2.getEstimator().uid)
 
         # Run train validation split, and choose the best set of parameters.
         tvsModel2 = tvs2.fit(training)
diff --git a/python/pyspark/ml/tests/test_util.py b/python/pyspark/ml/tests/test_util.py
new file mode 100644
index 0000000..498a649
--- /dev/null
+++ b/python/pyspark/ml/tests/test_util.py
@@ -0,0 +1,84 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.ml import Pipeline
+from pyspark.ml.classification import LogisticRegression, OneVsRest
+from pyspark.ml.feature import VectorAssembler
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.util import MetaAlgorithmReadWrite
+from pyspark.testing.mlutils import SparkSessionTestCase
+
+
+class MetaAlgorithmReadWriteTests(SparkSessionTestCase):
+
+    def test_getAllNestedStages(self):
+        def _check_uid_set_equal(stages, expected_stages):
+            uids = set(map(lambda x: x.uid, stages))
+            expected_uids = set(map(lambda x: x.uid, expected_stages))
+            self.assertEqual(uids, expected_uids)
+
+        df1 = self.spark.createDataFrame([
+            (Vectors.dense([1., 2.]), 1.0),
+            (Vectors.dense([-1., -2.]), 0.0),
+        ], ['features', 'label'])
+        df2 = self.spark.createDataFrame([
+            (1., 2., 1.0),
+            (1., 2., 0.0),
+        ], ['a', 'b', 'label'])
+        vs = VectorAssembler(inputCols=['a', 'b'], outputCol='features')
+        lr = LogisticRegression()
+        pipeline = Pipeline(stages=[vs, lr])
+        pipelineModel = pipeline.fit(df2)
+        ova = OneVsRest(classifier=lr)
+        ovaModel = ova.fit(df1)
+
+        ova_pipeline = Pipeline(stages=[vs, ova])
+        nested_pipeline = Pipeline(stages=[ova_pipeline])
+
+        _check_uid_set_equal(
+            MetaAlgorithmReadWrite.getAllNestedStages(pipeline),
+            [pipeline, vs, lr]
+        )
+        _check_uid_set_equal(
+            MetaAlgorithmReadWrite.getAllNestedStages(pipelineModel),
+            [pipelineModel] + pipelineModel.stages
+        )
+        _check_uid_set_equal(
+            MetaAlgorithmReadWrite.getAllNestedStages(ova),
+            [ova, lr]
+        )
+        _check_uid_set_equal(
+            MetaAlgorithmReadWrite.getAllNestedStages(ovaModel),
+            [ovaModel, lr] + ovaModel.models
+        )
+        _check_uid_set_equal(
+            MetaAlgorithmReadWrite.getAllNestedStages(nested_pipeline),
+            [nested_pipeline, ova_pipeline, vs, ova, lr]
+        )
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_util import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 6283c8b..e16ea57 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -26,8 +26,9 @@ from pyspark.ml.common import _py2java, _java2py
 from pyspark.ml.param import Params, Param, TypeConverters
 from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed
 from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaParams
-from pyspark.sql.functions import rand
+from pyspark.ml.wrapper import JavaParams, JavaEstimator, JavaWrapper
+from pyspark.sql.functions import col, lit, rand, UserDefinedFunction
+from pyspark.sql.types import BooleanType
 
 __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
            'TrainValidationSplitModel']
@@ -50,6 +51,10 @@ def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel):
 
     def singleTask():
         index, model = next(modelIter)
+        # TODO: duplicate evaluator to take extra params from input
+        #  Note: Supporting tuning params in evaluator need update method
+        #  `MetaAlgorithmReadWrite.getAllNestedStages`, make it return
+        #  all nested stages and evaluators
         metric = eva.evaluate(model.transform(validation, epm[index]))
         return index, metric, model if collectSubModel else None
 
@@ -169,8 +174,16 @@ class _ValidatorParams(HasSeed):
         # Load information from java_stage to the instance.
         estimator = JavaParams._from_java(java_stage.getEstimator())
         evaluator = JavaParams._from_java(java_stage.getEvaluator())
-        epms = [estimator._transfer_param_map_from_java(epm)
-                for epm in java_stage.getEstimatorParamMaps()]
+        if isinstance(estimator, JavaEstimator):
+            epms = [estimator._transfer_param_map_from_java(epm)
+                    for epm in java_stage.getEstimatorParamMaps()]
+        elif MetaAlgorithmReadWrite.isMetaEstimator(estimator):
+            # Meta estimator such as Pipeline, OneVsRest
+            epms = _ValidatorSharedReadWrite.meta_estimator_transfer_param_maps_from_java(
+                estimator, java_stage.getEstimatorParamMaps())
+        else:
+            raise ValueError('Unsupported estimator used in tuning: ' + str(estimator))
+
         return estimator, epms, evaluator
 
     def _to_java_impl(self):
@@ -181,15 +194,80 @@ class _ValidatorParams(HasSeed):
         gateway = SparkContext._gateway
         cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
 
-        java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
-        for idx, epm in enumerate(self.getEstimatorParamMaps()):
-            java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm)
+        estimator = self.getEstimator()
+        if isinstance(estimator, JavaEstimator):
+            java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
+            for idx, epm in enumerate(self.getEstimatorParamMaps()):
+                java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm)
+        elif MetaAlgorithmReadWrite.isMetaEstimator(estimator):
+            # Meta estimator such as Pipeline, OneVsRest
+            java_epms = _ValidatorSharedReadWrite.meta_estimator_transfer_param_maps_to_java(
+                estimator, self.getEstimatorParamMaps())
+        else:
+            raise ValueError('Unsupported estimator used in tuning: ' + str(estimator))
 
         java_estimator = self.getEstimator()._to_java()
         java_evaluator = self.getEvaluator()._to_java()
         return java_estimator, java_epms, java_evaluator
 
 
+class _ValidatorSharedReadWrite:
+    @staticmethod
+    def meta_estimator_transfer_param_maps_to_java(pyEstimator, pyParamMaps):
+        pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator)
+        stagePairs = list(map(lambda stage: (stage, stage._to_java()), pyStages))
+        sc = SparkContext._active_spark_context
+
+        paramMapCls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
+        javaParamMaps = SparkContext._gateway.new_array(paramMapCls, len(pyParamMaps))
+
+        for idx, pyParamMap in enumerate(pyParamMaps):
+            javaParamMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
+            for pyParam, pyValue in pyParamMap.items():
+                javaParam = None
+                for pyStage, javaStage in stagePairs:
+                    if pyStage._testOwnParam(pyParam.parent, pyParam.name):
+                        javaParam = javaStage.getParam(pyParam.name)
+                        break
+                if javaParam is None:
+                    raise ValueError('Resolve param in estimatorParamMaps failed: ' + str(pyParam))
+                if isinstance(pyValue, Params) and hasattr(pyValue, "_to_java"):
+                    javaValue = pyValue._to_java()
+                else:
+                    javaValue = _py2java(sc, pyValue)
+                pair = javaParam.w(javaValue)
+                javaParamMap.put([pair])
+            javaParamMaps[idx] = javaParamMap
+        return javaParamMaps
+
+    @staticmethod
+    def meta_estimator_transfer_param_maps_from_java(pyEstimator, javaParamMaps):
+        pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator)
+        stagePairs = list(map(lambda stage: (stage, stage._to_java()), pyStages))
+        sc = SparkContext._active_spark_context
+        pyParamMaps = []
+        for javaParamMap in javaParamMaps:
+            pyParamMap = dict()
+            for javaPair in javaParamMap.toList():
+                javaParam = javaPair.param()
+                pyParam = None
+                for pyStage, javaStage in stagePairs:
+                    if pyStage._testOwnParam(javaParam.parent(), javaParam.name()):
+                        pyParam = pyStage.getParam(javaParam.name())
+                if pyParam is None:
+                    raise ValueError('Resolve param in estimatorParamMaps failed: ' +
+                                     javaParam.parent() + '.' + javaParam.name())
+                javaValue = javaPair.value()
+                if sc._jvm.Class.forName("org.apache.spark.ml.util.DefaultParamsWritable") \
+                        .isInstance(javaValue):
+                    pyValue = JavaParams._from_java(javaValue)
+                else:
+                    pyValue = _java2py(sc, javaValue)
+                pyParamMap[pyParam] = pyValue
+            pyParamMaps.append(pyParamMap)
+        return pyParamMaps
+
+
 class _CrossValidatorParams(_ValidatorParams):
     """
     Params for :py:class:`CrossValidator` and :py:class:`CrossValidatorModel`.
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index aac2b38..0f9d8e7 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -584,3 +584,41 @@ class HasTrainingSummary(object):
         no summary exists.
         """
         return (self._call_java("summary"))
+
+
+class MetaAlgorithmReadWrite:
+
+    @staticmethod
+    def isMetaEstimator(pyInstance):
+        from pyspark.ml import Estimator, Pipeline
+        from pyspark.ml.tuning import _ValidatorParams
+        from pyspark.ml.classification import OneVsRest
+        return isinstance(pyInstance, Pipeline) or isinstance(pyInstance, OneVsRest) or \
+            (isinstance(pyInstance, Estimator) and isinstance(pyInstance, _ValidatorParams))
+
+    @staticmethod
+    def getAllNestedStages(pyInstance):
+        from pyspark.ml import Pipeline, PipelineModel
+        from pyspark.ml.tuning import _ValidatorParams
+        from pyspark.ml.classification import OneVsRest, OneVsRestModel
+
+        # TODO: We need to handle `RFormulaModel.pipelineModel` here after Pyspark RFormulaModel
+        #  support pipelineModel property.
+        if isinstance(pyInstance, Pipeline):
+            pySubStages = pyInstance.getStages()
+        elif isinstance(pyInstance, PipelineModel):
+            pySubStages = pyInstance.stages
+        elif isinstance(pyInstance, _ValidatorParams):
+            raise ValueError('PySpark does not support nested validator.')
+        elif isinstance(pyInstance, OneVsRest):
+            pySubStages = [pyInstance.getClassifier()]
+        elif isinstance(pyInstance, OneVsRestModel):
+            pySubStages = [pyInstance.getClassifier()] + pyInstance.models
+        else:
+            pySubStages = []
+
+        nestedStages = []
+        for pySubStage in pySubStages:
+            nestedStages.extend(MetaAlgorithmReadWrite.getAllNestedStages(pySubStage))
+
+        return [pyInstance] + nestedStages


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org