You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2020/03/17 19:35:03 UTC

[incubator-superset] branch master updated: fix: remove character set and collate column info by default (#9316)

This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 982c234  fix: remove character set and collate column info by default (#9316)
982c234 is described below

commit 982c234a5060f350dad2015318daa908208bc330
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Tue Mar 17 21:34:39 2020 +0200

    fix: remove character set and collate column info by default (#9316)
    
    * fix: remove character set and collate column info by default
    
    * lint
    
    * remove collation and charset info before compile
---
 superset/db_engine_specs/base.py      | 10 ++++++--
 superset/db_engine_specs/mssql.py     | 18 ++------------
 superset/db_engine_specs/mysql.py     | 15 ------------
 tests/db_engine_specs/mssql_tests.py  | 45 +++++++++++++++++++++++++----------
 tests/db_engine_specs/mysql_tests.py  | 21 ++++++++++++++++
 tests/db_engine_specs/oracle_tests.py | 35 +++++++++++++++++++--------
 tests/sqla_models_tests.py            |  2 ++
 7 files changed, 90 insertions(+), 56 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 708dae0..b2ede28 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -161,6 +161,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         utils.DbColumnType.STRING: (
             re.compile(r".*CHAR.*", re.IGNORECASE),
             re.compile(r".*STRING.*", re.IGNORECASE),
+            re.compile(r".*TEXT.*", re.IGNORECASE),
         ),
         utils.DbColumnType.TEMPORAL: (
             re.compile(r".*DATE.*", re.IGNORECASE),
@@ -911,13 +912,18 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     ) -> str:
         """
         Convert sqlalchemy column type to string representation.
-        Can be overridden to remove unnecessary details, especially
-        collation info (see mysql, mssql).
+        By default removes collation and character encoding info to avoid unnecessarily
+        long datatypes.
 
         :param sqla_column_type: SqlAlchemy column type
         :param dialect: Sqlalchemy dialect
         :return: Compiled column type
         """
+        sqla_column_type = sqla_column_type.copy()
+        if hasattr(sqla_column_type, "collation"):
+            sqla_column_type.collation = None
+        if hasattr(sqla_column_type, "charset"):
+            sqla_column_type.charset = None
         return sqla_column_type.compile(dialect=dialect).upper()
 
     @classmethod
diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py
index b7ced5f..6a231b2 100644
--- a/superset/db_engine_specs/mssql.py
+++ b/superset/db_engine_specs/mssql.py
@@ -18,7 +18,6 @@ import re
 from datetime import datetime
 from typing import Any, List, Optional, Tuple
 
-from sqlalchemy.engine.interfaces import Dialect
 from sqlalchemy.types import String, TypeEngine, UnicodeText
 
 from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
@@ -66,10 +65,10 @@ class MssqlEngineSpec(BaseEngineSpec):
         # Lists of `pyodbc.Row` need to be unpacked further
         return cls.pyodbc_rows_to_tuples(data)
 
-    column_types = [
+    column_types = (
         (String(), re.compile(r"^(?<!N)((VAR){0,1}CHAR|TEXT|STRING)", re.IGNORECASE)),
         (UnicodeText(), re.compile(r"^N((VAR){0,1}CHAR|TEXT)", re.IGNORECASE)),
-    ]
+    )
 
     @classmethod
     def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]:
@@ -77,16 +76,3 @@ class MssqlEngineSpec(BaseEngineSpec):
             if regex.match(type_):
                 return sqla_type
         return None
-
-    @classmethod
-    def column_datatype_to_string(
-        cls, sqla_column_type: TypeEngine, dialect: Dialect
-    ) -> str:
-        datatype = super().column_datatype_to_string(sqla_column_type, dialect)
-        # MSSQL returns long overflowing datatype
-        # as in 'VARCHAR(255) COLLATE SQL_LATIN1_GENERAL_CP1_CI_AS'
-        # and we don't need the verbose collation type
-        str_cutoff = " COLLATE "
-        if str_cutoff in datatype:
-            datatype = datatype.split(str_cutoff)[0]
-        return datatype
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index cf33298..b19527f 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -18,9 +18,7 @@ from datetime import datetime
 from typing import Any, Dict, Optional
 from urllib import parse
 
-from sqlalchemy.engine.interfaces import Dialect
 from sqlalchemy.engine.url import URL
-from sqlalchemy.types import TypeEngine
 
 from superset.db_engine_specs.base import BaseEngineSpec
 
@@ -97,16 +95,3 @@ class MySQLEngineSpec(BaseEngineSpec):
         except Exception:  # pylint: disable=broad-except
             pass
         return message
-
-    @classmethod
-    def column_datatype_to_string(
-        cls, sqla_column_type: TypeEngine, dialect: Dialect
-    ) -> str:
-        datatype = super().column_datatype_to_string(sqla_column_type, dialect)
-        # MySQL dialect started returning long overflowing datatype
-        # as in 'VARCHAR(255) COLLATE UTF8MB4_GENERAL_CI'
-        # and we don't need the verbose collation type
-        str_cutoff = " COLLATE "
-        if str_cutoff in datatype:
-            datatype = datatype.split(str_cutoff)[0]
-        return datatype
diff --git a/tests/db_engine_specs/mssql_tests.py b/tests/db_engine_specs/mssql_tests.py
index 95cde00..238dd2a 100644
--- a/tests/db_engine_specs/mssql_tests.py
+++ b/tests/db_engine_specs/mssql_tests.py
@@ -18,6 +18,7 @@ import unittest.mock as mock
 
 from sqlalchemy import column, table
 from sqlalchemy.dialects import mssql
+from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
 from sqlalchemy.sql import select
 from sqlalchemy.types import String, UnicodeText
 
@@ -75,21 +76,23 @@ class MssqlEngineSpecTest(DbEngineSpecTestCase):
 
     def test_convert_dttm(self):
         dttm = self.get_dttm()
-
-        self.assertEqual(
-            MssqlEngineSpec.convert_dttm("DATE", dttm),
-            "CONVERT(DATE, '2019-01-02', 23)",
-        )
-
-        self.assertEqual(
-            MssqlEngineSpec.convert_dttm("DATETIME", dttm),
-            "CONVERT(DATETIME, '2019-01-02T03:04:05.678', 126)",
+        test_cases = (
+            (
+                MssqlEngineSpec.convert_dttm("DATE", dttm),
+                "CONVERT(DATE, '2019-01-02', 23)",
+            ),
+            (
+                MssqlEngineSpec.convert_dttm("DATETIME", dttm),
+                "CONVERT(DATETIME, '2019-01-02T03:04:05.678', 126)",
+            ),
+            (
+                MssqlEngineSpec.convert_dttm("SMALLDATETIME", dttm),
+                "CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)",
+            ),
         )
 
-        self.assertEqual(
-            MssqlEngineSpec.convert_dttm("SMALLDATETIME", dttm),
-            "CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)",
-        )
+        for actual, expected in test_cases:
+            self.assertEqual(actual, expected)
 
     @mock.patch.object(
         MssqlEngineSpec, "pyodbc_rows_to_tuples", return_value="converted"
@@ -102,3 +105,19 @@ class MssqlEngineSpecTest(DbEngineSpecTestCase):
             result = MssqlEngineSpec.fetch_data(None, 0)
             mock_pyodbc_rows_to_tuples.assert_called_once_with(data)
             self.assertEqual(result, "converted")
+
+    def test_column_datatype_to_string(self):
+        test_cases = (
+            (DATE(), "DATE"),
+            (VARCHAR(length=255), "VARCHAR(255)"),
+            (VARCHAR(length=255, collation="utf8_general_ci"), "VARCHAR(255)"),
+            (NVARCHAR(length=128), "NVARCHAR(128)"),
+            (TEXT(), "TEXT"),
+            (NTEXT(collation="utf8_general_ci"), "NTEXT"),
+        )
+
+        for original, expected in test_cases:
+            actual = MssqlEngineSpec.column_datatype_to_string(
+                original, mssql.dialect()
+            )
+            self.assertEqual(actual, expected)
diff --git a/tests/db_engine_specs/mysql_tests.py b/tests/db_engine_specs/mysql_tests.py
index d47e707..897507d 100644
--- a/tests/db_engine_specs/mysql_tests.py
+++ b/tests/db_engine_specs/mysql_tests.py
@@ -16,6 +16,9 @@
 # under the License.
 import unittest
 
+from sqlalchemy.dialects import mysql
+from sqlalchemy.dialects.mysql import DATE, NVARCHAR, TEXT, VARCHAR
+
 from superset.db_engine_specs.mysql import MySQLEngineSpec
 from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
 
@@ -41,3 +44,21 @@ class MySQLEngineSpecsTestCase(DbEngineSpecTestCase):
             MySQLEngineSpec.convert_dttm("DATETIME", dttm),
             "STR_TO_DATE('2019-01-02 03:04:05.678900', '%Y-%m-%d %H:%i:%s.%f')",
         )
+
+    def test_column_datatype_to_string(self):
+        test_cases = (
+            (DATE(), "DATE"),
+            (VARCHAR(length=255), "VARCHAR(255)"),
+            (
+                VARCHAR(length=255, charset="latin1", collation="utf8mb4_general_ci"),
+                "VARCHAR(255)",
+            ),
+            (NVARCHAR(length=128), "NATIONAL VARCHAR(128)"),
+            (TEXT(), "TEXT"),
+        )
+
+        for original, expected in test_cases:
+            actual = MySQLEngineSpec.column_datatype_to_string(
+                original, mysql.dialect()
+            )
+            self.assertEqual(actual, expected)
diff --git a/tests/db_engine_specs/oracle_tests.py b/tests/db_engine_specs/oracle_tests.py
index 09806a0..8b821c7 100644
--- a/tests/db_engine_specs/oracle_tests.py
+++ b/tests/db_engine_specs/oracle_tests.py
@@ -16,6 +16,7 @@
 # under the License.
 from sqlalchemy import column
 from sqlalchemy.dialects import oracle
+from sqlalchemy.dialects.oracle import DATE, NVARCHAR, VARCHAR
 
 from superset.db_engine_specs.oracle import OracleEngineSpec
 from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
@@ -39,17 +40,31 @@ class OracleTestCase(DbEngineSpecTestCase):
     def test_convert_dttm(self):
         dttm = self.get_dttm()
 
-        self.assertEqual(
-            OracleEngineSpec.convert_dttm("DATE", dttm),
-            "TO_DATE('2019-01-02', 'YYYY-MM-DD')",
+        test_cases = (
+            (
+                OracleEngineSpec.convert_dttm("DATE", dttm),
+                "TO_DATE('2019-01-02', 'YYYY-MM-DD')",
+            ),
+            (
+                OracleEngineSpec.convert_dttm("DATETIME", dttm),
+                """TO_DATE('2019-01-02T03:04:05', 'YYYY-MM-DD"T"HH24:MI:SS')""",
+            ),
+            (
+                OracleEngineSpec.convert_dttm("TIMESTAMP", dttm),
+                """TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""",
+            ),
         )
 
-        self.assertEqual(
-            OracleEngineSpec.convert_dttm("DATETIME", dttm),
-            """TO_DATE('2019-01-02T03:04:05', 'YYYY-MM-DD"T"HH24:MI:SS')""",
+    def test_column_datatype_to_string(self):
+        test_cases = (
+            (DATE(), "DATE"),
+            (VARCHAR(length=255), "VARCHAR(255 CHAR)"),
+            (VARCHAR(length=255, collation="utf8"), "VARCHAR(255 CHAR)"),
+            (NVARCHAR(length=128), "NVARCHAR2(128)"),
         )
 
-        self.assertEqual(
-            OracleEngineSpec.convert_dttm("TIMESTAMP", dttm),
-            """TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""",
-        )
+        for original, expected in test_cases:
+            actual = OracleEngineSpec.column_datatype_to_string(
+                original, oracle.dialect()
+            )
+            self.assertEqual(actual, expected)
diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py
index 07d1d06..700c2e2 100644
--- a/tests/sqla_models_tests.py
+++ b/tests/sqla_models_tests.py
@@ -46,6 +46,8 @@ class DatabaseModelTestCase(SupersetTestCase):
             "VARCHAR": DbColumnType.STRING,
             "NVARCHAR": DbColumnType.STRING,
             "STRING": DbColumnType.STRING,
+            "TEXT": DbColumnType.STRING,
+            "NTEXT": DbColumnType.STRING,
             # numeric
             "INT": DbColumnType.NUMERIC,
             "BIGINT": DbColumnType.NUMERIC,