You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by el...@apache.org on 2023/04/24 19:54:04 UTC
[superset] branch master updated: feat: create dtype option for csv upload (#23716)
This is an automated email from the ASF dual-hosted git repository.
elizabeth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 71106cfd97 feat: create dtype option for csv upload (#23716)
71106cfd97 is described below
commit 71106cfd9791300fa3217bd46884381dde7e7b23
Author: Elizabeth Thompson <es...@gmail.com>
AuthorDate: Mon Apr 24 12:53:53 2023 -0700
feat: create dtype option for csv upload (#23716)
---
superset/db_engine_specs/redshift.py | 40 ++++++++++++++
.../form_view/csv_to_database_view/edit.html | 4 ++
superset/views/database/forms.py | 10 ++++
superset/views/database/views.py | 3 ++
tests/integration_tests/csv_upload_tests.py | 44 +++++++++++++++-
.../db_engine_specs/redshift_tests.py | 61 ++++++++++++++++++++++
6 files changed, 160 insertions(+), 2 deletions(-)
diff --git a/superset/db_engine_specs/redshift.py b/superset/db_engine_specs/redshift.py
index 7e2717d776..27b749e418 100644
--- a/superset/db_engine_specs/redshift.py
+++ b/superset/db_engine_specs/redshift.py
@@ -18,12 +18,16 @@ import logging
import re
from typing import Any, Dict, Optional, Pattern, Tuple
+import pandas as pd
from flask_babel import gettext as __
+from sqlalchemy.types import NVARCHAR
from superset.db_engine_specs.base import BasicParametersMixin
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
from superset.errors import SupersetErrorType
+from superset.models.core import Database
from superset.models.sql_lab import Query
+from superset.sql_parse import Table
logger = logging.getLogger()
@@ -96,6 +100,42 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
),
}
+ @classmethod
+ def df_to_sql(
+ cls,
+ database: Database,
+ table: Table,
+ df: pd.DataFrame,
+ to_sql_kwargs: Dict[str, Any],
+ ) -> None:
+ """
+ Upload data from a Pandas DataFrame to a database.
+
+ For regular engines this calls the `pandas.DataFrame.to_sql` method.
+ Overrides the base class to allow for pandas string types to be
+ used as nvarchar(max) columns, as redshift does not support
+ text data types.
+
+ Note this method does not create metadata for the table.
+
+ :param database: The database to upload the data to
+ :param table: The table to upload the data to
+ :param df: The dataframe with data to be uploaded
+ :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
+ """
+ to_sql_kwargs = to_sql_kwargs or {}
+ to_sql_kwargs["dtype"] = {
+ # uses the max size for redshift nvarchar(65335)
+ # the default object and string types create a varchar(256)
+ col_name: NVARCHAR(length=65535)
+ for col_name, type in zip(df.columns, df.dtypes)
+ if isinstance(type, pd.StringDtype)
+ }
+
+ super().df_to_sql(
+ df=df, database=database, table=table, to_sql_kwargs=to_sql_kwargs
+ )
+
@staticmethod
def _mutate_label(label: str) -> str:
"""
diff --git a/superset/templates/superset/form_view/csv_to_database_view/edit.html b/superset/templates/superset/form_view/csv_to_database_view/edit.html
index b09f9bd383..a0ae43792e 100644
--- a/superset/templates/superset/form_view/csv_to_database_view/edit.html
+++ b/superset/templates/superset/form_view/csv_to_database_view/edit.html
@@ -104,6 +104,10 @@
{{ lib.render_field(form.overwrite_duplicate, begin_sep_label, end_sep_label, begin_sep_field,
end_sep_field) }}
</tr>
+ <tr>
+ {{ lib.render_field(form.dtype, begin_sep_label, end_sep_label, begin_sep_field,
+ end_sep_field) }}
+ </tr>
{% endcall %}
{% call csv_macros.render_collapsable_form_group("accordion3", "Rows") %}
<tr>
diff --git a/superset/views/database/forms.py b/superset/views/database/forms.py
index 91ab38dc2f..99b64e38ab 100644
--- a/superset/views/database/forms.py
+++ b/superset/views/database/forms.py
@@ -140,6 +140,16 @@ class CsvToDatabaseForm(UploadToDatabaseForm):
get_pk=lambda a: a.id,
get_label=lambda a: a.database_name,
)
+ dtype = StringField(
+ _("Column Data Types"),
+ description=_(
+ "A dictionary with column names and their data types"
+ " if you need to change the defaults."
+ ' Example: {"user_id":"integer"}'
+ ),
+ validators=[Optional()],
+ widget=BS3TextFieldWidget(),
+ )
schema = StringField(
_("Schema"),
description=_("Select a schema if the database supports this"),
diff --git a/superset/views/database/views.py b/superset/views/database/views.py
index 037128ee16..a9137e59ed 100644
--- a/superset/views/database/views.py
+++ b/superset/views/database/views.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import io
+import json
import os
import tempfile
import zipfile
@@ -189,6 +190,7 @@ class CsvToDatabaseView(CustomFormView):
delimiter_input = form.otherInput.data
try:
+ kwargs = {"dtype": json.loads(form.dtype.data)} if form.dtype.data else {}
df = pd.concat(
pd.read_csv(
chunksize=1000,
@@ -208,6 +210,7 @@ class CsvToDatabaseView(CustomFormView):
skip_blank_lines=form.skip_blank_lines.data,
skipinitialspace=form.skip_initial_space.data,
skiprows=form.skiprows.data,
+ **kwargs,
)
)
diff --git a/tests/integration_tests/csv_upload_tests.py b/tests/integration_tests/csv_upload_tests.py
index d3b55f7bfe..97b83bb8fc 100644
--- a/tests/integration_tests/csv_upload_tests.py
+++ b/tests/integration_tests/csv_upload_tests.py
@@ -20,7 +20,7 @@ import json
import logging
import os
import shutil
-from typing import Dict, Optional
+from typing import Dict, Optional, Union
from unittest import mock
@@ -129,7 +129,12 @@ def get_upload_db():
return db.session.query(Database).filter_by(database_name=CSV_UPLOAD_DATABASE).one()
-def upload_csv(filename: str, table_name: str, extra: Optional[Dict[str, str]] = None):
+def upload_csv(
+ filename: str,
+ table_name: str,
+ extra: Optional[Dict[str, str]] = None,
+ dtype: Union[str, None] = None,
+):
csv_upload_db_id = get_upload_db().id
schema = utils.get_example_default_schema()
form_data = {
@@ -145,6 +150,8 @@ def upload_csv(filename: str, table_name: str, extra: Optional[Dict[str, str]] =
form_data["schema"] = schema
if extra:
form_data.update(extra)
+ if dtype:
+ form_data["dtype"] = dtype
return get_resp(test_client, "/csvtodatabaseview/form", data=form_data)
@@ -386,6 +393,39 @@ def test_import_csv(mock_event_logger):
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
assert data == [("john", 1, "x"), ("paul", 2, None)]
+ # cleanup
+ with get_upload_db().get_sqla_engine_with_context() as engine:
+ engine.execute(f"DROP TABLE {full_table_name}")
+
+ # with dtype
+ upload_csv(
+ CSV_FILENAME1,
+ CSV_UPLOAD_TABLE,
+ dtype='{"a": "string", "b": "float64"}',
+ )
+
+ # you can change the type to something compatible, like an object to string
+ # or an int to a float
+ # file upload should work as normal
+ with test_db.get_sqla_engine_with_context() as engine:
+ data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
+ assert data == [("john", 1), ("paul", 2)]
+
+ # cleanup
+ with get_upload_db().get_sqla_engine_with_context() as engine:
+ engine.execute(f"DROP TABLE {full_table_name}")
+
+ # with dtype - wrong type
+ resp = upload_csv(
+ CSV_FILENAME1,
+ CSV_UPLOAD_TABLE,
+ dtype='{"a": "int"}',
+ )
+
+ # you cannot pass an incompatible dtype
+ fail_msg = f"Unable to upload CSV file {escaped_double_quotes(CSV_FILENAME1)} to table {escaped_double_quotes(CSV_UPLOAD_TABLE)}"
+ assert fail_msg in resp
+
@pytest.mark.usefixtures("setup_csv_upload_with_context")
@pytest.mark.usefixtures("create_excel_files")
diff --git a/tests/integration_tests/db_engine_specs/redshift_tests.py b/tests/integration_tests/db_engine_specs/redshift_tests.py
index cdfe8d16cb..2d46c73fca 100644
--- a/tests/integration_tests/db_engine_specs/redshift_tests.py
+++ b/tests/integration_tests/db_engine_specs/redshift_tests.py
@@ -14,11 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import unittest.mock as mock
from textwrap import dedent
+import numpy as np
+import pandas as pd
+from sqlalchemy.types import NVARCHAR
+
from superset.db_engine_specs.redshift import RedshiftEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from superset.sql_parse import Table
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
+from tests.integration_tests.test_app import app
class TestRedshiftDbEngineSpec(TestDbEngineSpec):
@@ -183,3 +190,57 @@ psql: error: could not connect to server: Operation timed out
},
)
]
+
+ def test_df_to_sql_no_dtype(self):
+ mock_database = mock.MagicMock()
+ mock_database.get_df.return_value.empty = False
+ table_name = "foobar"
+ data = [
+ ("foo", "bar", pd.NA, None),
+ ("foo", "bar", pd.NA, True),
+ ("foo", "bar", pd.NA, None),
+ ]
+ numpy_dtype = [
+ ("id", "object"),
+ ("value", "object"),
+ ("num", "object"),
+ ("bool", "object"),
+ ]
+ column_names = ["id", "value", "num", "bool"]
+
+ test_array = np.array(data, dtype=numpy_dtype)
+
+ df = pd.DataFrame(test_array, columns=column_names)
+ df.to_sql = mock.MagicMock()
+
+ with app.app_context():
+ RedshiftEngineSpec.df_to_sql(
+ mock_database, Table(table=table_name), df, to_sql_kwargs={}
+ )
+
+ assert df.to_sql.call_args[1]["dtype"] == {}
+
+ def test_df_to_sql_with_string_dtype(self):
+ mock_database = mock.MagicMock()
+ mock_database.get_df.return_value.empty = False
+ table_name = "foobar"
+ data = [
+ ("foo", "bar", pd.NA, None),
+ ("foo", "bar", pd.NA, True),
+ ("foo", "bar", pd.NA, None),
+ ]
+ column_names = ["id", "value", "num", "bool"]
+
+ df = pd.DataFrame(data, columns=column_names)
+ df = df.astype(dtype={"value": "string"})
+ df.to_sql = mock.MagicMock()
+
+ with app.app_context():
+ RedshiftEngineSpec.df_to_sql(
+ mock_database, Table(table=table_name), df, to_sql_kwargs={}
+ )
+
+ # varchar string length should be 65535
+ dtype = df.to_sql.call_args[1]["dtype"]
+ assert isinstance(dtype["value"], NVARCHAR)
+ assert dtype["value"].length == 65535