You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2023/10/02 14:47:55 UTC
[superset] 03/05: fix(mysql): handle string typed decimal results (#24241)
This is an automated email from the ASF dual-hosted git repository.
michaelsmolina pushed a commit to branch 3.0
in repository https://gitbox.apache.org/repos/asf/superset.git
commit 615d7f5ccc7f46dcbdc075e378e60bd674855fbd
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Fri Sep 29 10:48:08 2023 -0700
fix(mysql): handle string typed decimal results (#24241)
(cherry picked from commit 7eab59af513ccccb3b1fed7aca5798c98c35fdb8)
---
superset/db_engine_specs/base.py | 29 ++++++++++++++++++-
superset/db_engine_specs/mysql.py | 6 +++-
tests/unit_tests/db_engine_specs/test_mysql.py | 40 ++++++++++++++++++++++++++
3 files changed, 73 insertions(+), 2 deletions(-)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index e7f8675423..6be3ab24b0 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -309,6 +309,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# engine-specific type mappings to check prior to the defaults
column_type_mappings: tuple[ColumnTypeMapping, ...] = ()
+ # type-specific functions to mutate values received from the database.
+ # Needed on certain databases that return values in an unexpected format
+ column_type_mutators: dict[TypeEngine, Callable[[Any], Any]] = {}
+
# Does database support join-free timeslot grouping
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
@@ -730,7 +734,30 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
try:
if cls.limit_method == LimitMethod.FETCH_MANY and limit:
return cursor.fetchmany(limit)
- return cursor.fetchall()
+ data = cursor.fetchall()
+ description = cursor.description or []
+ # Create a mapping between column name and a mutator function to normalize
+ # values with. The first two items in the description row are
+ # the column name and type.
+ column_mutators = {
+ row[0]: func
+ for row in description
+ if (
+ func := cls.column_type_mutators.get(
+ type(cls.get_sqla_column_type(cls.get_datatype(row[1])))
+ )
+ )
+ }
+ if column_mutators:
+ indexes = {row[0]: idx for idx, row in enumerate(description)}
+ for row_idx, row in enumerate(data):
+ new_row = list(row)
+ for col, func in column_mutators.items():
+ col_idx = indexes[col]
+ new_row[col_idx] = func(row[col_idx])
+ data[row_idx] = tuple(new_row)
+
+ return data
except Exception as ex:
raise cls.get_dbapi_mapped_exception(ex) from ex
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index e83e53e426..eaa7d9377d 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -16,8 +16,9 @@
# under the License.
import re
from datetime import datetime
+from decimal import Decimal
from re import Pattern
-from typing import Any, Optional
+from typing import Any, Callable, Optional
from urllib import parse
from flask_babel import gettext as __
@@ -125,6 +126,9 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
GenericDataType.STRING,
),
)
+ column_type_mutators: dict[types.TypeEngine, Callable[[Any], Any]] = {
+ DECIMAL: lambda val: Decimal(val) if isinstance(val, str) else val
+ }
_time_grain_expressions = {
None: "{col}",
diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py
index 89abf2321d..ed64347017 100644
--- a/tests/unit_tests/db_engine_specs/test_mysql.py
+++ b/tests/unit_tests/db_engine_specs/test_mysql.py
@@ -16,6 +16,7 @@
# under the License.
from datetime import datetime
+from decimal import Decimal
from typing import Any, Optional
from unittest.mock import Mock, patch
@@ -220,3 +221,42 @@ def test_get_schema_from_engine_params() -> None:
)
== "db1"
)
+
+
+@pytest.mark.parametrize(
+ "data,description,expected_result",
+ [
+ (
+ [("1.23456", "abc")],
+ [("dec", "decimal(12,6)"), ("str", "varchar(3)")],
+ [(Decimal("1.23456"), "abc")],
+ ),
+ (
+ [(Decimal("1.23456"), "abc")],
+ [("dec", "decimal(12,6)"), ("str", "varchar(3)")],
+ [(Decimal("1.23456"), "abc")],
+ ),
+ (
+ [(None, "abc")],
+ [("dec", "decimal(12,6)"), ("str", "varchar(3)")],
+ [(None, "abc")],
+ ),
+ (
+ [("1.23456", "abc")],
+ [("dec", "varchar(255)"), ("str", "varchar(3)")],
+ [("1.23456", "abc")],
+ ),
+ ],
+)
+def test_column_type_mutator(
+ data: list[tuple[Any, ...]],
+ description: list[Any],
+ expected_result: list[tuple[Any, ...]],
+):
+ from superset.db_engine_specs.mysql import MySQLEngineSpec as spec
+
+ mock_cursor = Mock()
+ mock_cursor.fetchall.return_value = data
+ mock_cursor.description = description
+
+ assert spec.fetch_data(mock_cursor) == expected_result