You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "ueshin (via GitHub)" <gi...@apache.org> on 2023/03/06 21:21:16 UTC

[GitHub] [spark] ueshin commented on a diff in pull request #40276: [SPARK-42630][CONNECT][PYTHON] Implement data type string parser

ueshin commented on code in PR #40276:
URL: https://github.com/apache/spark/pull/40276#discussion_r1127045146


##########
python/pyspark/sql/connect/types.py:
##########
@@ -342,20 +343,325 @@ def from_arrow_schema(arrow_schema: "pa.Schema") -> StructType:
 
 
 def parse_data_type(data_type: str) -> DataType:
-    # Currently we don't have a way to have a current Spark session in Spark Connect, and
-    # pyspark.sql.SparkSession has a centralized logic to control the session creation.
-    # So uses pyspark.sql.SparkSession for now. Should replace this to using the current
-    # Spark session for Spark Connect in the future.
-    from pyspark.sql import SparkSession as PySparkSession
-
-    assert is_remote()
-    return_type_schema = (
-        PySparkSession.builder.getOrCreate().createDataFrame(data=[], schema=data_type).schema
+    """
+    Parses the given data type string to a :class:`DataType`. The data type string format equals
+    :class:`DataType.simpleString`, except that the top level struct type can omit
+    the ``struct<>``. Since Spark 2.3, this also supports a schema in a DDL-formatted
+    string and case-insensitive strings.
+
+    Examples
+    --------
+    >>> parse_data_type("int ")
+    IntegerType()
+    >>> parse_data_type("INT ")
+    IntegerType()
+    >>> parse_data_type("a: byte, b: decimal(  16 , 8   ) ")
+    StructType([StructField('a', ByteType(), True), StructField('b', DecimalType(16,8), True)])
+    >>> parse_data_type("a DOUBLE, b STRING")
+    StructType([StructField('a', DoubleType(), True), StructField('b', StringType(), True)])
+    >>> parse_data_type("a DOUBLE, b CHAR( 50 )")
+    StructType([StructField('a', DoubleType(), True), StructField('b', CharType(50), True)])
+    >>> parse_data_type("a DOUBLE, b VARCHAR( 50 )")
+    StructType([StructField('a', DoubleType(), True), StructField('b', VarcharType(50), True)])
+    >>> parse_data_type("a: array< short>")
+    StructType([StructField('a', ArrayType(ShortType(), True), True)])
+    >>> parse_data_type(" map<string , string > ")
+    MapType(StringType(), StringType(), True)
+
+    >>> # Error cases
+    >>> parse_data_type("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    ParseException:...
+    >>> parse_data_type("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    ParseException:...
+    >>> parse_data_type("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    ParseException:...
+    >>> parse_data_type("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    ParseException:...
+    """
+    try:
+        # DDL format, "fieldname datatype, fieldname datatype".
+        return DDLSchemaParser(data_type).from_ddl_schema()
+    except ParseException as e:
+        try:
+            # For backwards compatibility, "integer", "struct<fieldname: datatype>" and etc.
+            return DDLDataTypeParser(data_type).from_ddl_datatype()
+        except ParseException:
+            try:
+                # For backwards compatibility, "fieldname: datatype, fieldname: datatype" case.
+                return DDLDataTypeParser(f"struct<{data_type}>").from_ddl_datatype()
+            except ParseException:
+                raise e from None
+
+
+class DataTypeParserBase:
+    REGEXP_IDENTIFIER: Final[Pattern] = re.compile("\\w+|`(?:``|[^`])*`", re.MULTILINE)
+    REGEXP_INTEGER_VALUES: Final[Pattern] = re.compile(
+        "\\(\\s*(?:[+-]?\\d+)\\s*(?:,\\s*(?:[+-]?\\d+)\\s*)*\\)", re.MULTILINE
     )
-    with_col_name = " " in data_type.strip()
-    if len(return_type_schema.fields) == 1 and not with_col_name:
-        # To match pyspark.sql.types._parse_datatype_string
-        return_type = return_type_schema.fields[0].dataType
-    else:
-        return_type = return_type_schema
-    return return_type
+    REGEXP_INTERVAL_TYPE: Final[Pattern] = re.compile(
+        "(day|hour|minute|second)(?:\\s+to\\s+(hour|minute|second))?", re.IGNORECASE | re.MULTILINE
+    )
+    REGEXP_NOT_NULL_COMMENT: Final[Pattern] = re.compile(
+        "(not\\s+null)?(?:(?(1)\\s+)comment\\s+'((?:\\\\'|[^'])*)')?", re.IGNORECASE | re.MULTILINE
+    )
+
+    def __init__(self, type_str: str):
+        self._type_str = type_str
+        self._pos = 0
+        self._lstrip()
+
+    def _lstrip(self) -> None:
+        remaining = self._type_str[self._pos :]
+        self._pos = self._pos + (len(remaining) - len(remaining.lstrip()))
+
+    def _parse_data_type(self) -> DataType:
+        type_str = self._type_str[self._pos :]
+        m = self.REGEXP_IDENTIFIER.match(type_str)
+        if m:
+            data_type_name = m.group(0).lower().strip("`").replace("``", "`")
+            self._pos = self._pos + len(m.group(0))
+            self._lstrip()
+            if data_type_name == "array":
+                return self._parse_array_type()
+            elif data_type_name == "map":
+                return self._parse_map_type()
+            elif data_type_name == "struct":
+                return self._parse_struct_type()
+            elif data_type_name == "interval":
+                return self._parse_interval_type()
+            else:
+                return self._parse_primitive_types(data_type_name)
+
+        raise ParseException(
+            error_class="PARSE_SYNTAX_ERROR",
+            message_parameters={"error": f"'{type_str}'", "hint": ""},
+        )
+
+    def _parse_array_type(self) -> ArrayType:
+        type_str = self._type_str[self._pos :]
+        if len(type_str) > 0 and type_str[0] == "<":
+            self._pos = self._pos + 1
+            self._lstrip()
+            element_type = self._parse_data_type()
+            remaining = self._type_str[self._pos :]
+            if len(remaining) and remaining[0] == ">":
+                self._pos = self._pos + 1
+                self._lstrip()
+                return ArrayType(element_type)
+        raise ParseException(error_class="INCOMPLETE_TYPE_DEFINITION.ARRAY", message_parameters={})
+
+    def _parse_map_type(self) -> MapType:
+        type_str = self._type_str[self._pos :]
+        if len(type_str) > 0 and type_str[0] == "<":
+            self._pos = self._pos + 1
+            self._lstrip()
+            key_type = self._parse_data_type()
+            remaining = self._type_str[self._pos :]
+            if len(remaining) > 0 and remaining[0] == ",":
+                self._pos = self._pos + 1
+                self._lstrip()
+                value_type = self._parse_data_type()
+                remaining = self._type_str[self._pos :]
+                if len(remaining) > 0 and remaining[0] == ">":
+                    self._pos = self._pos + 1
+                    self._lstrip()
+                    return MapType(key_type, value_type)
+        raise ParseException(error_class="INCOMPLETE_TYPE_DEFINITION.MAP", message_parameters={})
+
+    def _parse_struct_type(self) -> StructType:
+        type_str = self._type_str[self._pos :]
+        if len(type_str) > 0 and type_str[0] == "<":
+            self._pos = self._pos + 1
+            self._lstrip()
+            fields = self._parse_struct_fields()
+            remaining = self._type_str[self._pos :]
+            if len(remaining) > 0 and remaining[0] == ">":
+                self._pos = self._pos + 1
+                self._lstrip()
+                return StructType(fields)
+        raise ParseException(error_class="INCOMPLETE_TYPE_DEFINITION.STRUCT", message_parameters={})
+
+    def _parse_struct_fields(self, sep_with_colon: bool = True) -> List[StructField]:
+        type_str = self._type_str[self._pos :]
+        m = self.REGEXP_IDENTIFIER.match(type_str)
+        if m:
+            field_name = m.group(0).lower().strip("`").replace("``", "`")

Review Comment:
   Now that it supports `NOT NULL` or `COMMENT`, actually this still needs to check a config:
   
   https://github.com/apache/spark/blob/f9c8a246ecc34ef4bb93c319c7c9b6ff732c962e/python/pyspark/sql/connect/types.py#L574-L575
   
   In Scala, it checks `spark.sql.timestampType` whether is should be `TimestampType` or `TimestampNTZType`:
   
   https://github.com/apache/spark/blob/f9c8a246ecc34ef4bb93c319c7c9b6ff732c962e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala#L2878
   https://github.com/apache/spark/blob/f9c8a246ecc34ef4bb93c319c7c9b6ff732c962e/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4858-L4865



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org