You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@superset.apache.org by GitBox <gi...@apache.org> on 2018/02/08 17:43:26 UTC
[GitHub] mistercrunch closed pull request #4351: use enum.Enum to rewrite querystatus
mistercrunch closed pull request #4351: use enum.Enum to rewrite querystatus
URL: https://github.com/apache/incubator-superset/pull/4351
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 6ccddbe79c..b42f060315 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -42,11 +42,11 @@ def query(self, query_obj):
qry = qry.filter(Annotation.layer_id == query_obj['filter'][0]['val'])
qry = qry.filter(Annotation.start_dttm >= query_obj['from_dttm'])
qry = qry.filter(Annotation.end_dttm <= query_obj['to_dttm'])
- status = QueryStatus.SUCCESS
+ status = QueryStatus.SUCCESS.value
try:
df = pd.read_sql_query(qry.statement, db.engine)
except Exception as e:
- status = QueryStatus.FAILED
+ status = QueryStatus.FAILED.value
logging.exception(e)
error_message = (
utils.error_msg_from_exception(e))
@@ -679,13 +679,13 @@ def _get_top_groups(self, df, dimensions):
def query(self, query_obj):
qry_start_dttm = datetime.now()
sql = self.get_query_str(query_obj)
- status = QueryStatus.SUCCESS
+ status = QueryStatus.SUCCESS.value
error_message = None
df = None
try:
df = self.database.get_df(sql, self.schema)
except Exception as e:
- status = QueryStatus.FAILED
+ status = QueryStatus.FAILED.value
logging.exception(e)
error_message = (
self.database.db_engine_spec.extract_error_message(e))
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index d26f633bbd..1d73fdd626 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -585,7 +585,7 @@ def handle_cursor(cls, cursor, query, session):
stats = polled.get('stats', {})
query = session.query(type(query)).filter_by(id=query.id).one()
- if query.status == QueryStatus.STOPPED:
+ if query.status == QueryStatus.STOPPED.value:
cursor.cancel()
break
@@ -914,7 +914,7 @@ def handle_cursor(cls, cursor, query, session):
job_id = None
while polled.operationState in unfinished_states:
query = session.query(type(query)).filter_by(id=query.id).one()
- if query.status == QueryStatus.STOPPED:
+ if query.status == QueryStatus.STOPPED.value:
cursor.cancel()
break
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 948cf0d49e..5e60068dbc 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -292,7 +292,7 @@ def __init__( # noqa
df,
query,
duration,
- status=QueryStatus.SUCCESS,
+ status=QueryStatus.SUCCESS.value,
error_message=None):
self.df = df
self.query = query
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 44b692b915..53069b3c9c 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -35,7 +35,7 @@ class Query(Model):
# Store the tmp table into the DB only if the user asks for it.
tmp_table_name = Column(String(256))
user_id = Column(Integer, ForeignKey('ab_user.id'), nullable=True)
- status = Column(String(16), default=QueryStatus.PENDING)
+ status = Column(String(16), default=QueryStatus.PENDING.value)
tab_name = Column(String(256))
sql_editor_id = Column(String(256))
schema = Column(String(256))
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 63225f3e2d..f2faf8ae6f 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -100,7 +100,7 @@ def get_sql_results(
sesh = get_session(not ctask.request.called_directly)
query = get_query(query_id, sesh)
query.error_message = str(e)
- query.status = QueryStatus.FAILED
+ query.status = QueryStatus.FAILED.value
query.tmp_table_name = None
sesh.commit()
raise
@@ -127,7 +127,7 @@ def handle_error(msg):
resolutions at: {}'.format(msg, troubleshooting_link) \
if troubleshooting_link else msg
query.error_message = msg
- query.status = QueryStatus.FAILED
+ query.status = QueryStatus.FAILED.value
query.tmp_table_name = None
session.commit()
payload.update({
@@ -150,7 +150,6 @@ def handle_error(msg):
return handle_error(
'Only `SELECT` statements can be used with the CREATE TABLE '
'feature.')
- return
if not query.tmp_table_name:
start_dttm = datetime.fromtimestamp(query.start_time)
query.tmp_table_name = 'tmp_{}_table_{}'.format(
@@ -173,7 +172,7 @@ def handle_error(msg):
return handle_error(msg)
query.executed_sql = executed_sql
- query.status = QueryStatus.RUNNING
+ query.status = QueryStatus.RUNNING.value
query.start_running_time = utils.now_as_float()
session.merge(query)
session.commit()
@@ -215,7 +214,7 @@ def handle_error(msg):
conn.commit()
conn.close()
- if query.status == utils.QueryStatus.STOPPED:
+ if query.status == utils.QueryStatus.STOPPED.value:
return json.dumps(
{
'query_id': query.id,
@@ -232,7 +231,7 @@ def handle_error(msg):
query.rows = cdf.size
query.progress = 100
- query.status = QueryStatus.SUCCESS
+ query.status = QueryStatus.SUCCESS.value
if query.select_as_cta:
query.select_sql = '{}'.format(
database.select_star(
diff --git a/superset/utils.py b/superset/utils.py
index 8224843213..9c61ead450 100644
--- a/superset/utils.py
+++ b/superset/utils.py
@@ -11,6 +11,7 @@
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.utils import formatdate
+from enum import Enum, unique
import functools
import json
import logging
@@ -532,7 +533,8 @@ def ping_connection(connection, branch):
connection.should_close_with_result = save_should_close_with_result
-class QueryStatus(object):
+@unique
+class QueryStatus(Enum):
"""Enum-type class for query statuses"""
STOPPED = 'stopped'
diff --git a/superset/views/core.py b/superset/views/core.py
index ec4cce1fb1..e31aa5eb3a 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -1041,7 +1041,7 @@ def generate_json(self, datasource_type, datasource_id, form_data,
return json_error_response(utils.error_msg_from_exception(e))
status = 200
- if payload.get('status') == QueryStatus.FAILED:
+ if payload.get('status') == QueryStatus.FAILED.value:
status = 400
return json_success(viz_obj.json_dumps(payload), status=status)
@@ -1086,7 +1086,7 @@ def annotation_json(self, layer_id):
logging.exception(e)
return json_error_response(utils.error_msg_from_exception(e))
status = 200
- if payload.get('status') == QueryStatus.FAILED:
+ if payload.get('status') == QueryStatus.FAILED.value:
status = 400
return json_success(viz_obj.json_dumps(payload), status=status)
@@ -2214,7 +2214,7 @@ def stop_query(self):
db.session.query(Query)
.filter_by(client_id=client_id).one()
)
- query.status = utils.QueryStatus.STOPPED
+ query.status = utils.QueryStatus.STOPPED.value
db.session.commit()
except Exception:
pass
@@ -2261,7 +2261,7 @@ def sql_json(self):
select_as_cta=request.form.get('select_as_cta') == 'true',
start_time=utils.now_as_float(),
tab_name=request.form.get('tab'),
- status=QueryStatus.PENDING if async else QueryStatus.RUNNING,
+ status=QueryStatus.PENDING.value if async else QueryStatus.RUNNING.value,
sql_editor_id=request.form.get('sql_editor_id'),
tmp_table_name=tmp_table_name,
user_id=int(g.user.get_id()),
@@ -2292,7 +2292,7 @@ def sql_json(self):
'Tell your administrator to verify the availability of '
'the message queue.'
)
- query.status = QueryStatus.FAILED
+ query.status = QueryStatus.FAILED.value
query.error_message = msg
session.commit()
return json_error_response('{}'.format(msg))
@@ -2320,7 +2320,7 @@ def sql_json(self):
except Exception as e:
logging.exception(e)
return json_error_response('{}'.format(e))
- if data.get('status') == QueryStatus.FAILED:
+ if data.get('status') == QueryStatus.FAILED.value:
return json_error_response(payload=data)
return json_success(payload)
diff --git a/superset/viz.py b/superset/viz.py
index bb0bcf604d..d0bcc5a16b 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -111,7 +111,7 @@ def get_df(self, query_obj=None):
# If the datetime format is unix, the parse will use the corresponding
# parsing logic.
if df is None or df.empty:
- self.status = utils.QueryStatus.FAILED
+ self.status = utils.QueryStatus.FAILED.value
if not self.error_message:
self.error_message = 'No data.'
return pd.DataFrame()
@@ -278,7 +278,7 @@ def get_payload(self, force=False):
logging.exception(e)
if not self.error_message:
self.error_message = str(e)
- self.status = utils.QueryStatus.FAILED
+ self.status = utils.QueryStatus.FAILED.value
data = None
stacktrace = traceback.format_exc()
@@ -286,7 +286,7 @@ def get_payload(self, force=False):
data and
cache_key and
cache and
- self.status != utils.QueryStatus.FAILED):
+ self.status != utils.QueryStatus.FAILED.value):
cached_dttm = datetime.utcnow().isoformat().split('.')[0]
try:
cache_value = self.json_dumps({
diff --git a/tests/celery_tests.py b/tests/celery_tests.py
index 591e793945..31fca878cc 100644
--- a/tests/celery_tests.py
+++ b/tests/celery_tests.py
@@ -187,7 +187,7 @@ def test_run_sync_query_cta(self):
"SELECT name FROM ab_permission WHERE name='{}'".format(perm_name))
result2 = self.run_sql(
db_id, sql_where, '2', tmp_table='tmp_table_2', cta='true')
- self.assertEqual(QueryStatus.SUCCESS, result2['query']['state'])
+ self.assertEqual(QueryStatus.SUCCESS.value, result2['query']['state'])
self.assertEqual([], result2['data'])
self.assertEqual([], result2['columns'])
query2 = self.get_query_by_id(result2['query']['serverId'])
@@ -203,12 +203,12 @@ def test_run_sync_query_cta_no_data(self):
sql_empty_result = 'SELECT * FROM ab_user WHERE id=666'
result3 = self.run_sql(
db_id, sql_empty_result, '3', tmp_table='tmp_table_3', cta='true')
- self.assertEqual(QueryStatus.SUCCESS, result3['query']['state'])
+ self.assertEqual(QueryStatus.SUCCESS.value, result3['query']['state'])
self.assertEqual([], result3['data'])
self.assertEqual([], result3['columns'])
query3 = self.get_query_by_id(result3['query']['serverId'])
- self.assertEqual(QueryStatus.SUCCESS, query3.status)
+ self.assertEqual(QueryStatus.SUCCESS.value, query3.status)
def test_run_async_query(self):
main_db = self.get_main_database(db.session)
@@ -218,15 +218,15 @@ def test_run_async_query(self):
main_db.id, sql_where, '4', async='true', tmp_table='tmp_async_1',
cta='true')
assert result['query']['state'] in (
- QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS)
+ QueryStatus.PENDING.value, QueryStatus.RUNNING.value, QueryStatus.SUCCESS.value)
time.sleep(1)
query = self.get_query_by_id(result['query']['serverId'])
df = pd.read_sql_query(query.select_sql, con=eng)
- self.assertEqual(QueryStatus.SUCCESS, query.status)
+ self.assertEqual(QueryStatus.SUCCESS.value, query.status)
self.assertEqual([{'name': 'Admin'}], df.to_dict(orient='records'))
- self.assertEqual(QueryStatus.SUCCESS, query.status)
+ self.assertEqual(QueryStatus.SUCCESS.value, query.status)
self.assertTrue('FROM tmp_async_1' in query.select_sql)
self.assertTrue('LIMIT 666' in query.select_sql)
self.assertEqual(
diff --git a/tests/viz_tests.py b/tests/viz_tests.py
index abf29adb62..16addab28b 100644
--- a/tests/viz_tests.py
+++ b/tests/viz_tests.py
@@ -48,7 +48,7 @@ def test_get_df_returns_empty_df(self):
self.assertEqual(type(result), pd.DataFrame)
self.assertTrue(result.empty)
self.assertEqual(test_viz.error_message, 'No data.')
- self.assertEqual(test_viz.status, utils.QueryStatus.FAILED)
+ self.assertEqual(test_viz.status, utils.QueryStatus.FAILED.value)
def test_get_df_handles_dttm_col(self):
datasource = Mock()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services