You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/07/11 14:03:19 UTC
spark git commit: [SPARK-21365][PYTHON] Deduplicate logics parsing
DDL type/schema definition
Repository: spark
Updated Branches:
refs/heads/master 66d216865 -> ebc124d4c
[SPARK-21365][PYTHON] Deduplicate logics parsing DDL type/schema definition
## What changes were proposed in this pull request?
This PR deals with four points as below:
- Reuse existing DDL parser APIs rather than reimplementing within PySpark
- Support DDL formatted string, `field type, field type`.
- Support case-insensitivity for parsing.
- Support nested data types as below:
**Before**
```
>>> spark.createDataFrame([[[1]]], "struct<a: struct<b: int>>").show()
...
ValueError: The strcut field string format is: 'field_name:field_type', but got: a: struct<b: int>
```
```
>>> spark.createDataFrame([[[1]]], "a: struct<b: int>").show()
...
ValueError: The strcut field string format is: 'field_name:field_type', but got: a: struct<b: int>
```
```
>>> spark.createDataFrame([[1]], "a int").show()
...
ValueError: Could not parse datatype: a int
```
**After**
```
>>> spark.createDataFrame([[[1]]], "struct<a: struct<b: int>>").show()
+---+
| a|
+---+
|[1]|
+---+
```
```
>>> spark.createDataFrame([[[1]]], "a: struct<b: int>").show()
+---+
| a|
+---+
|[1]|
+---+
```
```
>>> spark.createDataFrame([[1]], "a int").show()
+---+
| a|
+---+
| 1|
+---+
```
## How was this patch tested?
Author: hyukjinkwon <gu...@gmail.com>
Closes #18590 from HyukjinKwon/deduplicate-python-ddl.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ebc124d4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ebc124d4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ebc124d4
Branch: refs/heads/master
Commit: ebc124d4c44d4c84f7868f390f778c0ff5cd66cb
Parents: 66d2168
Author: hyukjinkwon <gu...@gmail.com>
Authored: Tue Jul 11 22:03:10 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Jul 11 22:03:10 2017 +0800
----------------------------------------------------------------------
python/pyspark/sql/functions.py | 16 +++-
python/pyspark/sql/tests.py | 25 ++++++
python/pyspark/sql/types.py | 88 ++++++++------------
.../spark/sql/api/python/PythonSQLUtils.scala | 25 ++++++
4 files changed, 97 insertions(+), 57 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/ebc124d4/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index f203d85..d45ff63 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2037,16 +2037,26 @@ class UserDefinedFunction(object):
"{0}".format(type(func)))
self.func = func
- self.returnType = (
- returnType if isinstance(returnType, DataType)
- else _parse_datatype_string(returnType))
+ self._returnType = returnType
# Stores UserDefinedPythonFunctions jobj, once initialized
+ self._returnType_placeholder = None
self._judf_placeholder = None
self._name = name or (
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
@property
+ def returnType(self):
+ # This makes sure this is called after SparkContext is initialized.
+ # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string.
+ if self._returnType_placeholder is None:
+ if isinstance(self._returnType, DataType):
+ self._returnType_placeholder = self._returnType
+ else:
+ self._returnType_placeholder = _parse_datatype_string(self._returnType)
+ return self._returnType_placeholder
+
+ @property
def _judf(self):
# It is possible that concurrent access, to newly created UDF,
# will initialize multiple UserDefinedPythonFunctions.
http://git-wip-us.apache.org/repos/asf/spark/blob/ebc124d4/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index bd8477e..29e48a6 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1255,6 +1255,31 @@ class SQLTests(ReusedPySparkTestCase):
with self.assertRaises(TypeError):
not_a_field = struct1[9.9]
+ def test_parse_datatype_string(self):
+ from pyspark.sql.types import _all_atomic_types, _parse_datatype_string
+ for k, t in _all_atomic_types.items():
+ if t != NullType:
+ self.assertEqual(t(), _parse_datatype_string(k))
+ self.assertEqual(IntegerType(), _parse_datatype_string("int"))
+ self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)"))
+ self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )"))
+ self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)"))
+ self.assertEqual(
+ ArrayType(IntegerType()),
+ _parse_datatype_string("array<int >"))
+ self.assertEqual(
+ MapType(IntegerType(), DoubleType()),
+ _parse_datatype_string("map< int, double >"))
+ self.assertEqual(
+ StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
+ _parse_datatype_string("struct<a:int, c:double >"))
+ self.assertEqual(
+ StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
+ _parse_datatype_string("a:int, c:double"))
+ self.assertEqual(
+ StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
+ _parse_datatype_string("a INT, c DOUBLE"))
+
def test_metadata_null(self):
from pyspark.sql.types import StructType, StringType, StructField
schema = StructType([StructField("f1", StringType(), True, None),
http://git-wip-us.apache.org/repos/asf/spark/blob/ebc124d4/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index f5505ed..22fa273 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -32,6 +32,7 @@ if sys.version >= "3":
from py4j.protocol import register_input_converter
from py4j.java_gateway import JavaClass
+from pyspark import SparkContext
from pyspark.serializers import CloudPickleSerializer
__all__ = [
@@ -727,18 +728,6 @@ _FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)")
_BRACKETS = {'(': ')', '[': ']', '{': '}'}
-def _parse_basic_datatype_string(s):
- if s in _all_atomic_types.keys():
- return _all_atomic_types[s]()
- elif s == "int":
- return IntegerType()
- elif _FIXED_DECIMAL.match(s):
- m = _FIXED_DECIMAL.match(s)
- return DecimalType(int(m.group(1)), int(m.group(2)))
- else:
- raise ValueError("Could not parse datatype: %s" % s)
-
-
def _ignore_brackets_split(s, separator):
"""
Splits the given string by given separator, but ignore separators inside brackets pairs, e.g.
@@ -771,32 +760,23 @@ def _ignore_brackets_split(s, separator):
return parts
-def _parse_struct_fields_string(s):
- parts = _ignore_brackets_split(s, ",")
- fields = []
- for part in parts:
- name_and_type = _ignore_brackets_split(part, ":")
- if len(name_and_type) != 2:
- raise ValueError("The strcut field string format is: 'field_name:field_type', " +
- "but got: %s" % part)
- field_name = name_and_type[0].strip()
- field_type = _parse_datatype_string(name_and_type[1])
- fields.append(StructField(field_name, field_type))
- return StructType(fields)
-
-
def _parse_datatype_string(s):
"""
Parses the given data type string to a :class:`DataType`. The data type string format equals
to :class:`DataType.simpleString`, except that top level struct type can omit
the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead
of ``tinyint`` for :class:`ByteType`. We can also use ``int`` as a short name
- for :class:`IntegerType`.
+ for :class:`IntegerType`. Since Spark 2.3, this also supports a schema in a DDL-formatted
+ string and case-insensitive strings.
>>> _parse_datatype_string("int ")
IntegerType
+ >>> _parse_datatype_string("INT ")
+ IntegerType
>>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ")
StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))
+ >>> _parse_datatype_string("a DOUBLE, b STRING")
+ StructType(List(StructField(a,DoubleType,true),StructField(b,StringType,true)))
>>> _parse_datatype_string("a: array< short>")
StructType(List(StructField(a,ArrayType(ShortType,true),true)))
>>> _parse_datatype_string(" map<string , string > ")
@@ -806,43 +786,43 @@ def _parse_datatype_string(s):
>>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
- ValueError:...
+ ParseException:...
>>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
- ValueError:...
+ ParseException:...
>>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
- ValueError:...
+ ParseException:...
>>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
- ValueError:...
+ ParseException:...
"""
- s = s.strip()
- if s.startswith("array<"):
- if s[-1] != ">":
- raise ValueError("'>' should be the last char, but got: %s" % s)
- return ArrayType(_parse_datatype_string(s[6:-1]))
- elif s.startswith("map<"):
- if s[-1] != ">":
- raise ValueError("'>' should be the last char, but got: %s" % s)
- parts = _ignore_brackets_split(s[4:-1], ",")
- if len(parts) != 2:
- raise ValueError("The map type string format is: 'map<key_type,value_type>', " +
- "but got: %s" % s)
- kt = _parse_datatype_string(parts[0])
- vt = _parse_datatype_string(parts[1])
- return MapType(kt, vt)
- elif s.startswith("struct<"):
- if s[-1] != ">":
- raise ValueError("'>' should be the last char, but got: %s" % s)
- return _parse_struct_fields_string(s[7:-1])
- elif ":" in s:
- return _parse_struct_fields_string(s)
- else:
- return _parse_basic_datatype_string(s)
+ sc = SparkContext._active_spark_context
+
+ def from_ddl_schema(type_str):
+ return _parse_datatype_json_string(
+ sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json())
+
+ def from_ddl_datatype(type_str):
+ return _parse_datatype_json_string(
+ sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json())
+
+ try:
+ # DDL format, "fieldname datatype, fieldname datatype".
+ return from_ddl_schema(s)
+ except Exception as e:
+ try:
+ # For backwards compatibility, "integer", "struct<fieldname: datatype>" and etc.
+ return from_ddl_datatype(s)
+ except:
+ try:
+ # For backwards compatibility, "fieldname: datatype, fieldname: datatype" case.
+ return from_ddl_datatype("struct<%s>" % s.strip())
+ except:
+ raise e
def _parse_datatype_json_string(json_string):
http://git-wip-us.apache.org/repos/asf/spark/blob/ebc124d4/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
new file mode 100644
index 0000000..731feb9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.python
+
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.types.DataType
+
+private[sql] object PythonSQLUtils {
+ def parseDataType(typeText: String): DataType = CatalystSqlParser.parseDataType(typeText)
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org