You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2020/03/04 14:33:58 UTC
[incubator-superset] branch master updated: fix: share column type
matching between model and result set (#9161)
This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 7a91498 fix: share column type matching between model and result set (#9161)
7a91498 is described below
commit 7a91498cf1a9e56d4b3d7b3805076137525ea277
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Wed Mar 4 16:33:22 2020 +0200
fix: share column type matching between model and result set (#9161)
* Share column type matching between model and result set
* Address comments
---
superset/common/query_context.py | 2 +-
superset/connectors/base/models.py | 6 ++---
superset/connectors/druid/models.py | 2 +-
superset/connectors/sqla/models.py | 37 +++++++++++++++++++------
superset/db_engine_specs/base.py | 54 ++++++++++++++++++++++++++++++++++++-
superset/result_set.py | 11 ++++----
superset/utils/core.py | 14 +++++++++-
superset/views/core.py | 1 -
tests/result_set_tests.py | 19 ++++++-------
tests/sqla_models_tests.py | 48 ++++++++++++++++++++++++---------
10 files changed, 151 insertions(+), 43 deletions(-)
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index edfcfaa..40df7c9 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -123,7 +123,7 @@ class QueryContext:
@staticmethod
def get_data( # pylint: disable=invalid-name,no-self-use
- df: pd.DataFrame
+ df: pd.DataFrame,
) -> List[Dict]:
return df.to_dict(orient="records")
diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py
index 2f06245..eac6cbb 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -387,15 +387,15 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
"DECIMAL",
"MONEY",
)
- date_types = ("DATE", "TIME", "DATETIME")
+ date_types = ("DATE", "TIME")
str_types = ("VARCHAR", "STRING", "CHAR")
@property
- def is_num(self) -> bool:
+ def is_numeric(self) -> bool:
return self.type and any(map(lambda t: t in self.type.upper(), self.num_types))
@property
- def is_time(self) -> bool:
+ def is_temporal(self) -> bool:
return self.type and any(map(lambda t: t in self.type.upper(), self.date_types))
@property
diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py
index ca61a47..517a458 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -527,7 +527,7 @@ class DruidDatasource(Model, BaseDatasource):
@property
def num_cols(self) -> List[str]:
- return [c.column_name for c in self.columns if c.is_num]
+ return [c.column_name for c in self.columns if c.is_numeric]
@property
def name(self) -> str: # type: ignore
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 9236bc6..98ba5b2 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -145,6 +145,27 @@ class TableColumn(Model, BaseColumn):
update_from_object_fields = [s for s in export_fields if s not in ("table_id",)]
export_parent = "table"
+ @property
+ def is_numeric(self) -> bool:
+ db_engine_spec = self.table.database.db_engine_spec
+ return db_engine_spec.is_db_column_type_match(
+ self.type, utils.DbColumnType.NUMERIC
+ )
+
+ @property
+ def is_string(self) -> bool:
+ db_engine_spec = self.table.database.db_engine_spec
+ return db_engine_spec.is_db_column_type_match(
+ self.type, utils.DbColumnType.STRING
+ )
+
+ @property
+ def is_temporal(self) -> bool:
+ db_engine_spec = self.table.database.db_engine_spec
+ return db_engine_spec.is_db_column_type_match(
+ self.type, utils.DbColumnType.TEMPORAL
+ )
+
def get_sqla_col(self, label: Optional[str] = None) -> Column:
label = label or self.column_name
if self.expression:
@@ -489,7 +510,7 @@ class SqlaTable(Model, BaseDatasource):
@property
def num_cols(self) -> List:
- return [c.column_name for c in self.columns if c.is_num]
+ return [c.column_name for c in self.columns if c.is_numeric]
@property
def any_dttm_col(self) -> Optional[str]:
@@ -809,7 +830,7 @@ class SqlaTable(Model, BaseDatasource):
is_list_target = op in ("in", "not in")
eq = self.filter_values_handler(
flt.get("val"),
- target_column_is_numeric=col_obj.is_num,
+ target_column_is_numeric=col_obj.is_numeric,
is_list_target=is_list_target,
)
if op in ("in", "not in"):
@@ -820,7 +841,7 @@ class SqlaTable(Model, BaseDatasource):
cond = ~cond
where_clause_and.append(cond)
else:
- if col_obj.is_num:
+ if col_obj.is_numeric:
eq = utils.string_to_num(flt["val"])
if op == "==":
where_clause_and.append(col_obj.get_sqla_col() == eq)
@@ -1074,17 +1095,17 @@ class SqlaTable(Model, BaseDatasource):
logger.exception(e)
dbcol = dbcols.get(col.name, None)
if not dbcol:
- dbcol = TableColumn(column_name=col.name, type=datatype)
- dbcol.sum = dbcol.is_num
- dbcol.avg = dbcol.is_num
- dbcol.is_dttm = dbcol.is_time
+ dbcol = TableColumn(column_name=col.name, type=datatype, table=self)
+ dbcol.sum = dbcol.is_numeric
+ dbcol.avg = dbcol.is_numeric
+ dbcol.is_dttm = dbcol.is_temporal
db_engine_spec.alter_new_orm_column(dbcol)
else:
dbcol.type = datatype
dbcol.groupby = True
dbcol.filterable = True
self.columns.append(dbcol)
- if not any_date_col and dbcol.is_time:
+ if not any_date_col and dbcol.is_temporal:
any_date_col = col.name
metrics.append(
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 7c63e93..ba74a42 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -20,7 +20,17 @@ import os
import re
from contextlib import closing
from datetime import datetime
-from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TYPE_CHECKING, Union
+from typing import (
+ Any,
+ Dict,
+ List,
+ NamedTuple,
+ Optional,
+ Pattern,
+ Tuple,
+ TYPE_CHECKING,
+ Union,
+)
import pandas as pd
import sqlparse
@@ -134,6 +144,48 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
max_column_name_length = 0
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
+ # default matching patterns for identifying column types
+ db_column_types: Dict[utils.DbColumnType, Tuple[Pattern, ...]] = {
+ utils.DbColumnType.NUMERIC: (
+ re.compile(r".*DOUBLE.*", re.IGNORECASE),
+ re.compile(r".*FLOAT.*", re.IGNORECASE),
+ re.compile(r".*INT.*", re.IGNORECASE),
+ re.compile(r".*NUMBER.*", re.IGNORECASE),
+ re.compile(r".*LONG.*", re.IGNORECASE),
+ re.compile(r".*REAL.*", re.IGNORECASE),
+ re.compile(r".*NUMERIC.*", re.IGNORECASE),
+ re.compile(r".*DECIMAL.*", re.IGNORECASE),
+ re.compile(r".*MONEY.*", re.IGNORECASE),
+ ),
+ utils.DbColumnType.STRING: (
+ re.compile(r".*CHAR.*", re.IGNORECASE),
+ re.compile(r".*STRING.*", re.IGNORECASE),
+ ),
+ utils.DbColumnType.TEMPORAL: (
+ re.compile(r".*DATE.*", re.IGNORECASE),
+ re.compile(r".*TIME.*", re.IGNORECASE),
+ ),
+ }
+
+ @classmethod
+ def is_db_column_type_match(
+ cls, db_column_type: Optional[str], target_column_type: utils.DbColumnType
+ ) -> bool:
+ """
+ Check if a column type satisfies a pattern in a collection of regexes found in
+ `db_column_types`. For example, if `db_column_type == "NVARCHAR"`,
+ it would be a match for "STRING" due to being a match for the regex ".*CHAR.*".
+
+ :param db_column_type: Column type to evaluate
+ :param target_column_type: The target type to evaluate for
+ :return: `True` if a `db_column_type` matches any pattern corresponding to
+ `target_column_type`
+ """
+ if not db_column_type:
+ return False
+ patterns = cls.db_column_types[target_column_type]
+ return any(pattern.match(db_column_type) for pattern in patterns)
+
@classmethod
def get_allow_cost_estimate(cls, version: Optional[str] = None) -> bool:
return False
diff --git a/superset/result_set.py b/superset/result_set.py
index 8c4c567..bc7299b 100644
--- a/superset/result_set.py
+++ b/superset/result_set.py
@@ -20,7 +20,6 @@
import datetime
import json
import logging
-import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
import numpy as np
@@ -75,6 +74,7 @@ class SupersetResultSet:
cursor_description: Tuple[Any, ...],
db_engine_spec: Type[db_engine_specs.BaseEngineSpec],
):
+ self.db_engine_spec = db_engine_spec
data = data or []
column_names: List[str] = []
pa_data: List[pa.Array] = []
@@ -173,9 +173,10 @@ class SupersetResultSet:
def first_nonempty(items: List) -> Any:
return next((i for i in items if i), None)
- @staticmethod
- def is_date(db_type_str: Optional[str]) -> bool:
- return db_type_str in ("DATETIME", "TIMESTAMP")
+ def is_temporal(self, db_type_str: Optional[str]) -> bool:
+ return self.db_engine_spec.is_db_column_type_match(
+ db_type_str, utils.DbColumnType.TEMPORAL
+ )
def data_type(self, col_name: str, pa_dtype: pa.DataType) -> Optional[str]:
"""Given a pyarrow data type, Returns a generic database type"""
@@ -211,7 +212,7 @@ class SupersetResultSet:
column = {
"name": col.name,
"type": db_type_str,
- "is_date": self.is_date(db_type_str),
+ "is_date": self.is_temporal(db_type_str),
}
columns.append(column)
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 55c2395..0e820e1 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -22,6 +22,7 @@ import functools
import json
import logging
import os
+import re
import signal
import smtplib
import traceback
@@ -922,6 +923,7 @@ def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
database = (
db.session.query(models.Database).filter_by(database_name=database_name).first()
)
+
if not database:
logger.info(f"Creating database reference for {database_name}")
database = models.Database(database_name=database_name, *args, **kwargs)
@@ -1225,7 +1227,7 @@ class TimeRangeEndpoint(str, Enum):
UNKNOWN = "unknown"
-class ReservedUrlParameters(Enum):
+class ReservedUrlParameters(str, Enum):
"""
Reserved URL parameters that are used internally by Superset. These will not be
passed to chart queries, as they control the behavior of the UI.
@@ -1243,3 +1245,13 @@ class QuerySource(Enum):
CHART = 0
DASHBOARD = 1
SQL_LAB = 2
+
+
+class DbColumnType(Enum):
+ """
+ Generic database column type
+ """
+
+ NUMERIC = 0
+ STRING = 1
+ TEMPORAL = 2
diff --git a/superset/views/core.py b/superset/views/core.py
index 4ecc097..527dbb9 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -48,7 +48,6 @@ import superset.models.core as models
from superset import (
app,
appbuilder,
- cache,
conf,
dataframe,
db,
diff --git a/tests/result_set_tests.py b/tests/result_set_tests.py
index f1d78a4..184511a 100644
--- a/tests/result_set_tests.py
+++ b/tests/result_set_tests.py
@@ -17,9 +17,6 @@
# isort:skip_file
from datetime import datetime
-import numpy as np
-import pandas as pd
-
import tests.test_app
from superset.dataframe import df_to_records
from superset.db_engine_specs import BaseEngineSpec
@@ -88,12 +85,16 @@ class SupersetResultSetTestCase(SupersetTestCase):
)
def test_is_date(self):
- is_date = SupersetResultSet.is_date
- self.assertEqual(is_date("DATETIME"), True)
- self.assertEqual(is_date("TIMESTAMP"), True)
- self.assertEqual(is_date("STRING"), False)
- self.assertEqual(is_date(""), False)
- self.assertEqual(is_date(None), False)
+ data = [("a", 1), ("a", 2)]
+ cursor_descr = (("a", "string"), ("a", "string"))
+ results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
+ self.assertEqual(results.is_temporal("DATE"), True)
+ self.assertEqual(results.is_temporal("DATETIME"), True)
+ self.assertEqual(results.is_temporal("TIME"), True)
+ self.assertEqual(results.is_temporal("TIMESTAMP"), True)
+ self.assertEqual(results.is_temporal("STRING"), False)
+ self.assertEqual(results.is_temporal(""), False)
+ self.assertEqual(results.is_temporal(None), False)
def test_dedup_with_data(self):
data = [("a", 1), ("a", 2)]
diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py
index e959989..07d1d06 100644
--- a/tests/sqla_models_tests.py
+++ b/tests/sqla_models_tests.py
@@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
-import tests.test_app
+from typing import Dict
+
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.druid import DruidEngineSpec
-from superset.utils.core import get_example_database
+from superset.models.core import Database
+from superset.utils.core import DbColumnType, get_example_database
from .base_tests import SupersetTestCase
@@ -26,23 +28,43 @@ from .base_tests import SupersetTestCase
class DatabaseModelTestCase(SupersetTestCase):
def test_is_time_druid_time_col(self):
"""Druid has a special __time column"""
- col = TableColumn(column_name="__time", type="INTEGER")
+
+ database = Database(database_name="druid_db", sqlalchemy_uri="druid://db")
+ tbl = SqlaTable(table_name="druid_tbl", database=database)
+ col = TableColumn(column_name="__time", type="INTEGER", table=tbl)
self.assertEqual(col.is_dttm, None)
DruidEngineSpec.alter_new_orm_column(col)
self.assertEqual(col.is_dttm, True)
- col = TableColumn(column_name="__not_time", type="INTEGER")
- self.assertEqual(col.is_time, False)
+ col = TableColumn(column_name="__not_time", type="INTEGER", table=tbl)
+ self.assertEqual(col.is_temporal, False)
- def test_is_time_by_type(self):
- col = TableColumn(column_name="foo", type="DATE")
- self.assertEqual(col.is_time, True)
-
- col = TableColumn(column_name="foo", type="DATETIME")
- self.assertEqual(col.is_time, True)
+ def test_db_column_types(self):
+ test_cases: Dict[str, DbColumnType] = {
+ # string
+ "CHAR": DbColumnType.STRING,
+ "VARCHAR": DbColumnType.STRING,
+ "NVARCHAR": DbColumnType.STRING,
+ "STRING": DbColumnType.STRING,
+ # numeric
+ "INT": DbColumnType.NUMERIC,
+ "BIGINT": DbColumnType.NUMERIC,
+ "FLOAT": DbColumnType.NUMERIC,
+ "DECIMAL": DbColumnType.NUMERIC,
+ "MONEY": DbColumnType.NUMERIC,
+ # temporal
+ "DATE": DbColumnType.TEMPORAL,
+ "DATETIME": DbColumnType.TEMPORAL,
+ "TIME": DbColumnType.TEMPORAL,
+ "TIMESTAMP": DbColumnType.TEMPORAL,
+ }
- col = TableColumn(column_name="foo", type="STRING")
- self.assertEqual(col.is_time, False)
+ tbl = SqlaTable(table_name="col_type_test_tbl", database=get_example_database())
+ for str_type, db_col_type in test_cases.items():
+ col = TableColumn(column_name="foo", type=str_type, table=tbl)
+ self.assertEqual(col.is_temporal, db_col_type == DbColumnType.TEMPORAL)
+ self.assertEqual(col.is_numeric, db_col_type == DbColumnType.NUMERIC)
+ self.assertEqual(col.is_string, db_col_type == DbColumnType.STRING)
def test_has_extra_cache_keys(self):
query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user"