You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/04/20 09:49:00 UTC
[spark] 01/02: init
This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch fix_pipeline_tuning
in repository https://gitbox.apache.org/repos/asf/spark.git
commit c834fe8f335dc74db6346d82b5ce4cf742cba9bb
Author: Weichen Xu <we...@databricks.com>
AuthorDate: Mon Apr 20 17:04:12 2020 +0800
init
---
python/pyspark/ml/pipeline.py | 46 +++++++++++++++++++++++++++++++++++++++++--
1 file changed, 44 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 09e0748..0004b64 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -25,8 +25,8 @@ from pyspark import since, keyword_only, SparkContext
from pyspark.ml.base import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaParams
-from pyspark.ml.common import inherit_doc
+from pyspark.ml.wrapper import JavaParams, JavaWrapper
+from pyspark.ml.common import inherit_doc, _java2py, _py2java
@inherit_doc
@@ -174,6 +174,48 @@ class Pipeline(Estimator, MLReadable, MLWritable):
return _java_obj
+ def _make_java_param_pair(self, param, value):
+ """
+ Makes a Java param pair.
+ """
+ sc = SparkContext._active_spark_context
+ param = self._resolveParam(param)
+ java_param = sc._jvm.org.apache.spark.ml.param.Param(param.parent, param.name, param.doc)
+ if isinstance(value, JavaParams):
+ # used in the case of an estimator having another estimator as a parameter
+ # the reason why this is not in _py2java in common.py is that importing
+ # Estimator and Model in common.py results in a circular import with inherit_doc
+ java_value = value._to_java()
+ else:
+ java_value = _py2java(sc, value)
+ return java_param.w(java_value)
+
+ def _transfer_param_map_to_java(self, pyParamMap):
+ """
+ Transforms a Python ParamMap into a Java ParamMap.
+ """
+ paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
+ for param in self.params:
+ if param in pyParamMap:
+ pair = self._make_java_param_pair(param, pyParamMap[param])
+ paramMap.put([pair])
+ return paramMap
+
+ def _transfer_param_map_from_java(self, javaParamMap):
+ """
+ Transforms a Java ParamMap into a Python ParamMap.
+ """
+ sc = SparkContext._active_spark_context
+ paramMap = dict()
+ for pair in javaParamMap.toList():
+ param = pair.param()
+ if self.hasParam(str(param.name())):
+ if param.name() == "classifier":
+ paramMap[self.getParam(param.name())] = JavaParams._from_java(pair.value())
+ else:
+ paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
+ return paramMap
+
@inherit_doc
class PipelineWriter(MLWriter):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org