You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2017/09/08 18:57:40 UTC

spark git commit: [SPARK-15243][ML][SQL][PYTHON] Add missing support for unicode in Param methods & functions in dataframe

Repository: spark
Updated Branches:
  refs/heads/master 8a4f228dc -> 8598d03a0


[SPARK-15243][ML][SQL][PYTHON] Add missing support for unicode in Param methods & functions in dataframe

## What changes were proposed in this pull request?

This PR proposes to support unicodes in Param methods in ML, other missed functions in DataFrame.

For example, this causes a `ValueError` in Python 2.x when param is a unicode string:

```python
>>> from pyspark.ml.classification import LogisticRegression
>>> lr = LogisticRegression()
>>> lr.hasParam("threshold")
True
>>> lr.hasParam(u"threshold")
Traceback (most recent call last):
 ...
    raise TypeError("hasParam(): paramName must be a string")
TypeError: hasParam(): paramName must be a string
```

This PR is based on https://github.com/apache/spark/pull/13036

## How was this patch tested?

Unit tests in `python/pyspark/ml/tests.py` and `python/pyspark/sql/tests.py`.

Author: hyukjinkwon <gu...@gmail.com>
Author: sethah <se...@gmail.com>

Closes #17096 from HyukjinKwon/SPARK-15243.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8598d03a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8598d03a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8598d03a

Branch: refs/heads/master
Commit: 8598d03a00a39dd23646bf752f9fed5d28e271c6
Parents: 8a4f228
Author: hyukjinkwon <gu...@gmail.com>
Authored: Fri Sep 8 11:57:33 2017 -0700
Committer: Holden Karau <ho...@us.ibm.com>
Committed: Fri Sep 8 11:57:33 2017 -0700

----------------------------------------------------------------------
 python/pyspark/ml/param/__init__.py |  4 ++--
 python/pyspark/ml/tests.py          | 15 +++++++++++++++
 python/pyspark/sql/dataframe.py     | 22 +++++++++++-----------
 python/pyspark/sql/tests.py         | 22 ++++++++++++++--------
 4 files changed, 42 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8598d03a/python/pyspark/ml/param/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index 1334207..043c25c 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -330,7 +330,7 @@ class Params(Identifiable):
         Tests whether this instance contains a param with a given
         (string) name.
         """
-        if isinstance(paramName, str):
+        if isinstance(paramName, basestring):
             p = getattr(self, paramName, None)
             return isinstance(p, Param)
         else:
@@ -413,7 +413,7 @@ class Params(Identifiable):
         if isinstance(param, Param):
             self._shouldOwn(param)
             return param
-        elif isinstance(param, str):
+        elif isinstance(param, basestring):
             return self.getParam(param)
         else:
             raise ValueError("Cannot resolve %r as a param." % param)

http://git-wip-us.apache.org/repos/asf/spark/blob/8598d03a/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 6076b3c..509698f 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
 #
 # Licensed to the Apache Software Foundation (ASF) under one or more
 # contributor license agreements.  See the NOTICE file distributed with
@@ -352,6 +353,20 @@ class ParamTests(PySparkTestCase):
         testParams = TestParams()
         self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params]))
         self.assertFalse(testParams.hasParam("notAParameter"))
+        self.assertTrue(testParams.hasParam(u"maxIter"))
+
+    def test_resolveparam(self):
+        testParams = TestParams()
+        self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter)
+        self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter)
+
+        self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter)
+        if sys.version_info[0] >= 3:
+            # In Python 3, it is allowed to get/set attributes with non-ascii characters.
+            e_cls = AttributeError
+        else:
+            e_cls = UnicodeEncodeError
+        self.assertRaises(e_cls, lambda: testParams._resolveParam(u"아"))
 
     def test_params(self):
         testParams = TestParams()

http://git-wip-us.apache.org/repos/asf/spark/blob/8598d03a/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1cea130..8f88545 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -748,7 +748,7 @@ class DataFrame(object):
         +---+-----+
 
         """
-        if not isinstance(col, str):
+        if not isinstance(col, basestring):
             raise ValueError("col must be a string, but got %r" % type(col))
         if not isinstance(fractions, dict):
             raise ValueError("fractions must be a dict but got %r" % type(fractions))
@@ -1664,18 +1664,18 @@ class DataFrame(object):
            Added support for multiple columns.
         """
 
-        if not isinstance(col, (str, list, tuple)):
+        if not isinstance(col, (basestring, list, tuple)):
             raise ValueError("col should be a string, list or tuple, but got %r" % type(col))
 
-        isStr = isinstance(col, str)
+        isStr = isinstance(col, basestring)
 
         if isinstance(col, tuple):
             col = list(col)
-        elif isinstance(col, str):
+        elif isStr:
             col = [col]
 
         for c in col:
-            if not isinstance(c, str):
+            if not isinstance(c, basestring):
                 raise ValueError("columns should be strings, but got %r" % type(c))
         col = _to_list(self._sc, col)
 
@@ -1707,9 +1707,9 @@ class DataFrame(object):
         :param col2: The name of the second column
         :param method: The correlation method. Currently only supports "pearson"
         """
-        if not isinstance(col1, str):
+        if not isinstance(col1, basestring):
             raise ValueError("col1 should be a string.")
-        if not isinstance(col2, str):
+        if not isinstance(col2, basestring):
             raise ValueError("col2 should be a string.")
         if not method:
             method = "pearson"
@@ -1727,9 +1727,9 @@ class DataFrame(object):
         :param col1: The name of the first column
         :param col2: The name of the second column
         """
-        if not isinstance(col1, str):
+        if not isinstance(col1, basestring):
             raise ValueError("col1 should be a string.")
-        if not isinstance(col2, str):
+        if not isinstance(col2, basestring):
             raise ValueError("col2 should be a string.")
         return self._jdf.stat().cov(col1, col2)
 
@@ -1749,9 +1749,9 @@ class DataFrame(object):
         :param col2: The name of the second column. Distinct items will make the column names
             of the DataFrame.
         """
-        if not isinstance(col1, str):
+        if not isinstance(col1, basestring):
             raise ValueError("col1 should be a string.")
-        if not isinstance(col2, str):
+        if not isinstance(col2, basestring):
             raise ValueError("col2 should be a string.")
         return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8598d03a/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1bc889c..4d65abc 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1140,11 +1140,12 @@ class SQLTests(ReusedPySparkTestCase):
 
     def test_approxQuantile(self):
         df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
-        aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1)
-        self.assertTrue(isinstance(aq, list))
-        self.assertEqual(len(aq), 3)
+        for f in ["a", u"a"]:
+            aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1)
+            self.assertTrue(isinstance(aq, list))
+            self.assertEqual(len(aq), 3)
         self.assertTrue(all(isinstance(q, float) for q in aq))
-        aqs = df.stat.approxQuantile(["a", "b"], [0.1, 0.5, 0.9], 0.1)
+        aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1)
         self.assertTrue(isinstance(aqs, list))
         self.assertEqual(len(aqs), 2)
         self.assertTrue(isinstance(aqs[0], list))
@@ -1153,7 +1154,7 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertTrue(isinstance(aqs[1], list))
         self.assertEqual(len(aqs[1]), 3)
         self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
-        aqt = df.stat.approxQuantile(("a", "b"), [0.1, 0.5, 0.9], 0.1)
+        aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1)
         self.assertTrue(isinstance(aqt, list))
         self.assertEqual(len(aqt), 2)
         self.assertTrue(isinstance(aqt[0], list))
@@ -1169,17 +1170,22 @@ class SQLTests(ReusedPySparkTestCase):
     def test_corr(self):
         import math
         df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
-        corr = df.stat.corr("a", "b")
+        corr = df.stat.corr(u"a", "b")
         self.assertTrue(abs(corr - 0.95734012) < 1e-6)
 
+    def test_sampleby(self):
+        df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF()
+        sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0)
+        self.assertTrue(sampled.count() == 3)
+
     def test_cov(self):
         df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
-        cov = df.stat.cov("a", "b")
+        cov = df.stat.cov(u"a", "b")
         self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
 
     def test_crosstab(self):
         df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
-        ct = df.stat.crosstab("a", "b").collect()
+        ct = df.stat.crosstab(u"a", "b").collect()
         ct = sorted(ct, key=lambda x: x[0])
         for i, row in enumerate(ct):
             self.assertEqual(row[0], str(i))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org