You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/04/20 09:49:00 UTC

[spark] 01/02: init

This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch fix_pipeline_tuning
in repository https://gitbox.apache.org/repos/asf/spark.git

commit c834fe8f335dc74db6346d82b5ce4cf742cba9bb
Author: Weichen Xu <we...@databricks.com>
AuthorDate: Mon Apr 20 17:04:12 2020 +0800

    init
---
 python/pyspark/ml/pipeline.py | 46 +++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 44 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 09e0748..0004b64 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -25,8 +25,8 @@ from pyspark import since, keyword_only, SparkContext
 from pyspark.ml.base import Estimator, Model, Transformer
 from pyspark.ml.param import Param, Params
 from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaParams
-from pyspark.ml.common import inherit_doc
+from pyspark.ml.wrapper import JavaParams, JavaWrapper
+from pyspark.ml.common import inherit_doc, _java2py, _py2java
 
 
 @inherit_doc
@@ -174,6 +174,48 @@ class Pipeline(Estimator, MLReadable, MLWritable):
 
         return _java_obj
 
+    def _make_java_param_pair(self, param, value):
+        """
+        Makes a Java param pair.
+        """
+        sc = SparkContext._active_spark_context
+        param = self._resolveParam(param)
+        java_param = sc._jvm.org.apache.spark.ml.param.Param(param.parent, param.name, param.doc)
+        if isinstance(value, JavaParams):
+            # used in the case of an estimator having another estimator as a parameter
+            # the reason why this is not in _py2java in common.py is that importing
+            # Estimator and Model in common.py results in a circular import with inherit_doc
+            java_value = value._to_java()
+        else:
+            java_value = _py2java(sc, value)
+        return java_param.w(java_value)
+
+    def _transfer_param_map_to_java(self, pyParamMap):
+        """
+        Transforms a Python ParamMap into a Java ParamMap.
+        """
+        paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
+        for param in self.params:
+            if param in pyParamMap:
+                pair = self._make_java_param_pair(param, pyParamMap[param])
+                paramMap.put([pair])
+        return paramMap
+
+    def _transfer_param_map_from_java(self, javaParamMap):
+        """
+        Transforms a Java ParamMap into a Python ParamMap.
+        """
+        sc = SparkContext._active_spark_context
+        paramMap = dict()
+        for pair in javaParamMap.toList():
+            param = pair.param()
+            if self.hasParam(str(param.name())):
+                if param.name() == "classifier":
+                    paramMap[self.getParam(param.name())] = JavaParams._from_java(pair.value())
+                else:
+                    paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
+        return paramMap
+
 
 @inherit_doc
 class PipelineWriter(MLWriter):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org