You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2021/01/04 17:23:13 UTC

[incubator-superset] branch master updated: fix: CTAS on multiple statements (#12188)

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 164db3e  fix: CTAS on multiple statements (#12188)
164db3e is described below

commit 164db3e5a13c21137afb56a3044ef3f1aaf89e11
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Mon Jan 4 09:22:35 2021 -0800

    fix: CTAS on multiple statements (#12188)
    
    * WIP
    
    * Add unit tests for sql_parse
    
    * Add unit tests for sql_lab
---
 superset/sql_lab.py      |  57 ++++++++++++----
 superset/sql_parse.py    |  11 +++-
 tests/sql_parse_tests.py |  76 ++++++++++++++++++++++
 tests/sqllab_tests.py    | 164 ++++++++++++++++++++++++++++++++++++++++++++++-
 4 files changed, 295 insertions(+), 13 deletions(-)

diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 8001498..1153a2b 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -36,7 +36,7 @@ from superset.db_engine_specs import BaseEngineSpec
 from superset.extensions import celery_app
 from superset.models.sql_lab import Query
 from superset.result_set import SupersetResultSet
-from superset.sql_parse import ParsedQuery
+from superset.sql_parse import CtasMethod, ParsedQuery
 from superset.utils.celery import session_scope
 from superset.utils.core import (
     json_iso_dttm_ser,
@@ -160,6 +160,7 @@ def execute_sql_statement(
     session: Session,
     cursor: Any,
     log_params: Optional[Dict[str, Any]],
+    apply_ctas: bool = False,
 ) -> SupersetResultSet:
     """Executes a single SQL statement"""
     database = query.database
@@ -171,14 +172,7 @@ def execute_sql_statement(
         raise SqlLabSecurityException(
             _("Only `SELECT` statements are allowed against this database")
         )
-    if query.select_as_cta:
-        if not parsed_query.is_select():
-            raise SqlLabException(
-                _(
-                    "Only `SELECT` statements can be used with the CREATE TABLE "
-                    "feature."
-                )
-            )
+    if apply_ctas:
         if not query.tmp_table_name:
             start_dttm = datetime.fromtimestamp(query.start_time)
             query.tmp_table_name = "tmp_{}_table_{}".format(
@@ -322,8 +316,8 @@ def execute_sql_statements(  # pylint: disable=too-many-arguments, too-many-loca
         raise SqlLabException("Results backend isn't configured.")
 
     # Breaking down into multiple statements
+    parsed_query = ParsedQuery(rendered_query, strip_comments=True)
     if not db_engine_spec.run_multiple_statements_as_one:
-        parsed_query = ParsedQuery(rendered_query)
         statements = parsed_query.get_statements()
         logger.info(
             "Query %s: Executing %i statement(s)", str(query_id), len(statements)
@@ -337,6 +331,32 @@ def execute_sql_statements(  # pylint: disable=too-many-arguments, too-many-loca
     query.start_running_time = now_as_float()
     session.commit()
 
+    # Should we create a table or view from the select?
+    if (
+        query.select_as_cta
+        and query.ctas_method == CtasMethod.TABLE
+        and not parsed_query.is_valid_ctas()
+    ):
+        raise SqlLabException(
+            _(
+                "CTAS (create table as select) can only be run with a query where "
+                "the last statement is a SELECT. Please make sure your query has "
+                "a SELECT as its last statement. Then, try running your query again."
+            )
+        )
+    if (
+        query.select_as_cta
+        and query.ctas_method == CtasMethod.VIEW
+        and not parsed_query.is_valid_cvas()
+    ):
+        raise SqlLabException(
+            _(
+                "CVAS (create view as select) can only be run with a query with "
+                "a single SELECT statement. Please make sure your query has only "
+                "a SELECT statement. Then, try running your query again."
+            )
+        )
+
     engine = database.get_sqla_engine(
         schema=query.schema,
         nullpool=True,
@@ -354,6 +374,15 @@ def execute_sql_statements(  # pylint: disable=too-many-arguments, too-many-loca
                 if query.status == QueryStatus.STOPPED:
                     return None
 
+                # For CTAS we create the table only on the last statement
+                apply_ctas = query.select_as_cta and (
+                    query.ctas_method == CtasMethod.VIEW
+                    or (
+                        query.ctas_method == CtasMethod.TABLE
+                        and i == len(statements) - 1
+                    )
+                )
+
                 # Run statement
                 msg = f"Running statement {i+1} out of {statement_count}"
                 logger.info("Query %s: %s", str(query_id), msg)
@@ -361,7 +390,13 @@ def execute_sql_statements(  # pylint: disable=too-many-arguments, too-many-loca
                 session.commit()
                 try:
                     result_set = execute_sql_statement(
-                        statement, query, user_name, session, cursor, log_params
+                        statement,
+                        query,
+                        user_name,
+                        session,
+                        cursor,
+                        log_params,
+                        apply_ctas,
                     )
                 except Exception as ex:  # pylint: disable=broad-except
                     msg = str(ex)
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 8343f42..dd345db 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -81,7 +81,10 @@ class Table:  # pylint: disable=too-few-public-methods
 
 
 class ParsedQuery:
-    def __init__(self, sql_statement: str):
+    def __init__(self, sql_statement: str, strip_comments: bool = False):
+        if strip_comments:
+            sql_statement = sqlparse.format(sql_statement, strip_comments=True)
+
         self.sql: str = sql_statement
         self._tables: Set[Table] = set()
         self._alias_names: Set[str] = set()
@@ -110,6 +113,12 @@ class ParsedQuery:
     def is_select(self) -> bool:
         return self._parsed[0].get_type() == "SELECT"
 
+    def is_valid_ctas(self) -> bool:
+        return self._parsed[-1].get_type() == "SELECT"
+
+    def is_valid_cvas(self) -> bool:
+        return len(self._parsed) == 1 and self._parsed[0].get_type() == "SELECT"
+
     def is_explain(self) -> bool:
         # Remove comments
         statements_without_comments = sqlparse.format(
diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py
index e46c2b8..b54a9ef 100644
--- a/tests/sql_parse_tests.py
+++ b/tests/sql_parse_tests.py
@@ -656,3 +656,79 @@ class TestSupersetSqlParse(unittest.TestCase):
         """
         parsed = ParsedQuery(query)
         self.assertEqual(parsed.is_explain(), False)
+
+    def test_is_valid_ctas(self):
+        """A valid CTAS has a SELECT as its last statement"""
+        query = "SELECT * FROM table"
+        parsed = ParsedQuery(query, strip_comments=True)
+        assert parsed.is_valid_ctas()
+
+        query = """
+            -- comment
+            SELECT * FROM table
+            -- comment 2
+        """
+        parsed = ParsedQuery(query, strip_comments=True)
+        assert parsed.is_valid_ctas()
+
+        query = """
+            -- comment
+            SET @value = 42;
+            SELECT @value as foo;
+            -- comment 2
+        """
+        parsed = ParsedQuery(query, strip_comments=True)
+        assert parsed.is_valid_ctas()
+
+        query = """
+            -- comment
+            EXPLAIN SELECT * FROM table
+            -- comment 2
+        """
+        parsed = ParsedQuery(query, strip_comments=True)
+        assert not parsed.is_valid_ctas()
+
+        query = """
+            SELECT * FROM table;
+            INSERT INTO TABLE (foo) VALUES (42);
+        """
+        parsed = ParsedQuery(query, strip_comments=True)
+        assert not parsed.is_valid_ctas()
+
+    def test_is_valid_cvas(self):
+        """A valid CVAS has a single SELECT statement"""
+        query = "SELECT * FROM table"
+        parsed = ParsedQuery(query, strip_comments=True)
+        assert parsed.is_valid_cvas()
+
+        query = """
+            -- comment
+            SELECT * FROM table
+            -- comment 2
+        """
+        parsed = ParsedQuery(query, strip_comments=True)
+        assert parsed.is_valid_cvas()
+
+        query = """
+            -- comment
+            SET @value = 42;
+            SELECT @value as foo;
+            -- comment 2
+        """
+        parsed = ParsedQuery(query, strip_comments=True)
+        assert not parsed.is_valid_cvas()
+
+        query = """
+            -- comment
+            EXPLAIN SELECT * FROM table
+            -- comment 2
+        """
+        parsed = ParsedQuery(query, strip_comments=True)
+        assert not parsed.is_valid_ctas()
+
+        query = """
+            SELECT * FROM table;
+            INSERT INTO TABLE (foo) VALUES (42);
+        """
+        parsed = ParsedQuery(query, strip_comments=True)
+        assert not parsed.is_valid_ctas()
diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py
index a10ef4f..8c2b10a 100644
--- a/tests/sqllab_tests.py
+++ b/tests/sqllab_tests.py
@@ -18,11 +18,12 @@
 """Unit tests for Sql Lab"""
 import json
 from datetime import datetime, timedelta
-from parameterized import parameterized
 from random import random
 from unittest import mock
 
+from parameterized import parameterized
 import prison
+import pytest
 
 from superset import db, security_manager
 from superset.connectors.sqla.models import SqlaTable
@@ -30,6 +31,7 @@ from superset.db_engine_specs import BaseEngineSpec
 from superset.errors import ErrorLevel, SupersetErrorType
 from superset.models.sql_lab import Query, SavedQuery
 from superset.result_set import SupersetResultSet
+from superset.sql_lab import execute_sql_statements, SqlLabException
 from superset.sql_parse import CtasMethod
 from superset.utils.core import (
     datetime_to_epoch,
@@ -618,3 +620,163 @@ class TestSqlLab(SupersetTestCase):
             "template_parameters": {"state": "CA"},
             "undefined_parameters": ["stat"],
         }
+
+    @mock.patch("superset.sql_lab.get_query")
+    @mock.patch("superset.sql_lab.execute_sql_statement")
+    def test_execute_sql_statements(self, mock_execute_sql_statement, mock_get_query):
+        sql = """
+            -- comment
+            SET @value = 42;
+            SELECT @value AS foo;
+            -- comment
+        """
+        mock_session = mock.MagicMock()
+        mock_query = mock.MagicMock()
+        mock_query.database.allow_run_async = False
+        mock_cursor = mock.MagicMock()
+        mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
+            mock_cursor
+        )
+        mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
+        mock_get_query.return_value = mock_query
+
+        execute_sql_statements(
+            query_id=1,
+            rendered_query=sql,
+            return_results=True,
+            store_results=False,
+            user_name="admin",
+            session=mock_session,
+            start_time=None,
+            expand_data=False,
+            log_params=None,
+        )
+        mock_execute_sql_statement.assert_has_calls(
+            [
+                mock.call(
+                    "SET @value = 42",
+                    mock_query,
+                    "admin",
+                    mock_session,
+                    mock_cursor,
+                    None,
+                    False,
+                ),
+                mock.call(
+                    "SELECT @value AS foo",
+                    mock_query,
+                    "admin",
+                    mock_session,
+                    mock_cursor,
+                    None,
+                    False,
+                ),
+            ]
+        )
+
+    @mock.patch("superset.sql_lab.get_query")
+    @mock.patch("superset.sql_lab.execute_sql_statement")
+    def test_execute_sql_statements_ctas(
+        self, mock_execute_sql_statement, mock_get_query
+    ):
+        sql = """
+            -- comment
+            SET @value = 42;
+            SELECT @value AS foo;
+            -- comment
+        """
+        mock_session = mock.MagicMock()
+        mock_query = mock.MagicMock()
+        mock_query.database.allow_run_async = False
+        mock_cursor = mock.MagicMock()
+        mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
+            mock_cursor
+        )
+        mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
+        mock_get_query.return_value = mock_query
+
+        # set the query to CTAS
+        mock_query.select_as_cta = True
+        mock_query.ctas_method = CtasMethod.TABLE
+
+        execute_sql_statements(
+            query_id=1,
+            rendered_query=sql,
+            return_results=True,
+            store_results=False,
+            user_name="admin",
+            session=mock_session,
+            start_time=None,
+            expand_data=False,
+            log_params=None,
+        )
+        mock_execute_sql_statement.assert_has_calls(
+            [
+                mock.call(
+                    "SET @value = 42",
+                    mock_query,
+                    "admin",
+                    mock_session,
+                    mock_cursor,
+                    None,
+                    False,
+                ),
+                mock.call(
+                    "SELECT @value AS foo",
+                    mock_query,
+                    "admin",
+                    mock_session,
+                    mock_cursor,
+                    None,
+                    True,  # apply_ctas
+                ),
+            ]
+        )
+
+        # try invalid CTAS
+        sql = "DROP TABLE my_table"
+        with pytest.raises(SqlLabException) as excinfo:
+            execute_sql_statements(
+                query_id=1,
+                rendered_query=sql,
+                return_results=True,
+                store_results=False,
+                user_name="admin",
+                session=mock_session,
+                start_time=None,
+                expand_data=False,
+                log_params=None,
+            )
+        assert str(excinfo.value) == (
+            "CTAS (create table as select) can only be run with "
+            "a query where the last statement is a SELECT. Please "
+            "make sure your query has a SELECT as its last "
+            "statement. Then, try running your query again."
+        )
+
+        # try invalid CVAS
+        mock_query.ctas_method = CtasMethod.VIEW
+        sql = """
+            -- comment
+            SET @value = 42;
+            SELECT @value AS foo;
+            -- comment
+        """
+        with pytest.raises(SqlLabException) as excinfo:
+            execute_sql_statements(
+                query_id=1,
+                rendered_query=sql,
+                return_results=True,
+                store_results=False,
+                user_name="admin",
+                session=mock_session,
+                start_time=None,
+                expand_data=False,
+                log_params=None,
+            )
+        assert str(excinfo.value) == (
+            "CVAS (create view as select) can only be run with a "
+            "query with a single SELECT statement. Please make "
+            "sure your query has only a SELECT statement. Then, "
+            "try running your query again."
+        )