You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by dp...@apache.org on 2020/09/22 13:29:10 UTC

[incubator-superset] 01/09: fix(db_engine_specs): improve Presto column type matching (#10658)

This is an automated email from the ASF dual-hosted git repository.

dpgaspar pushed a commit to branch 0.38
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git

commit 4f3fcfd43040d99a63de0f1a222b92de922cd855
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Mon Aug 24 22:42:07 2020 +0300

    fix(db_engine_specs): improve Presto column type matching (#10658)
    
    * fix: improve Presto column type matching
    
    * add optional callback to type map and add tests
    
    * lint
    
    * change private to public
---
 superset/db_engine_specs/base.py              | 13 +++-
 superset/db_engine_specs/mssql.py             | 15 ++---
 superset/db_engine_specs/presto.py            | 89 ++++++++++++++++++---------
 superset/models/sql_types/presto_sql_types.py | 24 --------
 tests/db_engine_specs/presto_tests.py         | 21 +++++++
 5 files changed, 97 insertions(+), 65 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 9c9b6a8..cfb3671 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -24,8 +24,10 @@ from contextlib import closing
 from datetime import datetime
 from typing import (
     Any,
+    Callable,
     Dict,
     List,
+    Match,
     NamedTuple,
     Optional,
     Pattern,
@@ -142,6 +144,9 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     ] = None  # used for user messages, overridden in child classes
     _date_trunc_functions: Dict[str, str] = {}
     _time_grain_expressions: Dict[Optional[str], str] = {}
+    column_type_mappings: Tuple[
+        Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ...,
+    ] = ()
     time_groupby_inline = False
     limit_method = LimitMethod.FORCE_LIMIT
     time_secondary_columns = False
@@ -888,12 +893,18 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         """
         Return a sqlalchemy native column type that corresponds to the column type
         defined in the data source (return None to use default type inferred by
-        SQLAlchemy). Needs to be overridden if column requires special handling
+        SQLAlchemy). Override `_column_type_mappings` for specific needs
         (see MSSQL for example of NCHAR/NVARCHAR handling).
 
         :param type_: Column type returned by inspector
         :return: SqlAlchemy column type
         """
+        for regex, sqla_type in cls.column_type_mappings:
+            match = regex.match(type_)
+            if match:
+                if callable(sqla_type):
+                    return sqla_type(match)
+                return sqla_type
         return None
 
     @staticmethod
diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py
index ead06f8..d1bb99c 100644
--- a/superset/db_engine_specs/mssql.py
+++ b/superset/db_engine_specs/mssql.py
@@ -19,7 +19,7 @@ import re
 from datetime import datetime
 from typing import Any, List, Optional, Tuple, TYPE_CHECKING
 
-from sqlalchemy.types import String, TypeEngine, UnicodeText
+from sqlalchemy.types import String, UnicodeText
 
 from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
 from superset.utils import core as utils
@@ -75,19 +75,12 @@ class MssqlEngineSpec(BaseEngineSpec):
         # Lists of `pyodbc.Row` need to be unpacked further
         return cls.pyodbc_rows_to_tuples(data)
 
-    column_types = (
-        (String(), re.compile(r"^(?<!N)((VAR){0,1}CHAR|TEXT|STRING)", re.IGNORECASE)),
-        (UnicodeText(), re.compile(r"^N((VAR){0,1}CHAR|TEXT)", re.IGNORECASE)),
+    column_type_mappings = (
+        (re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), UnicodeText()),
+        (re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), String()),
     )
 
     @classmethod
-    def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]:
-        for sqla_type, regex in cls.column_types:
-            if regex.match(type_):
-                return sqla_type
-        return None
-
-    @classmethod
     def extract_error_message(cls, ex: Exception) -> str:
         if str(ex).startswith("(8155,"):
             return (
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 16e6a4c..9a53d5d 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -28,7 +28,7 @@ from urllib import parse
 import pandas as pd
 import simplejson as json
 from flask_babel import lazy_gettext as _
-from sqlalchemy import Column, literal_column
+from sqlalchemy import Column, literal_column, types
 from sqlalchemy.engine.base import Engine
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.result import RowProxy
@@ -40,7 +40,13 @@ from superset import app, cache, is_feature_enabled, security_manager
 from superset.db_engine_specs.base import BaseEngineSpec
 from superset.exceptions import SupersetTemplateException
 from superset.models.sql_lab import Query
-from superset.models.sql_types.presto_sql_types import type_map as presto_type_map
+from superset.models.sql_types.presto_sql_types import (
+    Array,
+    Interval,
+    Map,
+    Row,
+    TinyInteger,
+)
 from superset.result_set import destringify
 from superset.sql_parse import ParsedQuery
 from superset.utils import core as utils
@@ -260,13 +266,16 @@ class PrestoEngineSpec(BaseEngineSpec):
                         field_info = cls._split_data_type(single_field, r"\s")
                         # check if there is a structural data type within
                         # overall structural data type
+                        column_type = cls.get_sqla_column_type(field_info[1])
+                        if column_type is None:
+                            raise NotImplementedError(
+                                _("Unknown column type: %(col)s", col=field_info[1])
+                            )
                         if field_info[1] == "array" or field_info[1] == "row":
                             stack.append((field_info[0], field_info[1]))
                             full_parent_path = cls._get_full_name(stack)
                             result.append(
-                                cls._create_column_info(
-                                    full_parent_path, presto_type_map[field_info[1]]()
-                                )
+                                cls._create_column_info(full_parent_path, column_type)
                             )
                         else:  # otherwise this field is a basic data type
                             full_parent_path = cls._get_full_name(stack)
@@ -274,9 +283,7 @@ class PrestoEngineSpec(BaseEngineSpec):
                                 full_parent_path, field_info[0]
                             )
                             result.append(
-                                cls._create_column_info(
-                                    column_name, presto_type_map[field_info[1]]()
-                                )
+                                cls._create_column_info(column_name, column_type)
                             )
                     # If the component type ends with a structural data type, do not pop
                     # the stack. We have run across a structural data type within the
@@ -318,6 +325,34 @@ class PrestoEngineSpec(BaseEngineSpec):
         columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table))
         return columns
 
+    column_type_mappings = (
+        (re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()),
+        (re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()),
+        (re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()),
+        (re.compile(r"^integer.*", re.IGNORECASE), types.Integer()),
+        (re.compile(r"^bigint.*", re.IGNORECASE), types.BigInteger()),
+        (re.compile(r"^real.*", re.IGNORECASE), types.Float()),
+        (re.compile(r"^double.*", re.IGNORECASE), types.Float()),
+        (re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL()),
+        (
+            re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE),
+            lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(),
+        ),
+        (
+            re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE),
+            lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(),
+        ),
+        (re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY()),
+        (re.compile(r"^json.*", re.IGNORECASE), types.JSON()),
+        (re.compile(r"^date.*", re.IGNORECASE), types.DATE()),
+        (re.compile(r"^time.*", re.IGNORECASE), types.Time()),
+        (re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP()),
+        (re.compile(r"^interval.*", re.IGNORECASE), Interval()),
+        (re.compile(r"^array.*", re.IGNORECASE), Array()),
+        (re.compile(r"^map.*", re.IGNORECASE), Map()),
+        (re.compile(r"^row.*", re.IGNORECASE), Row()),
+    )
+
     @classmethod
     def get_columns(
         cls, inspector: Inspector, table_name: str, schema: Optional[str]
@@ -334,28 +369,24 @@ class PrestoEngineSpec(BaseEngineSpec):
         columns = cls._show_columns(inspector, table_name, schema)
         result: List[Dict[str, Any]] = []
         for column in columns:
-            try:
-                # parse column if it is a row or array
-                if is_feature_enabled("PRESTO_EXPAND_DATA") and (
-                    "array" in column.Type or "row" in column.Type
-                ):
-                    structural_column_index = len(result)
-                    cls._parse_structural_column(column.Column, column.Type, result)
-                    result[structural_column_index]["nullable"] = getattr(
-                        column, "Null", True
-                    )
-                    result[structural_column_index]["default"] = None
-                    continue
-
-                # otherwise column is a basic data type
-                column_type = presto_type_map[column.Type]()
-            except KeyError:
-                logger.info(
-                    "Did not recognize type {} of column {}".format(  # pylint: disable=logging-format-interpolation
-                        column.Type, column.Column
-                    )
+            # parse column if it is a row or array
+            if is_feature_enabled("PRESTO_EXPAND_DATA") and (
+                "array" in column.Type or "row" in column.Type
+            ):
+                structural_column_index = len(result)
+                cls._parse_structural_column(column.Column, column.Type, result)
+                result[structural_column_index]["nullable"] = getattr(
+                    column, "Null", True
+                )
+                result[structural_column_index]["default"] = None
+                continue
+
+            # otherwise column is a basic data type
+            column_type = cls.get_sqla_column_type(column.Type)
+            if column_type is None:
+                raise NotImplementedError(
+                    _("Unknown column type: %(col)s", col=column_type)
                 )
-                column_type = "OTHER"
             column_info = cls._create_column_info(column.Column, column_type)
             column_info["nullable"] = getattr(column, "Null", True)
             column_info["default"] = None
diff --git a/superset/models/sql_types/presto_sql_types.py b/superset/models/sql_types/presto_sql_types.py
index d6f6d39..a314639 100644
--- a/superset/models/sql_types/presto_sql_types.py
+++ b/superset/models/sql_types/presto_sql_types.py
@@ -16,7 +16,6 @@
 # under the License.
 from typing import Any, Dict, List, Optional, Type
 
-from sqlalchemy import types
 from sqlalchemy.sql.sqltypes import Integer
 from sqlalchemy.sql.type_api import TypeEngine
 from sqlalchemy.sql.visitors import Visitable
@@ -92,26 +91,3 @@ class Row(TypeEngine):
     @classmethod
     def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
         return "ROW"
-
-
-type_map = {
-    "boolean": types.Boolean,
-    "tinyint": TinyInteger,
-    "smallint": types.SmallInteger,
-    "integer": types.Integer,
-    "bigint": types.BigInteger,
-    "real": types.Float,
-    "double": types.Float,
-    "decimal": types.DECIMAL,
-    "varchar": types.String,
-    "char": types.CHAR,
-    "varbinary": types.VARBINARY,
-    "JSON": types.JSON,
-    "date": types.DATE,
-    "time": types.Time,
-    "timestamp": types.TIMESTAMP,
-    "interval": Interval,
-    "array": Array,
-    "map": Map,
-    "row": Row,
-}
diff --git a/tests/db_engine_specs/presto_tests.py b/tests/db_engine_specs/presto_tests.py
index 9d1d384..3a0346b 100644
--- a/tests/db_engine_specs/presto_tests.py
+++ b/tests/db_engine_specs/presto_tests.py
@@ -17,6 +17,7 @@
 from unittest import mock, skipUnless
 
 import pandas as pd
+from sqlalchemy import types
 from sqlalchemy.engine.result import RowProxy
 from sqlalchemy.sql import select
 
@@ -490,3 +491,23 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
         self.assertEqual(actual_cols, expected_cols)
         self.assertEqual(actual_data, expected_data)
         self.assertEqual(actual_expanded_cols, expected_expanded_cols)
+
+    def test_get_sqla_column_type(self):
+        sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)")
+        assert isinstance(sqla_type, types.VARCHAR)
+        assert sqla_type.length == 255
+
+        sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar")
+        assert isinstance(sqla_type, types.String)
+        assert sqla_type.length is None
+
+        sqla_type = PrestoEngineSpec.get_sqla_column_type("char(10)")
+        assert isinstance(sqla_type, types.CHAR)
+        assert sqla_type.length == 10
+
+        sqla_type = PrestoEngineSpec.get_sqla_column_type("char")
+        assert isinstance(sqla_type, types.CHAR)
+        assert sqla_type.length is None
+
+        sqla_type = PrestoEngineSpec.get_sqla_column_type("integer")
+        assert isinstance(sqla_type, types.Integer)