You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/07/11 14:03:19 UTC

spark git commit: [SPARK-21365][PYTHON] Deduplicate logics parsing DDL type/schema definition

Repository: spark
Updated Branches:
  refs/heads/master 66d216865 -> ebc124d4c


[SPARK-21365][PYTHON] Deduplicate logics parsing DDL type/schema definition

## What changes were proposed in this pull request?

This PR deals with four points as below:

- Reuse existing DDL parser APIs rather than reimplementing within PySpark

- Support DDL formatted string, `field type, field type`.

- Support case-insensitivity for parsing.

- Support nested data types as below:

  **Before**
  ```
  >>> spark.createDataFrame([[[1]]], "struct<a: struct<b: int>>").show()
  ...
  ValueError: The strcut field string format is: 'field_name:field_type', but got: a: struct<b: int>
  ```

  ```
  >>> spark.createDataFrame([[[1]]], "a: struct<b: int>").show()
  ...
  ValueError: The strcut field string format is: 'field_name:field_type', but got: a: struct<b: int>
  ```

  ```
  >>> spark.createDataFrame([[1]], "a int").show()
  ...
  ValueError: Could not parse datatype: a int
  ```

  **After**
  ```
  >>> spark.createDataFrame([[[1]]], "struct<a: struct<b: int>>").show()
  +---+
  |  a|
  +---+
  |[1]|
  +---+
  ```

  ```
  >>> spark.createDataFrame([[[1]]], "a: struct<b: int>").show()
  +---+
  |  a|
  +---+
  |[1]|
  +---+
  ```

  ```
  >>> spark.createDataFrame([[1]], "a int").show()
  +---+
  |  a|
  +---+
  |  1|
  +---+
  ```

## How was this patch tested?

Author: hyukjinkwon <gu...@gmail.com>

Closes #18590 from HyukjinKwon/deduplicate-python-ddl.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ebc124d4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ebc124d4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ebc124d4

Branch: refs/heads/master
Commit: ebc124d4c44d4c84f7868f390f778c0ff5cd66cb
Parents: 66d2168
Author: hyukjinkwon <gu...@gmail.com>
Authored: Tue Jul 11 22:03:10 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Jul 11 22:03:10 2017 +0800

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 16 +++-
 python/pyspark/sql/tests.py                     | 25 ++++++
 python/pyspark/sql/types.py                     | 88 ++++++++------------
 .../spark/sql/api/python/PythonSQLUtils.scala   | 25 ++++++
 4 files changed, 97 insertions(+), 57 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ebc124d4/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index f203d85..d45ff63 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2037,16 +2037,26 @@ class UserDefinedFunction(object):
                 "{0}".format(type(func)))
 
         self.func = func
-        self.returnType = (
-            returnType if isinstance(returnType, DataType)
-            else _parse_datatype_string(returnType))
+        self._returnType = returnType
         # Stores UserDefinedPythonFunctions jobj, once initialized
+        self._returnType_placeholder = None
         self._judf_placeholder = None
         self._name = name or (
             func.__name__ if hasattr(func, '__name__')
             else func.__class__.__name__)
 
     @property
+    def returnType(self):
+        # This makes sure this is called after SparkContext is initialized.
+        # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string.
+        if self._returnType_placeholder is None:
+            if isinstance(self._returnType, DataType):
+                self._returnType_placeholder = self._returnType
+            else:
+                self._returnType_placeholder = _parse_datatype_string(self._returnType)
+        return self._returnType_placeholder
+
+    @property
     def _judf(self):
         # It is possible that concurrent access, to newly created UDF,
         # will initialize multiple UserDefinedPythonFunctions.

http://git-wip-us.apache.org/repos/asf/spark/blob/ebc124d4/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index bd8477e..29e48a6 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1255,6 +1255,31 @@ class SQLTests(ReusedPySparkTestCase):
         with self.assertRaises(TypeError):
             not_a_field = struct1[9.9]
 
+    def test_parse_datatype_string(self):
+        from pyspark.sql.types import _all_atomic_types, _parse_datatype_string
+        for k, t in _all_atomic_types.items():
+            if t != NullType:
+                self.assertEqual(t(), _parse_datatype_string(k))
+        self.assertEqual(IntegerType(), _parse_datatype_string("int"))
+        self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1  ,1)"))
+        self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )"))
+        self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)"))
+        self.assertEqual(
+            ArrayType(IntegerType()),
+            _parse_datatype_string("array<int >"))
+        self.assertEqual(
+            MapType(IntegerType(), DoubleType()),
+            _parse_datatype_string("map< int, double  >"))
+        self.assertEqual(
+            StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
+            _parse_datatype_string("struct<a:int, c:double >"))
+        self.assertEqual(
+            StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
+            _parse_datatype_string("a:int, c:double"))
+        self.assertEqual(
+            StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
+            _parse_datatype_string("a INT, c DOUBLE"))
+
     def test_metadata_null(self):
         from pyspark.sql.types import StructType, StringType, StructField
         schema = StructType([StructField("f1", StringType(), True, None),

http://git-wip-us.apache.org/repos/asf/spark/blob/ebc124d4/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index f5505ed..22fa273 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -32,6 +32,7 @@ if sys.version >= "3":
 from py4j.protocol import register_input_converter
 from py4j.java_gateway import JavaClass
 
+from pyspark import SparkContext
 from pyspark.serializers import CloudPickleSerializer
 
 __all__ = [
@@ -727,18 +728,6 @@ _FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)")
 _BRACKETS = {'(': ')', '[': ']', '{': '}'}
 
 
-def _parse_basic_datatype_string(s):
-    if s in _all_atomic_types.keys():
-        return _all_atomic_types[s]()
-    elif s == "int":
-        return IntegerType()
-    elif _FIXED_DECIMAL.match(s):
-        m = _FIXED_DECIMAL.match(s)
-        return DecimalType(int(m.group(1)), int(m.group(2)))
-    else:
-        raise ValueError("Could not parse datatype: %s" % s)
-
-
 def _ignore_brackets_split(s, separator):
     """
     Splits the given string by given separator, but ignore separators inside brackets pairs, e.g.
@@ -771,32 +760,23 @@ def _ignore_brackets_split(s, separator):
     return parts
 
 
-def _parse_struct_fields_string(s):
-    parts = _ignore_brackets_split(s, ",")
-    fields = []
-    for part in parts:
-        name_and_type = _ignore_brackets_split(part, ":")
-        if len(name_and_type) != 2:
-            raise ValueError("The strcut field string format is: 'field_name:field_type', " +
-                             "but got: %s" % part)
-        field_name = name_and_type[0].strip()
-        field_type = _parse_datatype_string(name_and_type[1])
-        fields.append(StructField(field_name, field_type))
-    return StructType(fields)
-
-
 def _parse_datatype_string(s):
     """
     Parses the given data type string to a :class:`DataType`. The data type string format equals
     to :class:`DataType.simpleString`, except that top level struct type can omit
     the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead
     of ``tinyint`` for :class:`ByteType`. We can also use ``int`` as a short name
-    for :class:`IntegerType`.
+    for :class:`IntegerType`. Since Spark 2.3, this also supports a schema in a DDL-formatted
+    string and case-insensitive strings.
 
     >>> _parse_datatype_string("int ")
     IntegerType
+    >>> _parse_datatype_string("INT ")
+    IntegerType
     >>> _parse_datatype_string("a: byte, b: decimal(  16 , 8   ) ")
     StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))
+    >>> _parse_datatype_string("a DOUBLE, b STRING")
+    StructType(List(StructField(a,DoubleType,true),StructField(b,StringType,true)))
     >>> _parse_datatype_string("a: array< short>")
     StructType(List(StructField(a,ArrayType(ShortType,true),true)))
     >>> _parse_datatype_string(" map<string , string > ")
@@ -806,43 +786,43 @@ def _parse_datatype_string(s):
     >>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
     Traceback (most recent call last):
         ...
-    ValueError:...
+    ParseException:...
     >>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
     Traceback (most recent call last):
         ...
-    ValueError:...
+    ParseException:...
     >>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL
     Traceback (most recent call last):
         ...
-    ValueError:...
+    ParseException:...
     >>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL
     Traceback (most recent call last):
         ...
-    ValueError:...
+    ParseException:...
     """
-    s = s.strip()
-    if s.startswith("array<"):
-        if s[-1] != ">":
-            raise ValueError("'>' should be the last char, but got: %s" % s)
-        return ArrayType(_parse_datatype_string(s[6:-1]))
-    elif s.startswith("map<"):
-        if s[-1] != ">":
-            raise ValueError("'>' should be the last char, but got: %s" % s)
-        parts = _ignore_brackets_split(s[4:-1], ",")
-        if len(parts) != 2:
-            raise ValueError("The map type string format is: 'map<key_type,value_type>', " +
-                             "but got: %s" % s)
-        kt = _parse_datatype_string(parts[0])
-        vt = _parse_datatype_string(parts[1])
-        return MapType(kt, vt)
-    elif s.startswith("struct<"):
-        if s[-1] != ">":
-            raise ValueError("'>' should be the last char, but got: %s" % s)
-        return _parse_struct_fields_string(s[7:-1])
-    elif ":" in s:
-        return _parse_struct_fields_string(s)
-    else:
-        return _parse_basic_datatype_string(s)
+    sc = SparkContext._active_spark_context
+
+    def from_ddl_schema(type_str):
+        return _parse_datatype_json_string(
+            sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json())
+
+    def from_ddl_datatype(type_str):
+        return _parse_datatype_json_string(
+            sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json())
+
+    try:
+        # DDL format, "fieldname datatype, fieldname datatype".
+        return from_ddl_schema(s)
+    except Exception as e:
+        try:
+            # For backwards compatibility, "integer", "struct<fieldname: datatype>" and etc.
+            return from_ddl_datatype(s)
+        except:
+            try:
+                # For backwards compatibility, "fieldname: datatype, fieldname: datatype" case.
+                return from_ddl_datatype("struct<%s>" % s.strip())
+            except:
+                raise e
 
 
 def _parse_datatype_json_string(json_string):

http://git-wip-us.apache.org/repos/asf/spark/blob/ebc124d4/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
new file mode 100644
index 0000000..731feb9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.python
+
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.types.DataType
+
+private[sql] object PythonSQLUtils {
+  def parseDataType(typeText: String): DataType = CatalystSqlParser.parseDataType(typeText)
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org