You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@superset.apache.org by GitBox <gi...@apache.org> on 2018/02/08 17:43:26 UTC

[GitHub] mistercrunch closed pull request #4351: use enum.Enum to rewrite querystatus

mistercrunch closed pull request #4351: use enum.Enum to rewrite querystatus
URL: https://github.com/apache/incubator-superset/pull/4351
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 6ccddbe79c..b42f060315 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -42,11 +42,11 @@ def query(self, query_obj):
         qry = qry.filter(Annotation.layer_id == query_obj['filter'][0]['val'])
         qry = qry.filter(Annotation.start_dttm >= query_obj['from_dttm'])
         qry = qry.filter(Annotation.end_dttm <= query_obj['to_dttm'])
-        status = QueryStatus.SUCCESS
+        status = QueryStatus.SUCCESS.value
         try:
             df = pd.read_sql_query(qry.statement, db.engine)
         except Exception as e:
-            status = QueryStatus.FAILED
+            status = QueryStatus.FAILED.value
             logging.exception(e)
             error_message = (
                 utils.error_msg_from_exception(e))
@@ -679,13 +679,13 @@ def _get_top_groups(self, df, dimensions):
     def query(self, query_obj):
         qry_start_dttm = datetime.now()
         sql = self.get_query_str(query_obj)
-        status = QueryStatus.SUCCESS
+        status = QueryStatus.SUCCESS.value
         error_message = None
         df = None
         try:
             df = self.database.get_df(sql, self.schema)
         except Exception as e:
-            status = QueryStatus.FAILED
+            status = QueryStatus.FAILED.value
             logging.exception(e)
             error_message = (
                 self.database.db_engine_spec.extract_error_message(e))
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index d26f633bbd..1d73fdd626 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -585,7 +585,7 @@ def handle_cursor(cls, cursor, query, session):
             stats = polled.get('stats', {})
 
             query = session.query(type(query)).filter_by(id=query.id).one()
-            if query.status == QueryStatus.STOPPED:
+            if query.status == QueryStatus.STOPPED.value:
                 cursor.cancel()
                 break
 
@@ -914,7 +914,7 @@ def handle_cursor(cls, cursor, query, session):
         job_id = None
         while polled.operationState in unfinished_states:
             query = session.query(type(query)).filter_by(id=query.id).one()
-            if query.status == QueryStatus.STOPPED:
+            if query.status == QueryStatus.STOPPED.value:
                 cursor.cancel()
                 break
 
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 948cf0d49e..5e60068dbc 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -292,7 +292,7 @@ def __init__(  # noqa
             df,
             query,
             duration,
-            status=QueryStatus.SUCCESS,
+            status=QueryStatus.SUCCESS.value,
             error_message=None):
         self.df = df
         self.query = query
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 44b692b915..53069b3c9c 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -35,7 +35,7 @@ class Query(Model):
     # Store the tmp table into the DB only if the user asks for it.
     tmp_table_name = Column(String(256))
     user_id = Column(Integer, ForeignKey('ab_user.id'), nullable=True)
-    status = Column(String(16), default=QueryStatus.PENDING)
+    status = Column(String(16), default=QueryStatus.PENDING.value)
     tab_name = Column(String(256))
     sql_editor_id = Column(String(256))
     schema = Column(String(256))
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 63225f3e2d..f2faf8ae6f 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -100,7 +100,7 @@ def get_sql_results(
         sesh = get_session(not ctask.request.called_directly)
         query = get_query(query_id, sesh)
         query.error_message = str(e)
-        query.status = QueryStatus.FAILED
+        query.status = QueryStatus.FAILED.value
         query.tmp_table_name = None
         sesh.commit()
         raise
@@ -127,7 +127,7 @@ def handle_error(msg):
             resolutions at: {}'.format(msg, troubleshooting_link) \
             if troubleshooting_link else msg
         query.error_message = msg
-        query.status = QueryStatus.FAILED
+        query.status = QueryStatus.FAILED.value
         query.tmp_table_name = None
         session.commit()
         payload.update({
@@ -150,7 +150,6 @@ def handle_error(msg):
             return handle_error(
                 'Only `SELECT` statements can be used with the CREATE TABLE '
                 'feature.')
-            return
         if not query.tmp_table_name:
             start_dttm = datetime.fromtimestamp(query.start_time)
             query.tmp_table_name = 'tmp_{}_table_{}'.format(
@@ -173,7 +172,7 @@ def handle_error(msg):
         return handle_error(msg)
 
     query.executed_sql = executed_sql
-    query.status = QueryStatus.RUNNING
+    query.status = QueryStatus.RUNNING.value
     query.start_running_time = utils.now_as_float()
     session.merge(query)
     session.commit()
@@ -215,7 +214,7 @@ def handle_error(msg):
         conn.commit()
         conn.close()
 
-    if query.status == utils.QueryStatus.STOPPED:
+    if query.status == utils.QueryStatus.STOPPED.value:
         return json.dumps(
             {
                 'query_id': query.id,
@@ -232,7 +231,7 @@ def handle_error(msg):
 
     query.rows = cdf.size
     query.progress = 100
-    query.status = QueryStatus.SUCCESS
+    query.status = QueryStatus.SUCCESS.value
     if query.select_as_cta:
         query.select_sql = '{}'.format(
             database.select_star(
diff --git a/superset/utils.py b/superset/utils.py
index 8224843213..9c61ead450 100644
--- a/superset/utils.py
+++ b/superset/utils.py
@@ -11,6 +11,7 @@
 from email.mime.multipart import MIMEMultipart
 from email.mime.text import MIMEText
 from email.utils import formatdate
+from enum import Enum, unique
 import functools
 import json
 import logging
@@ -532,7 +533,8 @@ def ping_connection(connection, branch):
             connection.should_close_with_result = save_should_close_with_result
 
 
-class QueryStatus(object):
+@unique
+class QueryStatus(Enum):
     """Enum-type class for query statuses"""
 
     STOPPED = 'stopped'
diff --git a/superset/views/core.py b/superset/views/core.py
index ec4cce1fb1..e31aa5eb3a 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -1041,7 +1041,7 @@ def generate_json(self, datasource_type, datasource_id, form_data,
             return json_error_response(utils.error_msg_from_exception(e))
 
         status = 200
-        if payload.get('status') == QueryStatus.FAILED:
+        if payload.get('status') == QueryStatus.FAILED.value:
             status = 400
 
         return json_success(viz_obj.json_dumps(payload), status=status)
@@ -1086,7 +1086,7 @@ def annotation_json(self, layer_id):
             logging.exception(e)
             return json_error_response(utils.error_msg_from_exception(e))
         status = 200
-        if payload.get('status') == QueryStatus.FAILED:
+        if payload.get('status') == QueryStatus.FAILED.value:
             status = 400
         return json_success(viz_obj.json_dumps(payload), status=status)
 
@@ -2214,7 +2214,7 @@ def stop_query(self):
                 db.session.query(Query)
                 .filter_by(client_id=client_id).one()
             )
-            query.status = utils.QueryStatus.STOPPED
+            query.status = utils.QueryStatus.STOPPED.value
             db.session.commit()
         except Exception:
             pass
@@ -2261,7 +2261,7 @@ def sql_json(self):
             select_as_cta=request.form.get('select_as_cta') == 'true',
             start_time=utils.now_as_float(),
             tab_name=request.form.get('tab'),
-            status=QueryStatus.PENDING if async else QueryStatus.RUNNING,
+            status=QueryStatus.PENDING.value if async else QueryStatus.RUNNING.value,
             sql_editor_id=request.form.get('sql_editor_id'),
             tmp_table_name=tmp_table_name,
             user_id=int(g.user.get_id()),
@@ -2292,7 +2292,7 @@ def sql_json(self):
                     'Tell your administrator to verify the availability of '
                     'the message queue.'
                 )
-                query.status = QueryStatus.FAILED
+                query.status = QueryStatus.FAILED.value
                 query.error_message = msg
                 session.commit()
                 return json_error_response('{}'.format(msg))
@@ -2320,7 +2320,7 @@ def sql_json(self):
         except Exception as e:
             logging.exception(e)
             return json_error_response('{}'.format(e))
-        if data.get('status') == QueryStatus.FAILED:
+        if data.get('status') == QueryStatus.FAILED.value:
             return json_error_response(payload=data)
         return json_success(payload)
 
diff --git a/superset/viz.py b/superset/viz.py
index bb0bcf604d..d0bcc5a16b 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -111,7 +111,7 @@ def get_df(self, query_obj=None):
         # If the datetime format is unix, the parse will use the corresponding
         # parsing logic.
         if df is None or df.empty:
-            self.status = utils.QueryStatus.FAILED
+            self.status = utils.QueryStatus.FAILED.value
             if not self.error_message:
                 self.error_message = 'No data.'
             return pd.DataFrame()
@@ -278,7 +278,7 @@ def get_payload(self, force=False):
                 logging.exception(e)
                 if not self.error_message:
                     self.error_message = str(e)
-                self.status = utils.QueryStatus.FAILED
+                self.status = utils.QueryStatus.FAILED.value
                 data = None
                 stacktrace = traceback.format_exc()
 
@@ -286,7 +286,7 @@ def get_payload(self, force=False):
                     data and
                     cache_key and
                     cache and
-                    self.status != utils.QueryStatus.FAILED):
+                    self.status != utils.QueryStatus.FAILED.value):
                 cached_dttm = datetime.utcnow().isoformat().split('.')[0]
                 try:
                     cache_value = self.json_dumps({
diff --git a/tests/celery_tests.py b/tests/celery_tests.py
index 591e793945..31fca878cc 100644
--- a/tests/celery_tests.py
+++ b/tests/celery_tests.py
@@ -187,7 +187,7 @@ def test_run_sync_query_cta(self):
             "SELECT name FROM ab_permission WHERE name='{}'".format(perm_name))
         result2 = self.run_sql(
             db_id, sql_where, '2', tmp_table='tmp_table_2', cta='true')
-        self.assertEqual(QueryStatus.SUCCESS, result2['query']['state'])
+        self.assertEqual(QueryStatus.SUCCESS.value, result2['query']['state'])
         self.assertEqual([], result2['data'])
         self.assertEqual([], result2['columns'])
         query2 = self.get_query_by_id(result2['query']['serverId'])
@@ -203,12 +203,12 @@ def test_run_sync_query_cta_no_data(self):
         sql_empty_result = 'SELECT * FROM ab_user WHERE id=666'
         result3 = self.run_sql(
             db_id, sql_empty_result, '3', tmp_table='tmp_table_3', cta='true')
-        self.assertEqual(QueryStatus.SUCCESS, result3['query']['state'])
+        self.assertEqual(QueryStatus.SUCCESS.value, result3['query']['state'])
         self.assertEqual([], result3['data'])
         self.assertEqual([], result3['columns'])
 
         query3 = self.get_query_by_id(result3['query']['serverId'])
-        self.assertEqual(QueryStatus.SUCCESS, query3.status)
+        self.assertEqual(QueryStatus.SUCCESS.value, query3.status)
 
     def test_run_async_query(self):
         main_db = self.get_main_database(db.session)
@@ -218,15 +218,15 @@ def test_run_async_query(self):
             main_db.id, sql_where, '4', async='true', tmp_table='tmp_async_1',
             cta='true')
         assert result['query']['state'] in (
-            QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS)
+            QueryStatus.PENDING.value, QueryStatus.RUNNING.value, QueryStatus.SUCCESS.value)
 
         time.sleep(1)
 
         query = self.get_query_by_id(result['query']['serverId'])
         df = pd.read_sql_query(query.select_sql, con=eng)
-        self.assertEqual(QueryStatus.SUCCESS, query.status)
+        self.assertEqual(QueryStatus.SUCCESS.value, query.status)
         self.assertEqual([{'name': 'Admin'}], df.to_dict(orient='records'))
-        self.assertEqual(QueryStatus.SUCCESS, query.status)
+        self.assertEqual(QueryStatus.SUCCESS.value, query.status)
         self.assertTrue('FROM tmp_async_1' in query.select_sql)
         self.assertTrue('LIMIT 666' in query.select_sql)
         self.assertEqual(
diff --git a/tests/viz_tests.py b/tests/viz_tests.py
index abf29adb62..16addab28b 100644
--- a/tests/viz_tests.py
+++ b/tests/viz_tests.py
@@ -48,7 +48,7 @@ def test_get_df_returns_empty_df(self):
         self.assertEqual(type(result), pd.DataFrame)
         self.assertTrue(result.empty)
         self.assertEqual(test_viz.error_message, 'No data.')
-        self.assertEqual(test_viz.status, utils.QueryStatus.FAILED)
+        self.assertEqual(test_viz.status, utils.QueryStatus.FAILED.value)
 
     def test_get_df_handles_dttm_col(self):
         datasource = Mock()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services