You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2020/03/04 14:33:58 UTC

[incubator-superset] branch master updated: fix: share column type matching between model and result set (#9161)

This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 7a91498  fix: share column type matching between model and result set (#9161)
7a91498 is described below

commit 7a91498cf1a9e56d4b3d7b3805076137525ea277
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Wed Mar 4 16:33:22 2020 +0200

    fix: share column type matching between model and result set (#9161)
    
    * Share column type matching between model and result set
    
    * Address comments
---
 superset/common/query_context.py    |  2 +-
 superset/connectors/base/models.py  |  6 ++---
 superset/connectors/druid/models.py |  2 +-
 superset/connectors/sqla/models.py  | 37 +++++++++++++++++++------
 superset/db_engine_specs/base.py    | 54 ++++++++++++++++++++++++++++++++++++-
 superset/result_set.py              | 11 ++++----
 superset/utils/core.py              | 14 +++++++++-
 superset/views/core.py              |  1 -
 tests/result_set_tests.py           | 19 ++++++-------
 tests/sqla_models_tests.py          | 48 ++++++++++++++++++++++++---------
 10 files changed, 151 insertions(+), 43 deletions(-)

diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index edfcfaa..40df7c9 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -123,7 +123,7 @@ class QueryContext:
 
     @staticmethod
     def get_data(  # pylint: disable=invalid-name,no-self-use
-        df: pd.DataFrame
+        df: pd.DataFrame,
     ) -> List[Dict]:
         return df.to_dict(orient="records")
 
diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py
index 2f06245..eac6cbb 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -387,15 +387,15 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
         "DECIMAL",
         "MONEY",
     )
-    date_types = ("DATE", "TIME", "DATETIME")
+    date_types = ("DATE", "TIME")
     str_types = ("VARCHAR", "STRING", "CHAR")
 
     @property
-    def is_num(self) -> bool:
+    def is_numeric(self) -> bool:
         return self.type and any(map(lambda t: t in self.type.upper(), self.num_types))
 
     @property
-    def is_time(self) -> bool:
+    def is_temporal(self) -> bool:
         return self.type and any(map(lambda t: t in self.type.upper(), self.date_types))
 
     @property
diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py
index ca61a47..517a458 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -527,7 +527,7 @@ class DruidDatasource(Model, BaseDatasource):
 
     @property
     def num_cols(self) -> List[str]:
-        return [c.column_name for c in self.columns if c.is_num]
+        return [c.column_name for c in self.columns if c.is_numeric]
 
     @property
     def name(self) -> str:  # type: ignore
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 9236bc6..98ba5b2 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -145,6 +145,27 @@ class TableColumn(Model, BaseColumn):
     update_from_object_fields = [s for s in export_fields if s not in ("table_id",)]
     export_parent = "table"
 
+    @property
+    def is_numeric(self) -> bool:
+        db_engine_spec = self.table.database.db_engine_spec
+        return db_engine_spec.is_db_column_type_match(
+            self.type, utils.DbColumnType.NUMERIC
+        )
+
+    @property
+    def is_string(self) -> bool:
+        db_engine_spec = self.table.database.db_engine_spec
+        return db_engine_spec.is_db_column_type_match(
+            self.type, utils.DbColumnType.STRING
+        )
+
+    @property
+    def is_temporal(self) -> bool:
+        db_engine_spec = self.table.database.db_engine_spec
+        return db_engine_spec.is_db_column_type_match(
+            self.type, utils.DbColumnType.TEMPORAL
+        )
+
     def get_sqla_col(self, label: Optional[str] = None) -> Column:
         label = label or self.column_name
         if self.expression:
@@ -489,7 +510,7 @@ class SqlaTable(Model, BaseDatasource):
 
     @property
     def num_cols(self) -> List:
-        return [c.column_name for c in self.columns if c.is_num]
+        return [c.column_name for c in self.columns if c.is_numeric]
 
     @property
     def any_dttm_col(self) -> Optional[str]:
@@ -809,7 +830,7 @@ class SqlaTable(Model, BaseDatasource):
                 is_list_target = op in ("in", "not in")
                 eq = self.filter_values_handler(
                     flt.get("val"),
-                    target_column_is_numeric=col_obj.is_num,
+                    target_column_is_numeric=col_obj.is_numeric,
                     is_list_target=is_list_target,
                 )
                 if op in ("in", "not in"):
@@ -820,7 +841,7 @@ class SqlaTable(Model, BaseDatasource):
                         cond = ~cond
                     where_clause_and.append(cond)
                 else:
-                    if col_obj.is_num:
+                    if col_obj.is_numeric:
                         eq = utils.string_to_num(flt["val"])
                     if op == "==":
                         where_clause_and.append(col_obj.get_sqla_col() == eq)
@@ -1074,17 +1095,17 @@ class SqlaTable(Model, BaseDatasource):
                 logger.exception(e)
             dbcol = dbcols.get(col.name, None)
             if not dbcol:
-                dbcol = TableColumn(column_name=col.name, type=datatype)
-                dbcol.sum = dbcol.is_num
-                dbcol.avg = dbcol.is_num
-                dbcol.is_dttm = dbcol.is_time
+                dbcol = TableColumn(column_name=col.name, type=datatype, table=self)
+                dbcol.sum = dbcol.is_numeric
+                dbcol.avg = dbcol.is_numeric
+                dbcol.is_dttm = dbcol.is_temporal
                 db_engine_spec.alter_new_orm_column(dbcol)
             else:
                 dbcol.type = datatype
             dbcol.groupby = True
             dbcol.filterable = True
             self.columns.append(dbcol)
-            if not any_date_col and dbcol.is_time:
+            if not any_date_col and dbcol.is_temporal:
                 any_date_col = col.name
 
         metrics.append(
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 7c63e93..ba74a42 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -20,7 +20,17 @@ import os
 import re
 from contextlib import closing
 from datetime import datetime
-from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TYPE_CHECKING, Union
+from typing import (
+    Any,
+    Dict,
+    List,
+    NamedTuple,
+    Optional,
+    Pattern,
+    Tuple,
+    TYPE_CHECKING,
+    Union,
+)
 
 import pandas as pd
 import sqlparse
@@ -134,6 +144,48 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     max_column_name_length = 0
     try_remove_schema_from_table_name = True  # pylint: disable=invalid-name
 
+    # default matching patterns for identifying column types
+    db_column_types: Dict[utils.DbColumnType, Tuple[Pattern, ...]] = {
+        utils.DbColumnType.NUMERIC: (
+            re.compile(r".*DOUBLE.*", re.IGNORECASE),
+            re.compile(r".*FLOAT.*", re.IGNORECASE),
+            re.compile(r".*INT.*", re.IGNORECASE),
+            re.compile(r".*NUMBER.*", re.IGNORECASE),
+            re.compile(r".*LONG.*", re.IGNORECASE),
+            re.compile(r".*REAL.*", re.IGNORECASE),
+            re.compile(r".*NUMERIC.*", re.IGNORECASE),
+            re.compile(r".*DECIMAL.*", re.IGNORECASE),
+            re.compile(r".*MONEY.*", re.IGNORECASE),
+        ),
+        utils.DbColumnType.STRING: (
+            re.compile(r".*CHAR.*", re.IGNORECASE),
+            re.compile(r".*STRING.*", re.IGNORECASE),
+        ),
+        utils.DbColumnType.TEMPORAL: (
+            re.compile(r".*DATE.*", re.IGNORECASE),
+            re.compile(r".*TIME.*", re.IGNORECASE),
+        ),
+    }
+
+    @classmethod
+    def is_db_column_type_match(
+        cls, db_column_type: Optional[str], target_column_type: utils.DbColumnType
+    ) -> bool:
+        """
+        Check if a column type satisfies a pattern in a collection of regexes found in
+        `db_column_types`. For example, if `db_column_type == "NVARCHAR"`,
+        it would be a match for "STRING" due to being a match for the regex ".*CHAR.*".
+
+        :param db_column_type: Column type to evaluate
+        :param target_column_type: The target type to evaluate for
+        :return: `True` if a `db_column_type` matches any pattern corresponding to
+        `target_column_type`
+        """
+        if not db_column_type:
+            return False
+        patterns = cls.db_column_types[target_column_type]
+        return any(pattern.match(db_column_type) for pattern in patterns)
+
     @classmethod
     def get_allow_cost_estimate(cls, version: Optional[str] = None) -> bool:
         return False
diff --git a/superset/result_set.py b/superset/result_set.py
index 8c4c567..bc7299b 100644
--- a/superset/result_set.py
+++ b/superset/result_set.py
@@ -20,7 +20,6 @@
 import datetime
 import json
 import logging
-import re
 from typing import Any, Callable, Dict, List, Optional, Tuple, Type
 
 import numpy as np
@@ -75,6 +74,7 @@ class SupersetResultSet:
         cursor_description: Tuple[Any, ...],
         db_engine_spec: Type[db_engine_specs.BaseEngineSpec],
     ):
+        self.db_engine_spec = db_engine_spec
         data = data or []
         column_names: List[str] = []
         pa_data: List[pa.Array] = []
@@ -173,9 +173,10 @@ class SupersetResultSet:
     def first_nonempty(items: List) -> Any:
         return next((i for i in items if i), None)
 
-    @staticmethod
-    def is_date(db_type_str: Optional[str]) -> bool:
-        return db_type_str in ("DATETIME", "TIMESTAMP")
+    def is_temporal(self, db_type_str: Optional[str]) -> bool:
+        return self.db_engine_spec.is_db_column_type_match(
+            db_type_str, utils.DbColumnType.TEMPORAL
+        )
 
     def data_type(self, col_name: str, pa_dtype: pa.DataType) -> Optional[str]:
         """Given a pyarrow data type, Returns a generic database type"""
@@ -211,7 +212,7 @@ class SupersetResultSet:
             column = {
                 "name": col.name,
                 "type": db_type_str,
-                "is_date": self.is_date(db_type_str),
+                "is_date": self.is_temporal(db_type_str),
             }
             columns.append(column)
 
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 55c2395..0e820e1 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -22,6 +22,7 @@ import functools
 import json
 import logging
 import os
+import re
 import signal
 import smtplib
 import traceback
@@ -922,6 +923,7 @@ def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
     database = (
         db.session.query(models.Database).filter_by(database_name=database_name).first()
     )
+
     if not database:
         logger.info(f"Creating database reference for {database_name}")
         database = models.Database(database_name=database_name, *args, **kwargs)
@@ -1225,7 +1227,7 @@ class TimeRangeEndpoint(str, Enum):
     UNKNOWN = "unknown"
 
 
-class ReservedUrlParameters(Enum):
+class ReservedUrlParameters(str, Enum):
     """
     Reserved URL parameters that are used internally by Superset. These will not be
     passed to chart queries, as they control the behavior of the UI.
@@ -1243,3 +1245,13 @@ class QuerySource(Enum):
     CHART = 0
     DASHBOARD = 1
     SQL_LAB = 2
+
+
+class DbColumnType(Enum):
+    """
+    Generic database column type
+    """
+
+    NUMERIC = 0
+    STRING = 1
+    TEMPORAL = 2
diff --git a/superset/views/core.py b/superset/views/core.py
index 4ecc097..527dbb9 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -48,7 +48,6 @@ import superset.models.core as models
 from superset import (
     app,
     appbuilder,
-    cache,
     conf,
     dataframe,
     db,
diff --git a/tests/result_set_tests.py b/tests/result_set_tests.py
index f1d78a4..184511a 100644
--- a/tests/result_set_tests.py
+++ b/tests/result_set_tests.py
@@ -17,9 +17,6 @@
 # isort:skip_file
 from datetime import datetime
 
-import numpy as np
-import pandas as pd
-
 import tests.test_app
 from superset.dataframe import df_to_records
 from superset.db_engine_specs import BaseEngineSpec
@@ -88,12 +85,16 @@ class SupersetResultSetTestCase(SupersetTestCase):
         )
 
     def test_is_date(self):
-        is_date = SupersetResultSet.is_date
-        self.assertEqual(is_date("DATETIME"), True)
-        self.assertEqual(is_date("TIMESTAMP"), True)
-        self.assertEqual(is_date("STRING"), False)
-        self.assertEqual(is_date(""), False)
-        self.assertEqual(is_date(None), False)
+        data = [("a", 1), ("a", 2)]
+        cursor_descr = (("a", "string"), ("a", "string"))
+        results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
+        self.assertEqual(results.is_temporal("DATE"), True)
+        self.assertEqual(results.is_temporal("DATETIME"), True)
+        self.assertEqual(results.is_temporal("TIME"), True)
+        self.assertEqual(results.is_temporal("TIMESTAMP"), True)
+        self.assertEqual(results.is_temporal("STRING"), False)
+        self.assertEqual(results.is_temporal(""), False)
+        self.assertEqual(results.is_temporal(None), False)
 
     def test_dedup_with_data(self):
         data = [("a", 1), ("a", 2)]
diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py
index e959989..07d1d06 100644
--- a/tests/sqla_models_tests.py
+++ b/tests/sqla_models_tests.py
@@ -15,10 +15,12 @@
 # specific language governing permissions and limitations
 # under the License.
 # isort:skip_file
-import tests.test_app
+from typing import Dict
+
 from superset.connectors.sqla.models import SqlaTable, TableColumn
 from superset.db_engine_specs.druid import DruidEngineSpec
-from superset.utils.core import get_example_database
+from superset.models.core import Database
+from superset.utils.core import DbColumnType, get_example_database
 
 from .base_tests import SupersetTestCase
 
@@ -26,23 +28,43 @@ from .base_tests import SupersetTestCase
 class DatabaseModelTestCase(SupersetTestCase):
     def test_is_time_druid_time_col(self):
         """Druid has a special __time column"""
-        col = TableColumn(column_name="__time", type="INTEGER")
+
+        database = Database(database_name="druid_db", sqlalchemy_uri="druid://db")
+        tbl = SqlaTable(table_name="druid_tbl", database=database)
+        col = TableColumn(column_name="__time", type="INTEGER", table=tbl)
         self.assertEqual(col.is_dttm, None)
         DruidEngineSpec.alter_new_orm_column(col)
         self.assertEqual(col.is_dttm, True)
 
-        col = TableColumn(column_name="__not_time", type="INTEGER")
-        self.assertEqual(col.is_time, False)
+        col = TableColumn(column_name="__not_time", type="INTEGER", table=tbl)
+        self.assertEqual(col.is_temporal, False)
 
-    def test_is_time_by_type(self):
-        col = TableColumn(column_name="foo", type="DATE")
-        self.assertEqual(col.is_time, True)
-
-        col = TableColumn(column_name="foo", type="DATETIME")
-        self.assertEqual(col.is_time, True)
+    def test_db_column_types(self):
+        test_cases: Dict[str, DbColumnType] = {
+            # string
+            "CHAR": DbColumnType.STRING,
+            "VARCHAR": DbColumnType.STRING,
+            "NVARCHAR": DbColumnType.STRING,
+            "STRING": DbColumnType.STRING,
+            # numeric
+            "INT": DbColumnType.NUMERIC,
+            "BIGINT": DbColumnType.NUMERIC,
+            "FLOAT": DbColumnType.NUMERIC,
+            "DECIMAL": DbColumnType.NUMERIC,
+            "MONEY": DbColumnType.NUMERIC,
+            # temporal
+            "DATE": DbColumnType.TEMPORAL,
+            "DATETIME": DbColumnType.TEMPORAL,
+            "TIME": DbColumnType.TEMPORAL,
+            "TIMESTAMP": DbColumnType.TEMPORAL,
+        }
 
-        col = TableColumn(column_name="foo", type="STRING")
-        self.assertEqual(col.is_time, False)
+        tbl = SqlaTable(table_name="col_type_test_tbl", database=get_example_database())
+        for str_type, db_col_type in test_cases.items():
+            col = TableColumn(column_name="foo", type=str_type, table=tbl)
+            self.assertEqual(col.is_temporal, db_col_type == DbColumnType.TEMPORAL)
+            self.assertEqual(col.is_numeric, db_col_type == DbColumnType.NUMERIC)
+            self.assertEqual(col.is_string, db_col_type == DbColumnType.STRING)
 
     def test_has_extra_cache_keys(self):
         query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user"