You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2017/01/31 23:42:48 UTC

spark git commit: [SPARK-17161][PYSPARK][ML] Add PySpark-ML JavaWrapper convenience function to create Py4J JavaArrays

Repository: spark
Updated Branches:
  refs/heads/master ce112cec4 -> 57d70d26c


[SPARK-17161][PYSPARK][ML] Add PySpark-ML JavaWrapper convenience function to create Py4J JavaArrays

## What changes were proposed in this pull request?

Adding convenience function to Python `JavaWrapper` so that it is easy to create a Py4J JavaArray that is compatible with current class constructors that have a Scala `Array` as input so that it is not necessary to have a Java/Python friendly constructor.  The function takes a Java class as input that is used by Py4J to create the Java array of the given class.  As an example, `OneVsRest` has been updated to use this and the alternate constructor is removed.

## How was this patch tested?

Added unit tests for the new convenience function and updated `OneVsRest` doctests which use this to persist the model.

Author: Bryan Cutler <cu...@gmail.com>

Closes #14725 from BryanCutler/pyspark-new_java_array-CountVectorizer-SPARK-17161.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/57d70d26
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/57d70d26
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/57d70d26

Branch: refs/heads/master
Commit: 57d70d26c88819360cdc806e7124aa2cc1b9e4c5
Parents: ce112ce
Author: Bryan Cutler <cu...@gmail.com>
Authored: Tue Jan 31 15:42:36 2017 -0800
Committer: Holden Karau <ho...@us.ibm.com>
Committed: Tue Jan 31 15:42:36 2017 -0800

----------------------------------------------------------------------
 .../spark/ml/classification/OneVsRest.scala     |  5 ---
 project/MimaExcludes.scala                      |  5 ++-
 python/pyspark/ml/classification.py             | 11 +++++-
 python/pyspark/ml/tests.py                      | 40 +++++++++++++++++++-
 python/pyspark/ml/wrapper.py                    | 29 ++++++++++++++
 5 files changed, 81 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/57d70d26/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index cbd508a..7cbcccf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -135,11 +135,6 @@ final class OneVsRestModel private[ml] (
     @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
   extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
 
-  /** A Python-friendly auxiliary constructor. */
-  private[ml] def this(uid: String, models: JList[_ <: ClassificationModel[_, _]]) = {
-    this(uid, Metadata.empty, models.asScala.toArray)
-  }
-
   /** @group setParam */
   @Since("2.1.0")
   def setFeaturesCol(value: String): this.type = set(featuresCol, value)

http://git-wip-us.apache.org/repos/asf/spark/blob/57d70d26/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 7e6e143..9d35942 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -54,7 +54,10 @@ object MimaExcludes {
     // [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API.
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"),
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$10"),
-    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$11")
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$11"),
+
+    // [SPARK-17161] Removing Python-friendly constructors not needed
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this")
   )
 
   // Exclude rules for 2.1.x

http://git-wip-us.apache.org/repos/asf/spark/blob/57d70d26/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index f10556c..d41fc81 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -1517,6 +1517,11 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
     >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4))]).toDF()
     >>> model.transform(test2).head().prediction
     2.0
+    >>> model_path = temp_path + "/ovr_model"
+    >>> model.save(model_path)
+    >>> model2 = OneVsRestModel.load(model_path)
+    >>> model2.transform(test0).head().prediction
+    1.0
 
     .. versionadded:: 2.0.0
     """
@@ -1759,9 +1764,13 @@ class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable):
 
         :return: Java object equivalent to this instance.
         """
+        sc = SparkContext._active_spark_context
         java_models = [model._to_java() for model in self.models]
+        java_models_array = JavaWrapper._new_java_array(
+            java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel)
+        metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
         _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
-                                             self.uid, java_models)
+                                             self.uid, metadata.empty(), java_models_array)
         _java_obj.set("classifier", self.getClassifier()._to_java())
         _java_obj.set("featuresCol", self.getFeaturesCol())
         _java_obj.set("labelCol", self.getLabelCol())

http://git-wip-us.apache.org/repos/asf/spark/blob/57d70d26/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 68f5bc3..53204cd 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -60,8 +60,8 @@ from pyspark.ml.recommendation import ALS
 from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \
     GeneralizedLinearRegression
 from pyspark.ml.tuning import *
-from pyspark.ml.wrapper import JavaParams
-from pyspark.ml.common import _java2py
+from pyspark.ml.wrapper import JavaParams, JavaWrapper
+from pyspark.ml.common import _java2py, _py2java
 from pyspark.serializers import PickleSerializer
 from pyspark.sql import DataFrame, Row, SparkSession
 from pyspark.sql.functions import rand
@@ -1620,6 +1620,42 @@ class MatrixUDTTests(MLlibTestCase):
                 raise ValueError("Expected a matrix but got type %r" % type(m))
 
 
+class WrapperTests(MLlibTestCase):
+
+    def test_new_java_array(self):
+        # test array of strings
+        str_list = ["a", "b", "c"]
+        java_class = self.sc._gateway.jvm.java.lang.String
+        java_array = JavaWrapper._new_java_array(str_list, java_class)
+        self.assertEqual(_java2py(self.sc, java_array), str_list)
+        # test array of integers
+        int_list = [1, 2, 3]
+        java_class = self.sc._gateway.jvm.java.lang.Integer
+        java_array = JavaWrapper._new_java_array(int_list, java_class)
+        self.assertEqual(_java2py(self.sc, java_array), int_list)
+        # test array of floats
+        float_list = [0.1, 0.2, 0.3]
+        java_class = self.sc._gateway.jvm.java.lang.Double
+        java_array = JavaWrapper._new_java_array(float_list, java_class)
+        self.assertEqual(_java2py(self.sc, java_array), float_list)
+        # test array of bools
+        bool_list = [False, True, True]
+        java_class = self.sc._gateway.jvm.java.lang.Boolean
+        java_array = JavaWrapper._new_java_array(bool_list, java_class)
+        self.assertEqual(_java2py(self.sc, java_array), bool_list)
+        # test array of Java DenseVectors
+        v1 = DenseVector([0.0, 1.0])
+        v2 = DenseVector([1.0, 0.0])
+        vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)]
+        java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector
+        java_array = JavaWrapper._new_java_array(vec_java_list, java_class)
+        self.assertEqual(_java2py(self.sc, java_array), [v1, v2])
+        # test empty array
+        java_class = self.sc._gateway.jvm.java.lang.Integer
+        java_array = JavaWrapper._new_java_array([], java_class)
+        self.assertEqual(_java2py(self.sc, java_array), [])
+
+
 if __name__ == "__main__":
     from pyspark.ml.tests import *
     if xmlrunner:

http://git-wip-us.apache.org/repos/asf/spark/blob/57d70d26/python/pyspark/ml/wrapper.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 13b75e9..80a0b31 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -16,6 +16,9 @@
 #
 
 from abc import ABCMeta, abstractmethod
+import sys
+if sys.version >= '3':
+    xrange = range
 
 from pyspark import SparkContext
 from pyspark.sql import DataFrame
@@ -59,6 +62,32 @@ class JavaWrapper(object):
         java_args = [_py2java(sc, arg) for arg in args]
         return java_obj(*java_args)
 
+    @staticmethod
+    def _new_java_array(pylist, java_class):
+        """
+        Create a Java array of given java_class type. Useful for
+        calling a method with a Scala Array from Python with Py4J.
+
+        :param pylist:
+          Python list to convert to a Java Array.
+        :param java_class:
+          Java class to specify the type of Array. Should be in the
+          form of sc._gateway.jvm.* (sc is a valid Spark Context).
+        :return:
+          Java Array of converted pylist.
+
+        Example primitive Java classes:
+          - basestring -> sc._gateway.jvm.java.lang.String
+          - int -> sc._gateway.jvm.java.lang.Integer
+          - float -> sc._gateway.jvm.java.lang.Double
+          - bool -> sc._gateway.jvm.java.lang.Boolean
+        """
+        sc = SparkContext._active_spark_context
+        java_array = sc._gateway.new_array(java_class, len(pylist))
+        for i in xrange(len(pylist)):
+            java_array[i] = pylist[i]
+        return java_array
+
 
 @inherit_doc
 class JavaParams(JavaWrapper, Params):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org