You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2021/03/12 08:37:45 UTC
[superset] branch master updated: feat(explore): Postgres datatype
conversion (#13294)
This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 609c359 feat(explore): Postgres datatype conversion (#13294)
609c359 is described below
commit 609c3594ef74ad875d1f47e7f2a4c631036503c3
Author: Nikola Gigić <ni...@gmail.com>
AuthorDate: Fri Mar 12 09:36:43 2021 +0100
feat(explore): Postgres datatype conversion (#13294)
* test
* unnecessary import
* fix lint
* changes
* fix lint
* changes
* changes
* changes
* changes
* answering comments & changes
* answering comments
* answering comments
* changes
* changes
* changes
* fix tests
* fix tests
* fix tests
* fix tests
* fix tests
* fix tests
* fix tests
* fix tests
* fix tests
* fix tests
* fix tests
* fix tests
* fix tests
* fix tests
* fix tests
* fix tests
* fix tests
* fix tests
---
superset/connectors/sqla/models.py | 28 ++---
superset/db_engine_specs/base.py | 191 +++++++++++++++++++++++++---------
superset/db_engine_specs/mssql.py | 13 +--
superset/db_engine_specs/mysql.py | 68 +++++++++++-
superset/db_engine_specs/postgres.py | 54 +++++++++-
superset/db_engine_specs/presto.py | 138 ++++++++++++++++++++----
superset/result_set.py | 7 +-
superset/utils/core.py | 18 +++-
tests/db_engine_specs/mssql_tests.py | 31 +++---
tests/db_engine_specs/mysql_tests.py | 15 +--
tests/db_engine_specs/presto_tests.py | 57 +++++-----
tests/sqla_models_tests.py | 4 +-
12 files changed, 471 insertions(+), 153 deletions(-)
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 9f745f9..ff4f819 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -69,6 +69,7 @@ from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery
from superset.typing import Metric, QueryObjectDict
from superset.utils import core as utils
+from superset.utils.core import GenericDataType
config = app.config
metadata = Model.metadata # pylint: disable=no-member
@@ -186,20 +187,20 @@ class TableColumn(Model, BaseColumn):
"""
Check if the column has a numeric datatype.
"""
- db_engine_spec = self.table.database.db_engine_spec
- return db_engine_spec.is_db_column_type_match(
- self.type, utils.GenericDataType.NUMERIC
- )
+ column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
+ if column_spec is None:
+ return False
+ return column_spec.generic_type == GenericDataType.NUMERIC
@property
def is_string(self) -> bool:
"""
Check if the column has a string datatype.
"""
- db_engine_spec = self.table.database.db_engine_spec
- return db_engine_spec.is_db_column_type_match(
- self.type, utils.GenericDataType.STRING
- )
+ column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
+ if column_spec is None:
+ return False
+ return column_spec.generic_type == GenericDataType.STRING
@property
def is_temporal(self) -> bool:
@@ -211,10 +212,10 @@ class TableColumn(Model, BaseColumn):
"""
if self.is_dttm is not None:
return self.is_dttm
- db_engine_spec = self.table.database.db_engine_spec
- return db_engine_spec.is_db_column_type_match(
- self.type, utils.GenericDataType.TEMPORAL
- )
+ column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
+ if column_spec is None:
+ return False
+ return column_spec.is_dttm
def get_sqla_col(self, label: Optional[str] = None) -> Column:
label = label or self.column_name
@@ -222,7 +223,8 @@ class TableColumn(Model, BaseColumn):
col = literal_column(self.expression)
else:
db_engine_spec = self.table.database.db_engine_spec
- type_ = db_engine_spec.get_sqla_column_type(self.type)
+ column_spec = db_engine_spec.get_column_spec(self.type)
+ type_ = column_spec.sqla_type if column_spec else None
col = column(self.column_name, type_=type_)
col = self.table.make_sqla_column_compatible(col, label)
return col
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 3052ff7..e3e83cd 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -41,7 +41,7 @@ import pandas as pd
import sqlparse
from flask import g
from flask_babel import gettext as __, lazy_gettext as _
-from sqlalchemy import column, DateTime, select
+from sqlalchemy import column, DateTime, select, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.interfaces import Compiled, Dialect
from sqlalchemy.engine.reflection import Inspector
@@ -50,13 +50,14 @@ from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom
-from sqlalchemy.types import TypeEngine
+from sqlalchemy.types import String, TypeEngine, UnicodeText
from superset import app, security_manager, sql_parse
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery, Table
from superset.utils import core as utils
+from superset.utils.core import ColumnSpec, GenericDataType
if TYPE_CHECKING:
# prevent circular imports
@@ -145,8 +146,87 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[
- Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ...,
- ] = ()
+ Tuple[
+ Pattern[str],
+ Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
+ GenericDataType,
+ ],
+ ...,
+ ] = (
+ (
+ re.compile(r"^smallint", re.IGNORECASE),
+ types.SmallInteger(),
+ GenericDataType.NUMERIC,
+ ),
+ (
+ re.compile(r"^integer", re.IGNORECASE),
+ types.Integer(),
+ GenericDataType.NUMERIC,
+ ),
+ (
+ re.compile(r"^bigint", re.IGNORECASE),
+ types.BigInteger(),
+ GenericDataType.NUMERIC,
+ ),
+ (
+ re.compile(r"^decimal", re.IGNORECASE),
+ types.Numeric(),
+ GenericDataType.NUMERIC,
+ ),
+ (
+ re.compile(r"^numeric", re.IGNORECASE),
+ types.Numeric(),
+ GenericDataType.NUMERIC,
+ ),
+ (re.compile(r"^real", re.IGNORECASE), types.REAL, GenericDataType.NUMERIC,),
+ (
+ re.compile(r"^smallserial", re.IGNORECASE),
+ types.SmallInteger(),
+ GenericDataType.NUMERIC,
+ ),
+ (
+ re.compile(r"^serial", re.IGNORECASE),
+ types.Integer(),
+ GenericDataType.NUMERIC,
+ ),
+ (
+ re.compile(r"^bigserial", re.IGNORECASE),
+ types.BigInteger(),
+ GenericDataType.NUMERIC,
+ ),
+ (
+ re.compile(r"^string", re.IGNORECASE),
+ types.String(),
+ utils.GenericDataType.STRING,
+ ),
+ (
+ re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE),
+ UnicodeText(),
+ utils.GenericDataType.STRING,
+ ),
+ (
+ re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE),
+ String(),
+ utils.GenericDataType.STRING,
+ ),
+ (re.compile(r"^date", re.IGNORECASE), types.Date(), GenericDataType.TEMPORAL,),
+ (
+ re.compile(r"^timestamp", re.IGNORECASE),
+ types.TIMESTAMP(),
+ GenericDataType.TEMPORAL,
+ ),
+ (
+ re.compile(r"^interval", re.IGNORECASE),
+ types.Interval(),
+ GenericDataType.TEMPORAL,
+ ),
+ (re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,),
+ (
+ re.compile(r"^boolean", re.IGNORECASE),
+ types.Boolean(),
+ GenericDataType.BOOLEAN,
+ ),
+ )
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
time_secondary_columns = False
@@ -160,25 +240,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
run_multiple_statements_as_one = False
- # default matching patterns to convert database specific column types to
- # more generic types
- db_column_types: Dict[utils.GenericDataType, Tuple[Pattern[str], ...]] = {
- utils.GenericDataType.NUMERIC: (
- re.compile(r"BIT", re.IGNORECASE),
- re.compile(
- r".*(DOUBLE|FLOAT|INT|NUMBER|REAL|NUMERIC|DECIMAL|MONEY).*",
- re.IGNORECASE,
- ),
- re.compile(r".*LONG$", re.IGNORECASE),
- ),
- utils.GenericDataType.STRING: (
- re.compile(r".*(CHAR|STRING|TEXT).*", re.IGNORECASE),
- ),
- utils.GenericDataType.TEMPORAL: (
- re.compile(r".*(DATE|TIME).*", re.IGNORECASE),
- ),
- }
-
@classmethod
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
"""
@@ -209,25 +270,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return new_exception(str(exception))
@classmethod
- def is_db_column_type_match(
- cls, db_column_type: Optional[str], target_column_type: utils.GenericDataType
- ) -> bool:
- """
- Check if a column type satisfies a pattern in a collection of regexes found in
- `db_column_types`. For example, if `db_column_type == "NVARCHAR"`,
- it would be a match for "STRING" due to being a match for the regex ".*CHAR.*".
-
- :param db_column_type: Column type to evaluate
- :param target_column_type: The target type to evaluate for
- :return: `True` if a `db_column_type` matches any pattern corresponding to
- `target_column_type`
- """
- if not db_column_type:
- return False
- patterns = cls.db_column_types[target_column_type]
- return any(pattern.match(db_column_type) for pattern in patterns)
-
- @classmethod
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
return False
@@ -967,24 +1009,35 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return label_mutated
@classmethod
- def get_sqla_column_type(cls, type_: Optional[str]) -> Optional[TypeEngine]:
+ def get_sqla_column_type(
+ cls,
+ column_type: Optional[str],
+ column_type_mappings: Tuple[
+ Tuple[
+ Pattern[str],
+ Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
+ GenericDataType,
+ ],
+ ...,
+ ] = column_type_mappings,
+ ) -> Union[Tuple[TypeEngine, GenericDataType], None]:
"""
Return a sqlalchemy native column type that corresponds to the column type
defined in the data source (return None to use default type inferred by
SQLAlchemy). Override `column_type_mappings` for specific needs
(see MSSQL for example of NCHAR/NVARCHAR handling).
- :param type_: Column type returned by inspector
+ :param column_type: Column type returned by inspector
:return: SqlAlchemy column type
"""
- if not type_:
+ if not column_type:
return None
- for regex, sqla_type in cls.column_type_mappings:
- match = regex.match(type_)
+ for regex, sqla_type, generic_type in column_type_mappings:
+ match = regex.match(column_type)
if match:
if callable(sqla_type):
- return sqla_type(match)
- return sqla_type
+ return sqla_type(match), generic_type
+ return sqla_type, generic_type
return None
@staticmethod
@@ -1101,3 +1154,43 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
or parsed_query.is_explain()
or parsed_query.is_show()
)
+
+ @classmethod
+ def get_column_spec(
+ cls,
+ native_type: Optional[str],
+ source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
+ column_type_mappings: Tuple[
+ Tuple[
+ Pattern[str],
+ Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
+ GenericDataType,
+ ],
+ ...,
+ ] = column_type_mappings,
+ ) -> Union[ColumnSpec, None]:
+ """
+ Converts native database type to sqlalchemy column type.
+ :param native_type: Native database typee
+ :param source: Type coming from the database table or cursor description
+ :return: ColumnSpec object
+ """
+ column_type = None
+
+ if (
+ cls.get_sqla_column_type(
+ native_type, column_type_mappings=column_type_mappings
+ )
+ is not None
+ ):
+ column_type, generic_type = cls.get_sqla_column_type( # type: ignore
+ native_type, column_type_mappings=column_type_mappings
+ )
+ is_dttm = generic_type == GenericDataType.TEMPORAL
+
+ if column_type:
+ return ColumnSpec(
+ sqla_type=column_type, generic_type=generic_type, is_dttm=is_dttm
+ )
+
+ return None
diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py
index b105c70..67b9ec1 100644
--- a/superset/db_engine_specs/mssql.py
+++ b/superset/db_engine_specs/mssql.py
@@ -15,18 +15,12 @@
# specific language governing permissions and limitations
# under the License.
import logging
-import re
from datetime import datetime
-from typing import Any, List, Optional, Tuple, TYPE_CHECKING
-
-from sqlalchemy.types import String, UnicodeText
+from typing import Any, List, Optional, Tuple
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.utils import core as utils
-if TYPE_CHECKING:
- from superset.models.core import Database
-
logger = logging.getLogger(__name__)
@@ -77,11 +71,6 @@ class MssqlEngineSpec(BaseEngineSpec):
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)
- column_type_mappings = (
- (re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), UnicodeText()),
- (re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), String()),
- )
-
@classmethod
def extract_error_message(cls, ex: Exception) -> str:
if str(ex).startswith("(8155,"):
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index 481a769..3cb35e3 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -14,14 +14,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import re
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Callable, Dict, Match, Optional, Pattern, Tuple, Union
from urllib import parse
+from sqlalchemy.dialects.mysql import (
+ BIT,
+ DECIMAL,
+ DOUBLE,
+ FLOAT,
+ INTEGER,
+ LONGTEXT,
+ MEDIUMINT,
+ MEDIUMTEXT,
+ TINYINT,
+ TINYTEXT,
+)
from sqlalchemy.engine.url import URL
+from sqlalchemy.types import TypeEngine
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils
+from superset.utils.core import ColumnSpec, GenericDataType
class MySQLEngineSpec(BaseEngineSpec):
@@ -29,6 +44,34 @@ class MySQLEngineSpec(BaseEngineSpec):
engine_name = "MySQL"
max_column_name_length = 64
+ column_type_mappings: Tuple[
+ Tuple[
+ Pattern[str],
+ Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
+ GenericDataType,
+ ],
+ ...,
+ ] = (
+ (re.compile(r"^int.*", re.IGNORECASE), INTEGER(), GenericDataType.NUMERIC,),
+ (re.compile(r"^tinyint", re.IGNORECASE), TINYINT(), GenericDataType.NUMERIC,),
+ (
+ re.compile(r"^mediumint", re.IGNORECASE),
+ MEDIUMINT(),
+ GenericDataType.NUMERIC,
+ ),
+ (re.compile(r"^decimal", re.IGNORECASE), DECIMAL(), GenericDataType.NUMERIC,),
+ (re.compile(r"^float", re.IGNORECASE), FLOAT(), GenericDataType.NUMERIC,),
+ (re.compile(r"^double", re.IGNORECASE), DOUBLE(), GenericDataType.NUMERIC,),
+ (re.compile(r"^bit", re.IGNORECASE), BIT(), GenericDataType.NUMERIC,),
+ (re.compile(r"^tinytext", re.IGNORECASE), TINYTEXT(), GenericDataType.STRING,),
+ (
+ re.compile(r"^mediumtext", re.IGNORECASE),
+ MEDIUMTEXT(),
+ GenericDataType.STRING,
+ ),
+ (re.compile(r"^longtext", re.IGNORECASE), LONGTEXT(), GenericDataType.STRING,),
+ )
+
_time_grain_expressions = {
None: "{col}",
"PT1S": "DATE_ADD(DATE({col}), "
@@ -98,3 +141,26 @@ class MySQLEngineSpec(BaseEngineSpec):
except (AttributeError, KeyError):
pass
return message
+
+ @classmethod
+ def get_column_spec( # type: ignore
+ cls,
+ native_type: Optional[str],
+ source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
+ column_type_mappings: Tuple[
+ Tuple[
+ Pattern[str],
+ Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
+ GenericDataType,
+ ],
+ ...,
+ ] = column_type_mappings,
+ ) -> Union[ColumnSpec, None]:
+
+ column_spec = super().get_column_spec(native_type)
+ if column_spec:
+ return column_spec
+
+ return super().get_column_spec(
+ native_type, column_type_mappings=column_type_mappings
+ )
diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py
index a63ffdd..38c4a6d 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -18,14 +18,28 @@ import json
import logging
import re
from datetime import datetime
-from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Match,
+ Optional,
+ Pattern,
+ Tuple,
+ TYPE_CHECKING,
+ Union,
+)
from pytz import _FixedOffset # type: ignore
+from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.dialects.postgresql.base import PGInspector
+from sqlalchemy.types import String, TypeEngine
from superset.db_engine_specs.base import BaseEngineSpec
from superset.exceptions import SupersetException
from superset.utils import core as utils
+from superset.utils.core import ColumnSpec, GenericDataType
if TYPE_CHECKING:
from superset.models.core import Database # pragma: no cover
@@ -77,6 +91,21 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
max_column_name_length = 63
try_remove_schema_from_table_name = False
+ column_type_mappings = (
+ (
+ re.compile(r"^double precision", re.IGNORECASE),
+ DOUBLE_PRECISION(),
+ GenericDataType.NUMERIC,
+ ),
+ (
+ re.compile(r"^array.*", re.IGNORECASE),
+ lambda match: ARRAY(int(match[2])) if match[2] else String(),
+ utils.GenericDataType.STRING,
+ ),
+ (re.compile(r"^json.*", re.IGNORECASE), JSON(), utils.GenericDataType.STRING,),
+ (re.compile(r"^enum.*", re.IGNORECASE), ENUM(), utils.GenericDataType.STRING,),
+ )
+
@classmethod
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
return True
@@ -144,3 +173,26 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
engine_params["connect_args"] = connect_args
extra["engine_params"] = engine_params
return extra
+
+ @classmethod
+ def get_column_spec( # type: ignore
+ cls,
+ native_type: Optional[str],
+ source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
+ column_type_mappings: Tuple[
+ Tuple[
+ Pattern[str],
+ Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
+ GenericDataType,
+ ],
+ ...,
+ ] = column_type_mappings,
+ ) -> Union[ColumnSpec, None]:
+
+ column_spec = super().get_column_spec(native_type)
+ if column_spec:
+ return column_spec
+
+ return super().get_column_spec(
+ native_type, column_type_mappings=column_type_mappings
+ )
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 6ea687f..27fad22 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -23,7 +23,19 @@ from collections import defaultdict, deque
from contextlib import closing
from datetime import datetime
from distutils.version import StrictVersion
-from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
+from typing import (
+ Any,
+ Callable,
+ cast,
+ Dict,
+ List,
+ Match,
+ Optional,
+ Pattern,
+ Tuple,
+ TYPE_CHECKING,
+ Union,
+)
from urllib import parse
import pandas as pd
@@ -36,6 +48,7 @@ from sqlalchemy.engine.result import RowProxy
from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select
+from sqlalchemy.types import TypeEngine
from superset import app, cache_manager, is_feature_enabled
from superset.db_engine_specs.base import BaseEngineSpec
@@ -52,6 +65,7 @@ from superset.models.sql_types.presto_sql_types import (
from superset.result_set import destringify
from superset.sql_parse import ParsedQuery
from superset.utils import core as utils
+from superset.utils.core import ColumnSpec, GenericDataType
if TYPE_CHECKING:
# prevent circular imports
@@ -293,7 +307,8 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
field_info = cls._split_data_type(single_field, r"\s")
# check if there is a structural data type within
# overall structural data type
- column_type = cls.get_sqla_column_type(field_info[1])
+ column_spec = cls.get_column_spec(field_info[1])
+ column_type = column_spec.sqla_type if column_spec else None
if column_type is None:
column_type = types.String()
logger.info(
@@ -356,31 +371,89 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
return columns
column_type_mappings = (
- (re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()),
- (re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()),
- (re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()),
- (re.compile(r"^integer.*", re.IGNORECASE), types.Integer()),
- (re.compile(r"^bigint.*", re.IGNORECASE), types.BigInteger()),
- (re.compile(r"^real.*", re.IGNORECASE), types.Float()),
- (re.compile(r"^double.*", re.IGNORECASE), types.Float()),
- (re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL()),
+ (
+ re.compile(r"^boolean.*", re.IGNORECASE),
+ types.BOOLEAN,
+ utils.GenericDataType.BOOLEAN,
+ ),
+ (
+ re.compile(r"^tinyint.*", re.IGNORECASE),
+ TinyInteger(),
+ utils.GenericDataType.NUMERIC,
+ ),
+ (
+ re.compile(r"^smallint.*", re.IGNORECASE),
+ types.SMALLINT(),
+ utils.GenericDataType.NUMERIC,
+ ),
+ (
+ re.compile(r"^integer.*", re.IGNORECASE),
+ types.INTEGER(),
+ utils.GenericDataType.NUMERIC,
+ ),
+ (
+ re.compile(r"^bigint.*", re.IGNORECASE),
+ types.BIGINT(),
+ utils.GenericDataType.NUMERIC,
+ ),
+ (
+ re.compile(r"^real.*", re.IGNORECASE),
+ types.FLOAT(),
+ utils.GenericDataType.NUMERIC,
+ ),
+ (
+ re.compile(r"^double.*", re.IGNORECASE),
+ types.FLOAT(),
+ utils.GenericDataType.NUMERIC,
+ ),
+ (
+ re.compile(r"^decimal.*", re.IGNORECASE),
+ types.DECIMAL(),
+ utils.GenericDataType.NUMERIC,
+ ),
(
re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(),
+ utils.GenericDataType.STRING,
),
(
re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(),
+ utils.GenericDataType.STRING,
),
- (re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY()),
- (re.compile(r"^json.*", re.IGNORECASE), types.JSON()),
- (re.compile(r"^date.*", re.IGNORECASE), types.DATE()),
- (re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP()),
- (re.compile(r"^time.*", re.IGNORECASE), types.Time()),
- (re.compile(r"^interval.*", re.IGNORECASE), Interval()),
- (re.compile(r"^array.*", re.IGNORECASE), Array()),
- (re.compile(r"^map.*", re.IGNORECASE), Map()),
- (re.compile(r"^row.*", re.IGNORECASE), Row()),
+ (
+ re.compile(r"^varbinary.*", re.IGNORECASE),
+ types.VARBINARY(),
+ utils.GenericDataType.STRING,
+ ),
+ (
+ re.compile(r"^json.*", re.IGNORECASE),
+ types.JSON(),
+ utils.GenericDataType.STRING,
+ ),
+ (
+ re.compile(r"^date.*", re.IGNORECASE),
+ types.DATE(),
+ utils.GenericDataType.TEMPORAL,
+ ),
+ (
+ re.compile(r"^timestamp.*", re.IGNORECASE),
+ types.TIMESTAMP(),
+ utils.GenericDataType.TEMPORAL,
+ ),
+ (
+ re.compile(r"^interval.*", re.IGNORECASE),
+ Interval(),
+ utils.GenericDataType.TEMPORAL,
+ ),
+ (
+ re.compile(r"^time.*", re.IGNORECASE),
+ types.Time(),
+ utils.GenericDataType.TEMPORAL,
+ ),
+ (re.compile(r"^array.*", re.IGNORECASE), Array(), utils.GenericDataType.STRING),
+ (re.compile(r"^map.*", re.IGNORECASE), Map(), utils.GenericDataType.STRING),
+ (re.compile(r"^row.*", re.IGNORECASE), Row(), utils.GenericDataType.STRING),
)
@classmethod
@@ -412,7 +485,8 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
continue
# otherwise column is a basic data type
- column_type = cls.get_sqla_column_type(column.Type)
+ column_spec = cls.get_column_spec(column.Type)
+ column_type = column_spec.sqla_type if column_spec else None
if column_type is None:
column_type = types.String()
logger.info(
@@ -1111,3 +1185,27 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
return super().is_readonly_query(parsed_query) or parsed_query.is_show()
+
+ @classmethod
+ def get_column_spec( # type: ignore
+ cls,
+ native_type: Optional[str],
+ source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
+ column_type_mappings: Tuple[
+ Tuple[
+ Pattern[str],
+ Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
+ GenericDataType,
+ ],
+ ...,
+ ] = column_type_mappings,
+ ) -> Union[ColumnSpec, None]:
+
+ column_spec = super().get_column_spec(
+ native_type, column_type_mappings=column_type_mappings
+ )
+
+ if column_spec:
+ return column_spec
+
+ return super().get_column_spec(native_type)
diff --git a/superset/result_set.py b/superset/result_set.py
index f3f68ac..34d5dc9 100644
--- a/superset/result_set.py
+++ b/superset/result_set.py
@@ -181,9 +181,10 @@ class SupersetResultSet:
return next((i for i in items if i), None)
def is_temporal(self, db_type_str: Optional[str]) -> bool:
- return self.db_engine_spec.is_db_column_type_match(
- db_type_str, utils.GenericDataType.TEMPORAL
- )
+ column_spec = self.db_engine_spec.get_column_spec(db_type_str)
+ if column_spec is None:
+ return False
+ return column_spec.is_dttm
def data_type(self, col_name: str, pa_dtype: pa.DataType) -> Optional[str]:
"""Given a pyarrow data type, Returns a generic database type"""
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 893592b..a1e3d11 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -82,7 +82,7 @@ from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.type_api import Variant
-from sqlalchemy.types import TEXT, TypeDecorator
+from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine
from typing_extensions import TypedDict
import _thread # pylint: disable=C0411
@@ -148,6 +148,10 @@ class GenericDataType(IntEnum):
STRING = 1
TEMPORAL = 2
BOOLEAN = 3
+ # ARRAY = 4 # Mapping all the complex data types to STRING for now
+ # JSON = 5 # and leaving these as a reminder.
+ # MAP = 6
+ # ROW = 7
class ChartDataResultFormat(str, Enum):
@@ -306,6 +310,18 @@ class TemporalType(str, Enum):
TIMESTAMP = "TIMESTAMP"
+class ColumnTypeSource(Enum):
+ GET_TABLE = 1
+ CURSOR_DESCRIPION = 2
+
+
+class ColumnSpec(NamedTuple):
+ sqla_type: Union[TypeEngine, str]
+ generic_type: GenericDataType
+ is_dttm: bool
+ python_date_format: Optional[str] = None
+
+
try:
# Having might not have been imported.
class DimSelector(Having):
diff --git a/tests/db_engine_specs/mssql_tests.py b/tests/db_engine_specs/mssql_tests.py
index 149ed69..74c3715 100644
--- a/tests/db_engine_specs/mssql_tests.py
+++ b/tests/db_engine_specs/mssql_tests.py
@@ -24,32 +24,37 @@ from sqlalchemy.types import String, UnicodeText
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
+from superset.utils.core import GenericDataType
from tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestMssqlEngineSpec(TestDbEngineSpec):
def test_mssql_column_types(self):
- def assert_type(type_string, type_expected):
- type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
+ def assert_type(type_string, type_expected, generic_type_expected):
if type_expected is None:
+ type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
self.assertIsNone(type_assigned)
else:
- self.assertIsInstance(type_assigned, type_expected)
+ column_spec = MssqlEngineSpec.get_column_spec(type_string)
+ if column_spec != None:
+ self.assertIsInstance(column_spec.sqla_type, type_expected)
+ self.assertEquals(column_spec.generic_type, generic_type_expected)
- assert_type("INT", None)
- assert_type("STRING", String)
- assert_type("CHAR(10)", String)
- assert_type("VARCHAR(10)", String)
- assert_type("TEXT", String)
- assert_type("NCHAR(10)", UnicodeText)
- assert_type("NVARCHAR(10)", UnicodeText)
- assert_type("NTEXT", UnicodeText)
+ assert_type("STRING", String, GenericDataType.STRING)
+ assert_type("CHAR(10)", String, GenericDataType.STRING)
+ assert_type("VARCHAR(10)", String, GenericDataType.STRING)
+ assert_type("TEXT", String, GenericDataType.STRING)
+ assert_type("NCHAR(10)", UnicodeText, GenericDataType.STRING)
+ assert_type("NVARCHAR(10)", UnicodeText, GenericDataType.STRING)
+ assert_type("NTEXT", UnicodeText, GenericDataType.STRING)
def test_where_clause_n_prefix(self):
dialect = mssql.dialect()
spec = MssqlEngineSpec
- str_col = column("col", type_=spec.get_sqla_column_type("VARCHAR(10)"))
- unicode_col = column("unicode_col", type_=spec.get_sqla_column_type("NTEXT"))
+ type_, _ = spec.get_sqla_column_type("VARCHAR(10)")
+ str_col = column("col", type_=type_)
+ type_, _ = spec.get_sqla_column_type("NTEXT")
+ unicode_col = column("unicode_col", type_=type_)
tbl = table("tbl")
sel = (
select([str_col, unicode_col])
diff --git a/tests/db_engine_specs/mysql_tests.py b/tests/db_engine_specs/mysql_tests.py
index ba56b6c..035b06f 100644
--- a/tests/db_engine_specs/mysql_tests.py
+++ b/tests/db_engine_specs/mysql_tests.py
@@ -89,18 +89,9 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
("TIME", GenericDataType.TEMPORAL),
)
- for type_expectation in type_expectations:
- type_str = type_expectation[0]
- col_type = type_expectation[1]
- assert MySQLEngineSpec.is_db_column_type_match(
- type_str, GenericDataType.NUMERIC
- ) is (col_type == GenericDataType.NUMERIC)
- assert MySQLEngineSpec.is_db_column_type_match(
- type_str, GenericDataType.STRING
- ) is (col_type == GenericDataType.STRING)
- assert MySQLEngineSpec.is_db_column_type_match(
- type_str, GenericDataType.TEMPORAL
- ) is (col_type == GenericDataType.TEMPORAL)
+ for type_str, col_type in type_expectations:
+ column_spec = MySQLEngineSpec.get_column_spec(type_str)
+ assert column_spec.generic_type == col_type
def test_extract_error_message(self):
from MySQLdb._exceptions import OperationalError
diff --git a/tests/db_engine_specs/presto_tests.py b/tests/db_engine_specs/presto_tests.py
index d0343a3..5fd16a6 100644
--- a/tests/db_engine_specs/presto_tests.py
+++ b/tests/db_engine_specs/presto_tests.py
@@ -24,7 +24,7 @@ from sqlalchemy.sql import select
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.sql_parse import ParsedQuery
-from superset.utils.core import DatasourceName
+from superset.utils.core import DatasourceName, GenericDataType
from tests.db_engine_specs.base_tests import TestDbEngineSpec
@@ -535,30 +535,37 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
def test_get_sqla_column_type(self):
- sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)")
- assert isinstance(sqla_type, types.VARCHAR)
- assert sqla_type.length == 255
-
- sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar")
- assert isinstance(sqla_type, types.String)
- assert sqla_type.length is None
-
- sqla_type = PrestoEngineSpec.get_sqla_column_type("char(10)")
- assert isinstance(sqla_type, types.CHAR)
- assert sqla_type.length == 10
-
- sqla_type = PrestoEngineSpec.get_sqla_column_type("char")
- assert isinstance(sqla_type, types.CHAR)
- assert sqla_type.length is None
-
- sqla_type = PrestoEngineSpec.get_sqla_column_type("integer")
- assert isinstance(sqla_type, types.Integer)
-
- sqla_type = PrestoEngineSpec.get_sqla_column_type("time")
- assert isinstance(sqla_type, types.Time)
-
- sqla_type = PrestoEngineSpec.get_sqla_column_type("timestamp")
- assert isinstance(sqla_type, types.TIMESTAMP)
+ column_spec = PrestoEngineSpec.get_column_spec("varchar(255)")
+ assert isinstance(column_spec.sqla_type, types.VARCHAR)
+ assert column_spec.sqla_type.length == 255
+ self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
+
+ column_spec = PrestoEngineSpec.get_column_spec("varchar")
+ assert isinstance(column_spec.sqla_type, types.String)
+ assert column_spec.sqla_type.length is None
+ self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
+
+ column_spec = PrestoEngineSpec.get_column_spec("char(10)")
+ assert isinstance(column_spec.sqla_type, types.CHAR)
+ assert column_spec.sqla_type.length == 10
+ self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
+
+ column_spec = PrestoEngineSpec.get_column_spec("char")
+ assert isinstance(column_spec.sqla_type, types.CHAR)
+ assert column_spec.sqla_type.length is None
+ self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
+
+ column_spec = PrestoEngineSpec.get_column_spec("integer")
+ assert isinstance(column_spec.sqla_type, types.Integer)
+ self.assertEqual(column_spec.generic_type, GenericDataType.NUMERIC)
+
+ column_spec = PrestoEngineSpec.get_column_spec("time")
+ assert isinstance(column_spec.sqla_type, types.Time)
+ self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)
+
+ column_spec = PrestoEngineSpec.get_column_spec("timestamp")
+ assert isinstance(column_spec.sqla_type, types.TIMESTAMP)
+ self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)
sqla_type = PrestoEngineSpec.get_sqla_column_type(None)
assert sqla_type is None
diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py
index e034609..cdd77c2 100644
--- a/tests/sqla_models_tests.py
+++ b/tests/sqla_models_tests.py
@@ -84,11 +84,9 @@ class TestDatabaseModel(SupersetTestCase):
"TEXT": GenericDataType.STRING,
"NTEXT": GenericDataType.STRING,
# numeric
- "INT": GenericDataType.NUMERIC,
+ "INTEGER": GenericDataType.NUMERIC,
"BIGINT": GenericDataType.NUMERIC,
- "FLOAT": GenericDataType.NUMERIC,
"DECIMAL": GenericDataType.NUMERIC,
- "MONEY": GenericDataType.NUMERIC,
# temporal
"DATE": GenericDataType.TEMPORAL,
"DATETIME": GenericDataType.TEMPORAL,