You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2023/10/02 14:47:55 UTC

[superset] 03/05: fix(mysql): handle string typed decimal results (#24241)

This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 3.0
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 615d7f5ccc7f46dcbdc075e378e60bd674855fbd
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Fri Sep 29 10:48:08 2023 -0700

    fix(mysql): handle string typed decimal results (#24241)
    
    (cherry picked from commit 7eab59af513ccccb3b1fed7aca5798c98c35fdb8)
---
 superset/db_engine_specs/base.py               | 29 ++++++++++++++++++-
 superset/db_engine_specs/mysql.py              |  6 +++-
 tests/unit_tests/db_engine_specs/test_mysql.py | 40 ++++++++++++++++++++++++++
 3 files changed, 73 insertions(+), 2 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index e7f8675423..6be3ab24b0 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -309,6 +309,10 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     # engine-specific type mappings to check prior to the defaults
     column_type_mappings: tuple[ColumnTypeMapping, ...] = ()
 
+    # type-specific functions to mutate values received from the database.
+    # Needed on certain databases that return values in an unexpected format
+    column_type_mutators: dict[TypeEngine, Callable[[Any], Any]] = {}
+
     # Does database support join-free timeslot grouping
     time_groupby_inline = False
     limit_method = LimitMethod.FORCE_LIMIT
@@ -730,7 +734,30 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         try:
             if cls.limit_method == LimitMethod.FETCH_MANY and limit:
                 return cursor.fetchmany(limit)
-            return cursor.fetchall()
+            data = cursor.fetchall()
+            description = cursor.description or []
+            # Create a mapping between column name and a mutator function to normalize
+            # values with. The first two items in the description row are
+            # the column name and type.
+            column_mutators = {
+                row[0]: func
+                for row in description
+                if (
+                    func := cls.column_type_mutators.get(
+                        type(cls.get_sqla_column_type(cls.get_datatype(row[1])))
+                    )
+                )
+            }
+            if column_mutators:
+                indexes = {row[0]: idx for idx, row in enumerate(description)}
+                for row_idx, row in enumerate(data):
+                    new_row = list(row)
+                    for col, func in column_mutators.items():
+                        col_idx = indexes[col]
+                        new_row[col_idx] = func(row[col_idx])
+                    data[row_idx] = tuple(new_row)
+
+            return data
         except Exception as ex:
             raise cls.get_dbapi_mapped_exception(ex) from ex
 
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index e83e53e426..eaa7d9377d 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -16,8 +16,9 @@
 # under the License.
 import re
 from datetime import datetime
+from decimal import Decimal
 from re import Pattern
-from typing import Any, Optional
+from typing import Any, Callable, Optional
 from urllib import parse
 
 from flask_babel import gettext as __
@@ -125,6 +126,9 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
             GenericDataType.STRING,
         ),
     )
+    column_type_mutators: dict[types.TypeEngine, Callable[[Any], Any]] = {
+        DECIMAL: lambda val: Decimal(val) if isinstance(val, str) else val
+    }
 
     _time_grain_expressions = {
         None: "{col}",
diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py
index 89abf2321d..ed64347017 100644
--- a/tests/unit_tests/db_engine_specs/test_mysql.py
+++ b/tests/unit_tests/db_engine_specs/test_mysql.py
@@ -16,6 +16,7 @@
 # under the License.
 
 from datetime import datetime
+from decimal import Decimal
 from typing import Any, Optional
 from unittest.mock import Mock, patch
 
@@ -220,3 +221,42 @@ def test_get_schema_from_engine_params() -> None:
         )
         == "db1"
     )
+
+
+@pytest.mark.parametrize(
+    "data,description,expected_result",
+    [
+        (
+            [("1.23456", "abc")],
+            [("dec", "decimal(12,6)"), ("str", "varchar(3)")],
+            [(Decimal("1.23456"), "abc")],
+        ),
+        (
+            [(Decimal("1.23456"), "abc")],
+            [("dec", "decimal(12,6)"), ("str", "varchar(3)")],
+            [(Decimal("1.23456"), "abc")],
+        ),
+        (
+            [(None, "abc")],
+            [("dec", "decimal(12,6)"), ("str", "varchar(3)")],
+            [(None, "abc")],
+        ),
+        (
+            [("1.23456", "abc")],
+            [("dec", "varchar(255)"), ("str", "varchar(3)")],
+            [("1.23456", "abc")],
+        ),
+    ],
+)
+def test_column_type_mutator(
+    data: list[tuple[Any, ...]],
+    description: list[Any],
+    expected_result: list[tuple[Any, ...]],
+):
+    from superset.db_engine_specs.mysql import MySQLEngineSpec as spec
+
+    mock_cursor = Mock()
+    mock_cursor.fetchall.return_value = data
+    mock_cursor.description = description
+
+    assert spec.fetch_data(mock_cursor) == expected_result