You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by yo...@apache.org on 2022/08/11 15:38:09 UTC

[superset] branch master updated: fix: Validate required fields in sql_json API (#21003)

This is an automated email from the ASF dual-hosted git repository.

yongjiezhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new a2b21b55be fix: Validate required fields in sql_json API (#21003)
a2b21b55be is described below

commit a2b21b55be8941e1756bd6c10f5b3dd063a20ee3
Author: EugeneTorap <ev...@gmail.com>
AuthorDate: Thu Aug 11 18:37:53 2022 +0300

    fix: Validate required fields in sql_json API (#21003)
    
    * fix: Validate required params for sql_json API
    
    * Test required params in sql_json API
    
    * Refactoring: use marshmallow Schema for validation sql_json API
    
    * Update SqlJsonPayloadSchema
    
    * Update SqlJsonPayloadSchema
    
    * Refactoring
    
    * Refactoring
    
    * Refactoring
---
 superset/initialization/__init__.py             |  2 +-
 superset/views/core.py                          |  5 +++
 superset/views/sql_lab/__init__.py              | 16 +++++++++
 superset/views/sql_lab/schemas.py               | 35 +++++++++++++++++++
 superset/views/{sql_lab.py => sql_lab/views.py} |  8 +++--
 tests/integration_tests/core_tests.py           | 46 +++++++++++++++++++++++++
 6 files changed, 109 insertions(+), 3 deletions(-)

diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py
index 8dfeff9942..2fe5591dac 100644
--- a/superset/initialization/__init__.py
+++ b/superset/initialization/__init__.py
@@ -176,7 +176,7 @@ class SupersetAppInitializer:  # pylint: disable=too-many-public-methods
         from superset.views.log.api import LogRestApi
         from superset.views.log.views import LogModelView
         from superset.views.redirects import R
-        from superset.views.sql_lab import (
+        from superset.views.sql_lab.views import (
             SavedQueryView,
             SavedQueryViewApi,
             SqlLab,
diff --git a/superset/views/core.py b/superset/views/core.py
index 24c56e61d5..4f39233790 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -152,6 +152,7 @@ from superset.views.base import (
     json_success,
     validate_sqlatable,
 )
+from superset.views.sql_lab.schemas import SqlJsonPayloadSchema
 from superset.views.utils import (
     _deserialize_results_payload,
     bootstrap_user_data,
@@ -2433,6 +2434,10 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
     @event_logger.log_this
     @expose("/sql_json/", methods=["POST"])
     def sql_json(self) -> FlaskResponse:
+        errors = SqlJsonPayloadSchema().validate(request.json)
+        if errors:
+            return json_error_response(status=400, payload=errors)
+
         try:
             log_params = {
                 "user_agent": cast(Optional[str], request.headers.get("USER_AGENT"))
diff --git a/superset/views/sql_lab/__init__.py b/superset/views/sql_lab/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/superset/views/sql_lab/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/superset/views/sql_lab/schemas.py b/superset/views/sql_lab/schemas.py
new file mode 100644
index 0000000000..399665afc1
--- /dev/null
+++ b/superset/views/sql_lab/schemas.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from marshmallow import fields, Schema
+
+
+class SqlJsonPayloadSchema(Schema):
+    database_id = fields.Integer(required=True)
+    sql = fields.String(required=True)
+    client_id = fields.String(allow_none=True)
+    queryLimit = fields.Integer(allow_none=True)
+    sql_editor_id = fields.String(allow_none=True)
+    schema = fields.String(allow_none=True)
+    tab = fields.String(allow_none=True)
+    ctas_method = fields.String(allow_none=True)
+    templateParams = fields.String(allow_none=True)
+    tmp_table_name = fields.String(allow_none=True)
+    select_as_cta = fields.Boolean(allow_none=True)
+    json = fields.Boolean(allow_none=True)
+    runAsync = fields.Boolean(allow_none=True)
+    expand_data = fields.Boolean(allow_none=True)
diff --git a/superset/views/sql_lab.py b/superset/views/sql_lab/views.py
similarity index 99%
rename from superset/views/sql_lab.py
rename to superset/views/sql_lab/views.py
index 1042b8f920..509ff4211a 100644
--- a/superset/views/sql_lab.py
+++ b/superset/views/sql_lab/views.py
@@ -30,8 +30,12 @@ from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
 from superset.superset_typing import FlaskResponse
 from superset.utils import core as utils
 from superset.utils.core import get_user_id
-
-from .base import BaseSupersetView, DeleteMixin, json_success, SupersetModelView
+from superset.views.base import (
+    BaseSupersetView,
+    DeleteMixin,
+    json_success,
+    SupersetModelView,
+)
 
 logger = logging.getLogger(__name__)
 
diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py
index 5c99a1e870..471926d6e2 100644
--- a/tests/integration_tests/core_tests.py
+++ b/tests/integration_tests/core_tests.py
@@ -763,6 +763,52 @@ class TestCore(SupersetTestCase):
             f"/superset/extra_table_metadata/{example_db.id}/birth_names/{schema}/"
         )
 
+    def test_required_params_in_sql_json(self):
+        self.login()
+        client_id = "{}".format(random.getrandbits(64))[:10]
+
+        data = {"client_id": client_id}
+        rv = self.client.post(
+            "/superset/sql_json/",
+            json=data,
+        )
+        failed_resp = {
+            "sql": ["Missing data for required field."],
+            "database_id": ["Missing data for required field."],
+        }
+        resp_data = json.loads(rv.data.decode("utf-8"))
+        self.assertDictEqual(resp_data, failed_resp)
+        self.assertEqual(rv.status_code, 400)
+
+        data = {"sql": "SELECT 1", "client_id": client_id}
+        rv = self.client.post(
+            "/superset/sql_json/",
+            json=data,
+        )
+        failed_resp = {"database_id": ["Missing data for required field."]}
+        resp_data = json.loads(rv.data.decode("utf-8"))
+        self.assertDictEqual(resp_data, failed_resp)
+        self.assertEqual(rv.status_code, 400)
+
+        data = {"database_id": 1, "client_id": client_id}
+        rv = self.client.post(
+            "/superset/sql_json/",
+            json=data,
+        )
+        failed_resp = {"sql": ["Missing data for required field."]}
+        resp_data = json.loads(rv.data.decode("utf-8"))
+        self.assertDictEqual(resp_data, failed_resp)
+        self.assertEqual(rv.status_code, 400)
+
+        data = {"sql": "SELECT 1", "database_id": 1, "client_id": client_id}
+        rv = self.client.post(
+            "/superset/sql_json/",
+            json=data,
+        )
+        resp_data = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(resp_data.get("status"), "success")
+        self.assertEqual(rv.status_code, 200)
+
     def test_templated_sql_json(self):
         if superset.utils.database.get_example_database().backend == "presto":
             # TODO: make it work for presto