You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by dp...@apache.org on 2020/05/14 16:00:15 UTC

[incubator-superset] branch master updated: fix(mssql): reverts #9644 and displays a better error msg (#9752)

This is an automated email from the ASF dual-hosted git repository.

dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 3cc5400  fix(mssql): reverts #9644 and displays a better error msg (#9752)
3cc5400 is described below

commit 3cc540019f2aa6c3dac1d356f0f21eeca96b34f2
Author: Daniel Vaz Gaspar <da...@gmail.com>
AuthorDate: Thu May 14 17:00:02 2020 +0100

    fix(mssql): reverts #9644 and displays a better error msg (#9752)
---
 superset/db_engine_specs/mssql.py    | 11 +++--
 superset/sql_parse.py                | 45 +-------------------
 superset/utils/core.py               | 14 +++----
 tests/db_engine_specs/mssql_tests.py | 81 ++++++++++--------------------------
 4 files changed, 36 insertions(+), 115 deletions(-)

diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py
index 4fc6e6f..fde69b3 100644
--- a/superset/db_engine_specs/mssql.py
+++ b/superset/db_engine_specs/mssql.py
@@ -22,7 +22,6 @@ from typing import Any, List, Optional, Tuple, TYPE_CHECKING
 from sqlalchemy.types import String, TypeEngine, UnicodeText
 
 from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
-from superset.sql_parse import ParsedQuery
 
 if TYPE_CHECKING:
     from superset.models.core import Database  # pylint: disable=unused-import
@@ -85,6 +84,10 @@ class MssqlEngineSpec(BaseEngineSpec):
         return None
 
     @classmethod
-    def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database") -> str:
-        new_sql = ParsedQuery(sql).set_alias()
-        return super().apply_limit_to_sql(new_sql, limit, database)
+    def extract_error_message(cls, ex: Exception) -> str:
+        if str(ex).startswith("(8155,"):
+            return (
+                f"{cls.engine} error: All your SQL functions need to "
+                "have an alias on MSSQL. For example: SELECT COUNT(*) AS C1 FROM TABLE1"
+            )
+        return f"{cls.engine} error: {cls._extract_error_message(ex)}"
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index bb6f341..e841db1 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -20,14 +20,7 @@ from urllib import parse
 
 import sqlparse
 from dataclasses import dataclass
-from sqlparse.sql import (
-    Function,
-    Identifier,
-    IdentifierList,
-    remove_quotes,
-    Token,
-    TokenList,
-)
+from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList
 from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
 from sqlparse.utils import imt
 
@@ -284,39 +277,3 @@ class ParsedQuery:
         for i in statement.tokens:
             str_res += str(i.value)
         return str_res
-
-    def set_alias(self) -> str:
-        """
-        Returns a new query string where all functions have alias.
-        This is particularly necessary for MSSQL engines.
-
-        :return: String with new aliased SQL query
-        """
-        new_sql = ""
-        changed_counter = 1
-        for token in self._parsed[0].tokens:
-            # Identifier list (list of columns)
-            if isinstance(token, IdentifierList) and token.ttype is None:
-                for i, identifier in enumerate(token.get_identifiers()):
-                    # Functions are anonymous on MSSQL
-                    if isinstance(identifier, Function) and not identifier.has_alias():
-                        identifier.value = (
-                            f"{identifier.value} AS"
-                            f" {identifier.get_real_name()}_{changed_counter}"
-                        )
-                        changed_counter += 1
-                    new_sql += str(identifier.value)
-                    # If not last identifier
-                    if i != len(list(token.get_identifiers())) - 1:
-                        new_sql += ", "
-            # Just a lonely function?
-            elif isinstance(token, Function) and token.ttype is None:
-                if not token.has_alias():
-                    token.value = (
-                        f"{token.value} AS {token.get_real_name()}_{changed_counter}"
-                    )
-                new_sql += str(token.value)
-            # Nothing to change, assemble what we have
-            else:
-                new_sql += str(token.value)
-        return new_sql
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 41deae5..47fe8c8 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -444,7 +444,7 @@ def json_dumps_w_dates(payload):
     return json.dumps(payload, default=json_int_dttm_ser)
 
 
-def error_msg_from_exception(e: Exception) -> str:
+def error_msg_from_exception(ex: Exception) -> str:
     """Translate exception into error message
 
     Database have different ways to handle exception. This function attempts
@@ -459,12 +459,12 @@ def error_msg_from_exception(e: Exception) -> str:
     The latter version is parsed correctly by this function.
     """
     msg = ""
-    if hasattr(e, "message"):
-        if isinstance(e.message, dict):  # type: ignore
-            msg = e.message.get("message")  # type: ignore
-        elif e.message:  # type: ignore
-            msg = e.message  # type: ignore
-    return msg or str(e)
+    if hasattr(ex, "message"):
+        if isinstance(ex.message, dict):  # type: ignore
+            msg = ex.message.get("message")  # type: ignore
+        elif ex.message:  # type: ignore
+            msg = ex.message  # type: ignore
+    return msg or str(ex)
 
 
 def markdown(s: str, markup_wrap: Optional[bool] = False) -> str:
diff --git a/tests/db_engine_specs/mssql_tests.py b/tests/db_engine_specs/mssql_tests.py
index 9f5351c..0a254de 100644
--- a/tests/db_engine_specs/mssql_tests.py
+++ b/tests/db_engine_specs/mssql_tests.py
@@ -15,18 +15,15 @@
 # specific language governing permissions and limitations
 # under the License.
 import unittest.mock as mock
-from typing import Optional
 
 from sqlalchemy import column, table
 from sqlalchemy.dialects import mssql
 from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
-from sqlalchemy.sql import select, Select
+from sqlalchemy.sql import select
 from sqlalchemy.types import String, UnicodeText
 
 from superset.db_engine_specs.base import BaseEngineSpec
 from superset.db_engine_specs.mssql import MssqlEngineSpec
-from superset.extensions import db
-from superset.models.core import Database
 from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
 
 
@@ -97,64 +94,28 @@ class MssqlEngineSpecTest(DbEngineSpecTestCase):
         for actual, expected in test_cases:
             self.assertEqual(actual, expected)
 
-    def test_apply_limit(self):
-        def compile_sqla_query(qry: Select, schema: Optional[str] = None) -> str:
-            return str(
-                qry.compile(
-                    dialect=mssql.dialect(), compile_kwargs={"literal_binds": True}
-                )
-            )
-
-        database = Database(
-            database_name="mssql_test",
-            sqlalchemy_uri="mssql+pymssql://sa:Password_123@localhost:1433/msdb",
+    def test_extract_error_message(self):
+        test_mssql_exception = Exception(
+            "(8155, b\"No column name was specified for column 1 of 'inner_qry'."
+            "DB-Lib error message 20018, severity 16:\\nGeneral SQL Server error: "
+            'Check messages from the SQL Server\\n")'
         )
-        db.session.add(database)
-        db.session.commit()
-
-        with mock.patch.object(database, "compile_sqla_query", new=compile_sqla_query):
-            test_sql = "SELECT COUNT(*) FROM FOO_TABLE"
-
-            limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
-
-            expected_sql = (
-                "SELECT TOP 1000 * \n"
-                "FROM (SELECT COUNT(*) AS COUNT_1 FROM FOO_TABLE) AS inner_qry"
-            )
-            self.assertEqual(expected_sql, limited_sql)
-
-            test_sql = "SELECT COUNT(*), SUM(id) FROM FOO_TABLE"
-            limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
-
-            expected_sql = (
-                "SELECT TOP 1000 * \n"
-                "FROM (SELECT COUNT(*) AS COUNT_1, SUM(id) AS SUM_2 FROM FOO_TABLE) "
-                "AS inner_qry"
-            )
-            self.assertEqual(expected_sql, limited_sql)
-
-            test_sql = "SELECT COUNT(*), FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1"
-            limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
-
-            expected_sql = (
-                "SELECT TOP 1000 * \n"
-                "FROM (SELECT COUNT(*) AS COUNT_1, "
-                "FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1)"
-                " AS inner_qry"
-            )
-            self.assertEqual(expected_sql, limited_sql)
-
-            test_sql = "SELECT COUNT(*), COUNT(*) FROM FOO_TABLE"
-            limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
-            expected_sql = (
-                "SELECT TOP 1000 * \n"
-                "FROM (SELECT COUNT(*) AS COUNT_1, COUNT(*) AS COUNT_2 FROM FOO_TABLE)"
-                " AS inner_qry"
-            )
-            self.assertEqual(expected_sql, limited_sql)
+        error_message = MssqlEngineSpec.extract_error_message(test_mssql_exception)
+        expected_message = (
+            "mssql error: All your SQL functions need to "
+            "have an alias on MSSQL. For example: SELECT COUNT(*) AS C1 FROM TABLE1"
+        )
+        self.assertEqual(expected_message, error_message)
 
-        db.session.delete(database)
-        db.session.commit()
+        test_mssql_exception = Exception(
+            '(8200, b"A correlated expression is invalid because it is not in a '
+            "GROUP BY clause.\\n\")'"
+        )
+        error_message = MssqlEngineSpec.extract_error_message(test_mssql_exception)
+        expected_message = "mssql error: " + MssqlEngineSpec._extract_error_message(
+            test_mssql_exception
+        )
+        self.assertEqual(expected_message, error_message)
 
     @mock.patch.object(
         MssqlEngineSpec, "pyodbc_rows_to_tuples", return_value="converted"