You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "hvanhovell (via GitHub)" <gi...@apache.org> on 2023/03/04 14:08:46 UTC

[GitHub] [spark] hvanhovell commented on a diff in pull request #40276: [SPARK-42630][CONNECT][PYTHON] Implement data type string parser

hvanhovell commented on code in PR #40276:
URL: https://github.com/apache/spark/pull/40276#discussion_r1125475957


##########
python/pyspark/sql/connect/types.py:
##########
@@ -342,20 +343,325 @@ def from_arrow_schema(arrow_schema: "pa.Schema") -> StructType:
 
 
 def parse_data_type(data_type: str) -> DataType:
-    # Currently we don't have a way to have a current Spark session in Spark Connect, and
-    # pyspark.sql.SparkSession has a centralized logic to control the session creation.
-    # So uses pyspark.sql.SparkSession for now. Should replace this to using the current
-    # Spark session for Spark Connect in the future.
-    from pyspark.sql import SparkSession as PySparkSession
-
-    assert is_remote()
-    return_type_schema = (
-        PySparkSession.builder.getOrCreate().createDataFrame(data=[], schema=data_type).schema
+    """
+    Parses the given data type string to a :class:`DataType`. The data type string format equals
+    :class:`DataType.simpleString`, except that the top level struct type can omit
+    the ``struct<>``. Since Spark 2.3, this also supports a schema in a DDL-formatted
+    string and case-insensitive strings.
+
+    Examples
+    --------
+    >>> parse_data_type("int ")
+    IntegerType()
+    >>> parse_data_type("INT ")
+    IntegerType()
+    >>> parse_data_type("a: byte, b: decimal(  16 , 8   ) ")
+    StructType([StructField('a', ByteType(), True), StructField('b', DecimalType(16,8), True)])
+    >>> parse_data_type("a DOUBLE, b STRING")
+    StructType([StructField('a', DoubleType(), True), StructField('b', StringType(), True)])
+    >>> parse_data_type("a DOUBLE, b CHAR( 50 )")
+    StructType([StructField('a', DoubleType(), True), StructField('b', CharType(50), True)])
+    >>> parse_data_type("a DOUBLE, b VARCHAR( 50 )")
+    StructType([StructField('a', DoubleType(), True), StructField('b', VarcharType(50), True)])
+    >>> parse_data_type("a: array< short>")
+    StructType([StructField('a', ArrayType(ShortType(), True), True)])
+    >>> parse_data_type(" map<string , string > ")
+    MapType(StringType(), StringType(), True)
+
+    >>> # Error cases
+    >>> parse_data_type("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    ParseException:...
+    >>> parse_data_type("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    ParseException:...
+    >>> parse_data_type("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    ParseException:...
+    >>> parse_data_type("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    ParseException:...
+    """
+    try:
+        # DDL format, "fieldname datatype, fieldname datatype".
+        return DDLSchemaParser(data_type).from_ddl_schema()
+    except ParseException as e:
+        try:
+            # For backwards compatibility, "integer", "struct<fieldname: datatype>" and etc.
+            return DDLDataTypeParser(data_type).from_ddl_datatype()
+        except ParseException:
+            try:
+                # For backwards compatibility, "fieldname: datatype, fieldname: datatype" case.
+                return DDLDataTypeParser(f"struct<{data_type}>").from_ddl_datatype()
+            except ParseException:
+                raise e from None
+
+
+class DataTypeParserBase:
+    REGEXP_IDENTIFIER: Final[Pattern] = re.compile("\\w+|`(?:``|[^`])*`", re.MULTILINE)
+    REGEXP_INTEGER_VALUES: Final[Pattern] = re.compile(
+        "\\(\\s*(?:[+-]?\\d+)\\s*(?:,\\s*(?:[+-]?\\d+)\\s*)*\\)", re.MULTILINE
     )
-    with_col_name = " " in data_type.strip()
-    if len(return_type_schema.fields) == 1 and not with_col_name:
-        # To match pyspark.sql.types._parse_datatype_string
-        return_type = return_type_schema.fields[0].dataType
-    else:
-        return_type = return_type_schema
-    return return_type
+    REGEXP_INTERVAL_TYPE: Final[Pattern] = re.compile(
+        "(day|hour|minute|second)(?:\\s+to\\s+(hour|minute|second))?", re.IGNORECASE | re.MULTILINE
+    )
+    REGEXP_NOT_NULL_COMMENT: Final[Pattern] = re.compile(
+        "(not\\s+null)?(?:(?(1)\\s+)comment\\s+'((?:\\\\'|[^'])*)')?", re.IGNORECASE | re.MULTILINE
+    )
+
+    def __init__(self, type_str: str):
+        self._type_str = type_str
+        self._pos = 0
+        self._lstrip()
+
+    def _lstrip(self) -> None:
+        remaining = self._type_str[self._pos :]
+        self._pos = self._pos + (len(remaining) - len(remaining.lstrip()))
+
+    def _parse_data_type(self) -> DataType:
+        type_str = self._type_str[self._pos :]
+        m = self.REGEXP_IDENTIFIER.match(type_str)
+        if m:
+            data_type_name = m.group(0).lower().strip("`").replace("``", "`")
+            self._pos = self._pos + len(m.group(0))
+            self._lstrip()
+            if data_type_name == "array":
+                return self._parse_array_type()
+            elif data_type_name == "map":
+                return self._parse_map_type()
+            elif data_type_name == "struct":
+                return self._parse_struct_type()
+            elif data_type_name == "interval":
+                return self._parse_interval_type()
+            else:
+                return self._parse_primitive_types(data_type_name)
+
+        raise ParseException(
+            error_class="PARSE_SYNTAX_ERROR",
+            message_parameters={"error": f"'{type_str}'", "hint": ""},
+        )
+
+    def _parse_array_type(self) -> ArrayType:
+        type_str = self._type_str[self._pos :]
+        if len(type_str) > 0 and type_str[0] == "<":
+            self._pos = self._pos + 1
+            self._lstrip()
+            element_type = self._parse_data_type()
+            remaining = self._type_str[self._pos :]
+            if len(remaining) and remaining[0] == ">":
+                self._pos = self._pos + 1
+                self._lstrip()
+                return ArrayType(element_type)
+        raise ParseException(error_class="INCOMPLETE_TYPE_DEFINITION.ARRAY", message_parameters={})
+
+    def _parse_map_type(self) -> MapType:
+        type_str = self._type_str[self._pos :]
+        if len(type_str) > 0 and type_str[0] == "<":
+            self._pos = self._pos + 1
+            self._lstrip()
+            key_type = self._parse_data_type()
+            remaining = self._type_str[self._pos :]
+            if len(remaining) > 0 and remaining[0] == ",":
+                self._pos = self._pos + 1
+                self._lstrip()
+                value_type = self._parse_data_type()
+                remaining = self._type_str[self._pos :]
+                if len(remaining) > 0 and remaining[0] == ">":
+                    self._pos = self._pos + 1
+                    self._lstrip()
+                    return MapType(key_type, value_type)
+        raise ParseException(error_class="INCOMPLETE_TYPE_DEFINITION.MAP", message_parameters={})
+
+    def _parse_struct_type(self) -> StructType:
+        type_str = self._type_str[self._pos :]
+        if len(type_str) > 0 and type_str[0] == "<":
+            self._pos = self._pos + 1
+            self._lstrip()
+            fields = self._parse_struct_fields()
+            remaining = self._type_str[self._pos :]
+            if len(remaining) > 0 and remaining[0] == ">":
+                self._pos = self._pos + 1
+                self._lstrip()
+                return StructType(fields)
+        raise ParseException(error_class="INCOMPLETE_TYPE_DEFINITION.STRUCT", message_parameters={})
+
+    def _parse_struct_fields(self, sep_with_colon: bool = True) -> List[StructField]:
+        type_str = self._type_str[self._pos :]
+        m = self.REGEXP_IDENTIFIER.match(type_str)
+        if m:
+            field_name = m.group(0).lower().strip("`").replace("``", "`")

Review Comment:
   @HyukjinKwon real engines write their own parsers ;)... In this case you could add a separate tokenization step to make the actual parsing a bit easier to read, and easier to maintain, and it can take care of some case sensitivity concerns. I think adding antlr for this is overkill.
   
   As for the worry of missing stuff. I guess having some language agnostic specification test would not be a bad thing to add.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org