You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@superset.apache.org by GitBox <gi...@apache.org> on 2018/02/04 04:22:11 UTC

[GitHub] mistercrunch closed pull request #4298: Refactor import csv

mistercrunch closed pull request #4298: Refactor import csv
URL: https://github.com/apache/incubator-superset/pull/4298
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index d26f633bbd..e55fa94829 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -133,13 +133,14 @@ def _allowed_file(filename):
             'table': table,
             'df': df,
             'name': form.name.data,
-            'con': create_engine(form.con.data, echo=False),
+            'con': create_engine(form.con.data.sqlalchemy_uri, echo=False),
             'schema': form.schema.data,
             'if_exists': form.if_exists.data,
             'index': form.index.data,
             'index_label': form.index_label.data,
             'chunksize': 10000,
         }
+
         BaseEngineSpec.df_to_db(**df_to_db_kwargs)
 
     @classmethod
diff --git a/superset/forms.py b/superset/forms.py
index a07790440f..cacb9067eb 100644
--- a/superset/forms.py
+++ b/superset/forms.py
@@ -10,14 +10,20 @@
 from flask_wtf.file import FileAllowed, FileField, FileRequired
 from wtforms import (
     BooleanField, IntegerField, SelectField, StringField)
+from wtforms.ext.sqlalchemy.fields import QuerySelectField
 from wtforms.validators import DataRequired, NumberRange, Optional
 
-from superset import app
+from superset import app, db
+from superset.models import core as models
 
 config = app.config
 
 
 class CsvToDatabaseForm(DynamicForm):
+    # pylint: disable=E0211
+    def all_db_items():
+        return db.session.query(models.Database)
+
     name = StringField(
         _('Table Name'),
         description=_('Name of table to be created from csv data.'),
@@ -28,12 +34,9 @@ class CsvToDatabaseForm(DynamicForm):
         description=_('Select a CSV file to be uploaded to a database.'),
         validators=[
             FileRequired(), FileAllowed(['csv'], _('CSV Files Only!'))])
-
-    con = SelectField(
-        _('Database'),
-        description=_('database in which to add above table.'),
-        validators=[DataRequired()],
-        choices=[])
+    con = QuerySelectField(
+         query_factory=all_db_items,
+         get_pk=lambda a: a.id, get_label=lambda a: a.database_name)
     sep = StringField(
         _('Delimiter'),
         description=_('Delimiter used by CSV file (for whitespace use \s+).'),
@@ -49,7 +52,6 @@ class CsvToDatabaseForm(DynamicForm):
             ('fail', _('Fail')), ('replace', _('Replace')),
             ('append', _('Append'))],
         validators=[DataRequired()])
-
     schema = StringField(
         _('Schema'),
         description=_('Specify a schema (if database flavour supports this).'),
diff --git a/superset/views/core.py b/superset/views/core.py
index db905f3a82..48767b8433 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -24,10 +24,11 @@
 from flask_babel import gettext as __
 from flask_babel import lazy_gettext as _
 import pandas as pd
+from six import text_type
 import sqlalchemy as sqla
 from sqlalchemy import create_engine
 from sqlalchemy.engine.url import make_url
-from sqlalchemy.exc import OperationalError
+from sqlalchemy.exc import IntegrityError, OperationalError
 from unidecode import unidecode
 from werkzeug.routing import BaseConverter
 from werkzeug.utils import secure_filename
@@ -163,8 +164,6 @@ def apply(self, query, func):  # noqa
         return query
 
 
-
-
 class DatabaseView(SupersetModelView, DeleteMixin, YamlExportMixin):  # noqa
     datamodel = SQLAInterface(models.Database)
 
@@ -319,49 +318,36 @@ def form_get(self, form):
         form.infer_datetime_format.data = True
         form.decimal.data = '.'
         form.if_exists.data = 'append'
-        all_datasources = (
-            db.session.query(
-                models.Database.sqlalchemy_uri,
-                models.Database.database_name)
-            .all()
-        )
-        form.con.choices += all_datasources
 
     def form_post(self, form):
-        def _upload_file(csv_file):
-            if csv_file and csv_file.filename:
-                filename = secure_filename(csv_file.filename)
-                csv_file.save(os.path.join(config['UPLOAD_FOLDER'], filename))
-                return filename
-
         csv_file = form.csv_file.data
-        _upload_file(csv_file)
-        table = SqlaTable(table_name=form.name.data)
-        database = (
-            db.session.query(models.Database)
-            .filter_by(sqlalchemy_uri=form.data.get('con'))
-            .one()
-        )
-        table.database = database
-        table.database_id = database.id
+        form.csv_file.data.filename = secure_filename(form.csv_file.data.filename)
+        csv_filename = form.csv_file.data.filename
         try:
-            database.db_engine_spec.create_table_from_csv(form, table)
+            csv_file.save(os.path.join(config['UPLOAD_FOLDER'], csv_filename))
+            table = SqlaTable(table_name=form.name.data)
+            table.database = form.data.get('con')
+            table.database_id = table.database.id
+            table.database.db_engine_spec.create_table_from_csv(form, table)
         except Exception as e:
-            os.remove(os.path.join(config['UPLOAD_FOLDER'], csv_file.filename))
-            flash(e, 'error')
-            return redirect('/tablemodelview/list/')
+            try:
+                os.remove(os.path.join(config['UPLOAD_FOLDER'], csv_filename))
+            except OSError:
+                pass
+            message = u'Table name {} already exists. Please pick another'.format(
+                    form.name.data) if isinstance(e, IntegrityError) else text_type(e)
+            flash(
+                message,
+                'danger')
+            return redirect('/csvtodatabaseview/form')
 
-        os.remove(os.path.join(config['UPLOAD_FOLDER'], csv_file.filename))
+        os.remove(os.path.join(config['UPLOAD_FOLDER'], csv_filename))
         # Go back to welcome page / splash screen
-        db_name = (
-            db.session.query(models.Database.database_name)
-            .filter_by(sqlalchemy_uri=form.data.get('con'))
-            .one()
-        )
-        message = _('CSV file "{0}" uploaded to table "{1}" in '
-                    'database "{2}"'.format(form.csv_file.data.filename,
+        db_name = table.database.database_name
+        message = _(u'CSV file "{0}" uploaded to table "{1}" in '
+                    'database "{2}"'.format(csv_filename,
                                             form.name.data,
-                                            db_name[0]))
+                                            db_name))
         flash(message, 'info')
         return redirect('/tablemodelview/list/')
 
diff --git a/tests/core_tests.py b/tests/core_tests.py
index a7edc4ec16..ca8f246012 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -803,20 +803,22 @@ def test_import_csv(self):
         test_file.write('john,1\n')
         test_file.write('paul,2\n')
         test_file.close()
-        main_db_uri = db.session.query(
-            models.Database.sqlalchemy_uri)\
-            .filter_by(database_name='main').all()
+        main_db_uri = (
+            db.session.query(models.Database)
+            .filter_by(database_name='main')
+            .all()
+        )
 
         test_file = open(filename, 'rb')
         form_data = {
             'csv_file': test_file,
             'sep': ',',
             'name': table_name,
-            'con': main_db_uri[0][0],
+            'con': main_db_uri[0].id,
             'if_exists': 'append',
             'index_label': 'test_label',
-            'mangle_dupe_cols': False}
-
+            'mangle_dupe_cols': False,
+        }
         url = '/databaseview/list/'
         add_datasource_page = self.get_resp(url)
         assert 'Upload a CSV' in add_datasource_page


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services