You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by el...@apache.org on 2023/04/24 19:54:04 UTC

[superset] branch master updated: feat: create dtype option for csv upload (#23716)

This is an automated email from the ASF dual-hosted git repository.

elizabeth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 71106cfd97 feat: create dtype option for csv upload (#23716)
71106cfd97 is described below

commit 71106cfd9791300fa3217bd46884381dde7e7b23
Author: Elizabeth Thompson <es...@gmail.com>
AuthorDate: Mon Apr 24 12:53:53 2023 -0700

    feat: create dtype option for csv upload (#23716)
---
 superset/db_engine_specs/redshift.py               | 40 ++++++++++++++
 .../form_view/csv_to_database_view/edit.html       |  4 ++
 superset/views/database/forms.py                   | 10 ++++
 superset/views/database/views.py                   |  3 ++
 tests/integration_tests/csv_upload_tests.py        | 44 +++++++++++++++-
 .../db_engine_specs/redshift_tests.py              | 61 ++++++++++++++++++++++
 6 files changed, 160 insertions(+), 2 deletions(-)

diff --git a/superset/db_engine_specs/redshift.py b/superset/db_engine_specs/redshift.py
index 7e2717d776..27b749e418 100644
--- a/superset/db_engine_specs/redshift.py
+++ b/superset/db_engine_specs/redshift.py
@@ -18,12 +18,16 @@ import logging
 import re
 from typing import Any, Dict, Optional, Pattern, Tuple
 
+import pandas as pd
 from flask_babel import gettext as __
+from sqlalchemy.types import NVARCHAR
 
 from superset.db_engine_specs.base import BasicParametersMixin
 from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
 from superset.errors import SupersetErrorType
+from superset.models.core import Database
 from superset.models.sql_lab import Query
+from superset.sql_parse import Table
 
 logger = logging.getLogger()
 
@@ -96,6 +100,42 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
         ),
     }
 
+    @classmethod
+    def df_to_sql(
+        cls,
+        database: Database,
+        table: Table,
+        df: pd.DataFrame,
+        to_sql_kwargs: Dict[str, Any],
+    ) -> None:
+        """
+        Upload data from a Pandas DataFrame to a database.
+
+        For regular engines this calls the `pandas.DataFrame.to_sql` method.
+        Overrides the base class to allow for pandas string types to be
+        used as nvarchar(max) columns, as redshift does not support
+        text data types.
+
+        Note this method does not create metadata for the table.
+
+        :param database: The database to upload the data to
+        :param table: The table to upload the data to
+        :param df: The dataframe with data to be uploaded
+        :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
+        """
+        to_sql_kwargs = to_sql_kwargs or {}
+        to_sql_kwargs["dtype"] = {
+            # uses the max size for redshift nvarchar(65335)
+            # the default object and string types create a varchar(256)
+            col_name: NVARCHAR(length=65535)
+            for col_name, type in zip(df.columns, df.dtypes)
+            if isinstance(type, pd.StringDtype)
+        }
+
+        super().df_to_sql(
+            df=df, database=database, table=table, to_sql_kwargs=to_sql_kwargs
+        )
+
     @staticmethod
     def _mutate_label(label: str) -> str:
         """
diff --git a/superset/templates/superset/form_view/csv_to_database_view/edit.html b/superset/templates/superset/form_view/csv_to_database_view/edit.html
index b09f9bd383..a0ae43792e 100644
--- a/superset/templates/superset/form_view/csv_to_database_view/edit.html
+++ b/superset/templates/superset/form_view/csv_to_database_view/edit.html
@@ -104,6 +104,10 @@
         {{ lib.render_field(form.overwrite_duplicate, begin_sep_label, end_sep_label, begin_sep_field,
         end_sep_field) }}
       </tr>
+      <tr>
+        {{ lib.render_field(form.dtype, begin_sep_label, end_sep_label, begin_sep_field,
+        end_sep_field) }}
+      </tr>
     {% endcall %}
     {% call csv_macros.render_collapsable_form_group("accordion3", "Rows") %}
       <tr>
diff --git a/superset/views/database/forms.py b/superset/views/database/forms.py
index 91ab38dc2f..99b64e38ab 100644
--- a/superset/views/database/forms.py
+++ b/superset/views/database/forms.py
@@ -140,6 +140,16 @@ class CsvToDatabaseForm(UploadToDatabaseForm):
         get_pk=lambda a: a.id,
         get_label=lambda a: a.database_name,
     )
+    dtype = StringField(
+        _("Column Data Types"),
+        description=_(
+            "A dictionary with column names and their data types"
+            " if you need to change the defaults."
+            ' Example: {"user_id":"integer"}'
+        ),
+        validators=[Optional()],
+        widget=BS3TextFieldWidget(),
+    )
     schema = StringField(
         _("Schema"),
         description=_("Select a schema if the database supports this"),
diff --git a/superset/views/database/views.py b/superset/views/database/views.py
index 037128ee16..a9137e59ed 100644
--- a/superset/views/database/views.py
+++ b/superset/views/database/views.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import io
+import json
 import os
 import tempfile
 import zipfile
@@ -189,6 +190,7 @@ class CsvToDatabaseView(CustomFormView):
             delimiter_input = form.otherInput.data
 
         try:
+            kwargs = {"dtype": json.loads(form.dtype.data)} if form.dtype.data else {}
             df = pd.concat(
                 pd.read_csv(
                     chunksize=1000,
@@ -208,6 +210,7 @@ class CsvToDatabaseView(CustomFormView):
                     skip_blank_lines=form.skip_blank_lines.data,
                     skipinitialspace=form.skip_initial_space.data,
                     skiprows=form.skiprows.data,
+                    **kwargs,
                 )
             )
 
diff --git a/tests/integration_tests/csv_upload_tests.py b/tests/integration_tests/csv_upload_tests.py
index d3b55f7bfe..97b83bb8fc 100644
--- a/tests/integration_tests/csv_upload_tests.py
+++ b/tests/integration_tests/csv_upload_tests.py
@@ -20,7 +20,7 @@ import json
 import logging
 import os
 import shutil
-from typing import Dict, Optional
+from typing import Dict, Optional, Union
 
 from unittest import mock
 
@@ -129,7 +129,12 @@ def get_upload_db():
     return db.session.query(Database).filter_by(database_name=CSV_UPLOAD_DATABASE).one()
 
 
-def upload_csv(filename: str, table_name: str, extra: Optional[Dict[str, str]] = None):
+def upload_csv(
+    filename: str,
+    table_name: str,
+    extra: Optional[Dict[str, str]] = None,
+    dtype: Union[str, None] = None,
+):
     csv_upload_db_id = get_upload_db().id
     schema = utils.get_example_default_schema()
     form_data = {
@@ -145,6 +150,8 @@ def upload_csv(filename: str, table_name: str, extra: Optional[Dict[str, str]] =
         form_data["schema"] = schema
     if extra:
         form_data.update(extra)
+    if dtype:
+        form_data["dtype"] = dtype
     return get_resp(test_client, "/csvtodatabaseview/form", data=form_data)
 
 
@@ -386,6 +393,39 @@ def test_import_csv(mock_event_logger):
         data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
         assert data == [("john", 1, "x"), ("paul", 2, None)]
 
+    # cleanup
+    with get_upload_db().get_sqla_engine_with_context() as engine:
+        engine.execute(f"DROP TABLE {full_table_name}")
+
+    # with dtype
+    upload_csv(
+        CSV_FILENAME1,
+        CSV_UPLOAD_TABLE,
+        dtype='{"a": "string", "b": "float64"}',
+    )
+
+    # you can change the type to something compatible, like an object to string
+    # or an int to a float
+    # file upload should work as normal
+    with test_db.get_sqla_engine_with_context() as engine:
+        data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
+        assert data == [("john", 1), ("paul", 2)]
+
+    # cleanup
+    with get_upload_db().get_sqla_engine_with_context() as engine:
+        engine.execute(f"DROP TABLE {full_table_name}")
+
+    # with dtype - wrong type
+    resp = upload_csv(
+        CSV_FILENAME1,
+        CSV_UPLOAD_TABLE,
+        dtype='{"a": "int"}',
+    )
+
+    # you cannot pass an incompatible dtype
+    fail_msg = f"Unable to upload CSV file {escaped_double_quotes(CSV_FILENAME1)} to table {escaped_double_quotes(CSV_UPLOAD_TABLE)}"
+    assert fail_msg in resp
+
 
 @pytest.mark.usefixtures("setup_csv_upload_with_context")
 @pytest.mark.usefixtures("create_excel_files")
diff --git a/tests/integration_tests/db_engine_specs/redshift_tests.py b/tests/integration_tests/db_engine_specs/redshift_tests.py
index cdfe8d16cb..2d46c73fca 100644
--- a/tests/integration_tests/db_engine_specs/redshift_tests.py
+++ b/tests/integration_tests/db_engine_specs/redshift_tests.py
@@ -14,11 +14,18 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import unittest.mock as mock
 from textwrap import dedent
 
+import numpy as np
+import pandas as pd
+from sqlalchemy.types import NVARCHAR
+
 from superset.db_engine_specs.redshift import RedshiftEngineSpec
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from superset.sql_parse import Table
 from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
+from tests.integration_tests.test_app import app
 
 
 class TestRedshiftDbEngineSpec(TestDbEngineSpec):
@@ -183,3 +190,57 @@ psql: error: could not connect to server: Operation timed out
                 },
             )
         ]
+
+    def test_df_to_sql_no_dtype(self):
+        mock_database = mock.MagicMock()
+        mock_database.get_df.return_value.empty = False
+        table_name = "foobar"
+        data = [
+            ("foo", "bar", pd.NA, None),
+            ("foo", "bar", pd.NA, True),
+            ("foo", "bar", pd.NA, None),
+        ]
+        numpy_dtype = [
+            ("id", "object"),
+            ("value", "object"),
+            ("num", "object"),
+            ("bool", "object"),
+        ]
+        column_names = ["id", "value", "num", "bool"]
+
+        test_array = np.array(data, dtype=numpy_dtype)
+
+        df = pd.DataFrame(test_array, columns=column_names)
+        df.to_sql = mock.MagicMock()
+
+        with app.app_context():
+            RedshiftEngineSpec.df_to_sql(
+                mock_database, Table(table=table_name), df, to_sql_kwargs={}
+            )
+
+        assert df.to_sql.call_args[1]["dtype"] == {}
+
+    def test_df_to_sql_with_string_dtype(self):
+        mock_database = mock.MagicMock()
+        mock_database.get_df.return_value.empty = False
+        table_name = "foobar"
+        data = [
+            ("foo", "bar", pd.NA, None),
+            ("foo", "bar", pd.NA, True),
+            ("foo", "bar", pd.NA, None),
+        ]
+        column_names = ["id", "value", "num", "bool"]
+
+        df = pd.DataFrame(data, columns=column_names)
+        df = df.astype(dtype={"value": "string"})
+        df.to_sql = mock.MagicMock()
+
+        with app.app_context():
+            RedshiftEngineSpec.df_to_sql(
+                mock_database, Table(table=table_name), df, to_sql_kwargs={}
+            )
+
+        # varchar string length should be 65535
+        dtype = df.to_sql.call_args[1]["dtype"]
+        assert isinstance(dtype["value"], NVARCHAR)
+        assert dtype["value"].length == 65535