You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2024/02/12 22:21:08 UTC

(superset) 02/09: fix: column values with NaN (#26946)

This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 3.0
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 383b4d949fdffb6e86e5e3ab53fcaeeb9eb27301
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Fri Feb 2 12:56:02 2024 -0500

    fix: column values with NaN (#26946)
    
    (cherry picked from commit d8a98475036a4fba28b3d3eb508b3d1f3f5072aa)
---
 superset/models/helpers.py                      |  7 ++++++-
 tests/integration_tests/conftest.py             | 21 +++++++++++----------
 tests/integration_tests/datasource/api_tests.py | 10 ++++++++++
 tests/integration_tests/datasource_tests.py     |  9 ++++++++-
 4 files changed, 35 insertions(+), 12 deletions(-)

diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index f9b30c3d42..849f3db176 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -1332,7 +1332,10 @@ class ExploreMixin:  # pylint: disable=too-many-public-methods
         return and_(*l)
 
     def values_for_column(
-        self, column_name: str, limit: int = 10000, denormalize_column: bool = False
+        self,
+        column_name: str,
+        limit: int = 10000,
+        denormalize_column: bool = False,
     ) -> list[Any]:
         # denormalize column name before querying for values
         # unless disabled in the dataset configuration
@@ -1370,6 +1373,8 @@ class ExploreMixin:  # pylint: disable=too-many-public-methods
             sql = self.mutate_query_from_config(sql)
 
             df = pd.read_sql_query(sql=sql, con=engine)
+            # replace NaN with None to ensure it can be serialized to JSON
+            df = df.replace({np.nan: None})
             return df["column_values"].to_list()
 
     def get_timestamp_expression(
diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py
index 28da7b7913..4134514af3 100644
--- a/tests/integration_tests/conftest.py
+++ b/tests/integration_tests/conftest.py
@@ -296,25 +296,25 @@ def virtual_dataset():
     dataset = SqlaTable(
         table_name="virtual_dataset",
         sql=(
-            "SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5 "
+            "SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5, 1 as col6 "
             "UNION ALL "
-            "SELECT 1, 'b', 1.1, NULL, '2000-01-02 00:00:00' "
+            "SELECT 1, 'b', 1.1, NULL, '2000-01-02 00:00:00', NULL "
             "UNION ALL "
-            "SELECT 2 as col1, 'c' as col2, 1.2, NULL, '2000-01-03 00:00:00' "
+            "SELECT 2 as col1, 'c' as col2, 1.2, NULL, '2000-01-03 00:00:00', 3 "
             "UNION ALL "
-            "SELECT 3 as col1, 'd' as col2, 1.3, NULL, '2000-01-04 00:00:00' "
+            "SELECT 3 as col1, 'd' as col2, 1.3, NULL, '2000-01-04 00:00:00', 4 "
             "UNION ALL "
-            "SELECT 4 as col1, 'e' as col2, 1.4, NULL, '2000-01-05 00:00:00' "
+            "SELECT 4 as col1, 'e' as col2, 1.4, NULL, '2000-01-05 00:00:00', 5 "
             "UNION ALL "
-            "SELECT 5 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00' "
+            "SELECT 5 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00', 6 "
             "UNION ALL "
-            "SELECT 6 as col1, 'g' as col2, 1.6, NULL, '2000-01-07 00:00:00' "
+            "SELECT 6 as col1, 'g' as col2, 1.6, NULL, '2000-01-07 00:00:00', 7 "
             "UNION ALL "
-            "SELECT 7 as col1, 'h' as col2, 1.7, NULL, '2000-01-08 00:00:00' "
+            "SELECT 7 as col1, 'h' as col2, 1.7, NULL, '2000-01-08 00:00:00', 8 "
             "UNION ALL "
-            "SELECT 8 as col1, 'i' as col2, 1.8, NULL, '2000-01-09 00:00:00' "
+            "SELECT 8 as col1, 'i' as col2, 1.8, NULL, '2000-01-09 00:00:00', 9 "
             "UNION ALL "
-            "SELECT 9 as col1, 'j' as col2, 1.9, NULL, '2000-01-10 00:00:00' "
+            "SELECT 9 as col1, 'j' as col2, 1.9, NULL, '2000-01-10 00:00:00', 10"
         ),
         database=get_example_database(),
     )
@@ -324,6 +324,7 @@ def virtual_dataset():
     TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
     # Different database dialect datetime type is not consistent, so temporarily use varchar
     TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset)
+    TableColumn(column_name="col6", type="INTEGER", table=dataset)
 
     SqlMetric(metric_name="count", expression="count(*)", table=dataset)
     db.session.merge(dataset)
diff --git a/tests/integration_tests/datasource/api_tests.py b/tests/integration_tests/datasource/api_tests.py
index b6a6af105d..2605ec399f 100644
--- a/tests/integration_tests/datasource/api_tests.py
+++ b/tests/integration_tests/datasource/api_tests.py
@@ -72,6 +72,16 @@ class TestDatasourceApi(SupersetTestCase):
         response = json.loads(rv.data.decode("utf-8"))
         self.assertEqual(response["result"], [None])
 
+    @pytest.mark.usefixtures("app_context", "virtual_dataset")
+    def test_get_column_values_integers_with_nulls(self):
+        self.login(username="admin")
+        table = self.get_virtual_dataset()
+        rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col6/values/")
+        self.assertEqual(rv.status_code, 200)
+        response = json.loads(rv.data.decode("utf-8"))
+        for val in [1, None, 3, 4, 5, 6, 7, 8, 9, 10]:
+            assert val in response["result"]
+
     @pytest.mark.usefixtures("app_context", "virtual_dataset")
     def test_get_column_values_invalid_datasource_type(self):
         self.login(username="admin")
diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py
index 12293325d1..59ed9117e7 100644
--- a/tests/integration_tests/datasource_tests.py
+++ b/tests/integration_tests/datasource_tests.py
@@ -545,7 +545,14 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
         },
     )
     assert rv.status_code == 200
-    assert rv.json["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
+    assert rv.json["result"]["colnames"] == [
+        "col1",
+        "col2",
+        "col3",
+        "col4",
+        "col5",
+        "col6",
+    ]
     assert rv.json["result"]["rowcount"] == 1
 
     # empty results