You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by di...@apache.org on 2023/01/30 16:31:38 UTC

[superset] 01/01: chore: Migrate /superset/estimate_query_cost/// to API v1

This is an automated email from the ASF dual-hosted git repository.

diegomedina24 pushed a commit to branch dm/migrate-estimate_query_cost-to-v1
in repository https://gitbox.apache.org/repos/asf/superset.git

commit f929d568f8ad535822ef6ff39bdfd9d6a394795a
Author: Diego Medina <di...@gmail.com>
AuthorDate: Mon Jan 30 13:12:56 2023 -0300

    chore: Migrate /superset/estimate_query_cost/<database_id>/<schema>/ to API v1
---
 superset-frontend/src/SqlLab/actions/sqlLab.js    |  20 ++--
 superset-frontend/src/SqlLab/reducers/sqlLab.js   |   2 +-
 superset/constants.py                             |   1 +
 superset/sqllab/api.py                            |  59 +++++++++++-
 superset/sqllab/commands/estimate.py              | 106 ++++++++++++++++++++++
 superset/sqllab/schemas.py                        |   7 ++
 tests/integration_tests/sql_lab/api_tests.py      |  67 ++++++++++++++
 tests/integration_tests/sql_lab/commands_tests.py |  78 +++++++++++++++-
 8 files changed, 325 insertions(+), 15 deletions(-)

diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js
index a331e462d7..559c9866d5 100644
--- a/superset-frontend/src/SqlLab/actions/sqlLab.js
+++ b/superset-frontend/src/SqlLab/actions/sqlLab.js
@@ -184,18 +184,20 @@ export function estimateQueryCost(queryEditor) {
     const { dbId, schema, sql, selectedText, templateParams } =
       getUpToDateQuery(getState(), queryEditor);
     const requestSql = selectedText || sql;
-    const endpoint =
-      schema === null
-        ? `/superset/estimate_query_cost/${dbId}/`
-        : `/superset/estimate_query_cost/${dbId}/${schema}/`;
+
+    const postPayload = {
+      database_id: dbId,
+      schema,
+      sql: requestSql,
+      template_params: JSON.parse(templateParams || '{}'),
+    };
+
     return Promise.all([
       dispatch({ type: COST_ESTIMATE_STARTED, query: queryEditor }),
       SupersetClient.post({
-        endpoint,
-        postPayload: {
-          sql: requestSql,
-          templateParams: JSON.parse(templateParams || '{}'),
-        },
+        endpoint: '/api/v1/sqllab/estimate/',
+        body: JSON.stringify(postPayload),
+        headers: { 'Content-Type': 'application/json' },
       })
         .then(({ json }) =>
           dispatch({ type: COST_ESTIMATE_RETURNED, query: queryEditor, json }),
diff --git a/superset-frontend/src/SqlLab/reducers/sqlLab.js b/superset-frontend/src/SqlLab/reducers/sqlLab.js
index e3bb196fbc..a110914b81 100644
--- a/superset-frontend/src/SqlLab/reducers/sqlLab.js
+++ b/superset-frontend/src/SqlLab/reducers/sqlLab.js
@@ -335,7 +335,7 @@ export default function sqlLabReducer(state = {}, action) {
           ...state.queryCostEstimates,
           [action.query.id]: {
             completed: true,
-            cost: action.json,
+            cost: action.json.result,
             error: null,
           },
         },
diff --git a/superset/constants.py b/superset/constants.py
index 3d2c5c470c..d4c3657600 100644
--- a/superset/constants.py
+++ b/superset/constants.py
@@ -142,6 +142,7 @@ MODEL_API_RW_METHOD_PERMISSION_MAP = {
     "delete_ssh_tunnel": "write",
     "get_updated_since": "read",
     "stop_query": "read",
+    "estimate_query_cost": "read",
 }
 
 EXTRA_FORM_DATA_APPEND_KEYS = {
diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py
index 283c3ab638..aa8a0f7c4a 100644
--- a/superset/sqllab/api.py
+++ b/superset/sqllab/api.py
@@ -18,12 +18,13 @@ import logging
 from typing import Any, cast, Dict, Optional
 
 import simplejson as json
-from flask import request
+from flask import request, Response
 from flask_appbuilder.api import expose, protect, rison
 from flask_appbuilder.models.sqla.interface import SQLAInterface
 from marshmallow import ValidationError
 
 from superset import app, is_feature_enabled
+from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP
 from superset.databases.dao import DatabaseDAO
 from superset.extensions import event_logger
 from superset.jinja_context import get_template_processor
@@ -31,6 +32,7 @@ from superset.models.sql_lab import Query
 from superset.queries.dao import QueryDAO
 from superset.sql_lab import get_sql_results
 from superset.sqllab.command_status import SqlJsonExecutionStatus
+from superset.sqllab.commands.estimate import QueryEstimationCommand
 from superset.sqllab.commands.execute import CommandResult, ExecuteSqlCommand
 from superset.sqllab.commands.results import SqlExecutionResultsCommand
 from superset.sqllab.exceptions import (
@@ -40,6 +42,7 @@ from superset.sqllab.exceptions import (
 from superset.sqllab.execution_context_convertor import ExecutionContextConvertor
 from superset.sqllab.query_render import SqlQueryRenderImpl
 from superset.sqllab.schemas import (
+    EstimateQueryCostSchema,
     ExecutePayloadSchema,
     QueryExecutionResponseSchema,
     sql_lab_get_results_schema,
@@ -68,6 +71,8 @@ class SqlLabRestApi(BaseSupersetApi):
 
     class_permission_name = "Query"
 
+    method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
+    estimate_model_schema = EstimateQueryCostSchema()
     execute_model_schema = ExecutePayloadSchema()
 
     apispec_parameter_schemas = {
@@ -79,6 +84,58 @@ class SqlLabRestApi(BaseSupersetApi):
         QueryExecutionResponseSchema,
     )
 
+    @expose("/estimate/", methods=["POST"])
+    @protect()
+    @statsd_metrics
+    @event_logger.log_this_with_context(
+        action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
+        f".estimate_query_cost",
+        log_to_statsd=False,
+    )
+    @requires_json
+    def estimate_query_cost(self, **kwargs: Any) -> Response:
+        """Estimates the SQL query execution cost
+        ---
+        post:
+          summary: >-
+            Estimates the SQL query execution cost
+          requestBody:
+            description: SQL query and params
+            required: true
+            content:
+              application/json:
+                schema:
+                  $ref: '#/components/schemas/EstimateQueryCostSchema'
+          responses:
+            200:
+              description: Query estimation result
+              content:
+                application/json:
+                  schema:
+                    type: object
+                    properties:
+                      result:
+                        type: object
+            400:
+              $ref: '#/components/responses/400'
+            401:
+              $ref: '#/components/responses/401'
+            403:
+              $ref: '#/components/responses/403'
+            404:
+              $ref: '#/components/responses/404'
+            500:
+              $ref: '#/components/responses/500'
+        """
+        try:
+            model = self.estimate_model_schema.load(request.json)
+        except ValidationError as error:
+            return self.response_400(message=error.messages)
+
+        command = QueryEstimationCommand(model)
+        result = command.run()
+        return self.response(200, result=result)
+
     @expose("/results/")
     @protect()
     @statsd_metrics
diff --git a/superset/sqllab/commands/estimate.py b/superset/sqllab/commands/estimate.py
new file mode 100644
index 0000000000..ee0a084ac6
--- /dev/null
+++ b/superset/sqllab/commands/estimate.py
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=too-few-public-methods, too-many-arguments
+from __future__ import annotations
+
+import logging
+from typing import Any, Dict, List
+
+from flask_babel import gettext as __, lazy_gettext as _
+
+from superset import app, db
+from superset.commands.base import BaseCommand
+from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from superset.exceptions import SupersetErrorException, SupersetTimeoutException
+from superset.jinja_context import get_template_processor
+from superset.models.core import Database
+from superset.sqllab.schemas import EstimateQueryCostSchema
+from superset.utils import core as utils
+
+config = app.config
+SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = config["SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT"]
+stats_logger = config["STATS_LOGGER"]
+
+logger = logging.getLogger(__name__)
+
+
+class QueryEstimationCommand(BaseCommand):
+    _database_id: int
+    _sql: str
+    _template_params: Dict[str, Any]
+    _schema: str
+    _database: Database
+
+    def __init__(self, params: EstimateQueryCostSchema) -> None:
+        self._database_id = params.get("database_id")
+        self._sql = params.get("sql", "")
+        self._template_params = params.get("template_params", {})
+        self._schema = params.get("schema", "")
+
+    def validate(self) -> None:
+        self._database = db.session.query(Database).get(self._database_id)
+        if not self._database:
+            raise SupersetErrorException(
+                SupersetError(
+                    message=__("The database could not be found"),
+                    error_type=SupersetErrorType.RESULTS_BACKEND_ERROR,
+                    level=ErrorLevel.ERROR,
+                ),
+                status=404,
+            )
+
+    def run(
+        self,
+    ) -> List[Dict[str, Any]]:
+        self.validate()
+
+        sql = self._sql
+        if self._template_params:
+            template_processor = get_template_processor(self._database)
+            sql = template_processor.process_template(sql, **self._template_params)
+
+        timeout = SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT
+        timeout_msg = f"The estimation exceeded the {timeout} seconds timeout."
+        try:
+            with utils.timeout(seconds=timeout, error_message=timeout_msg):
+                cost = self._database.db_engine_spec.estimate_query_cost(
+                    self._database, self._schema, sql, utils.QuerySource.SQL_LAB
+                )
+        except SupersetTimeoutException as ex:
+            logger.exception(ex)
+            raise SupersetErrorException(
+                SupersetError(
+                    message=__(
+                        "The query estimation was killed after %(sqllab_timeout)s seconds. It might "
+                        "be too complex, or the database might be under heavy load.",
+                        sqllab_timeout=SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT,
+                    ),
+                    error_type=SupersetErrorType.SQLLAB_TIMEOUT_ERROR,
+                    level=ErrorLevel.ERROR,
+                ),
+                status=500,
+            )
+
+        spec = self._database.db_engine_spec
+        query_cost_formatters: Dict[str, Any] = app.config[
+            "QUERY_COST_FORMATTERS_BY_ENGINE"
+        ]
+        query_cost_formatter = query_cost_formatters.get(
+            spec.engine, spec.query_cost_formatter
+        )
+        cost = query_cost_formatter(cost)
+        return cost
diff --git a/superset/sqllab/schemas.py b/superset/sqllab/schemas.py
index f238fda5c9..d146558c56 100644
--- a/superset/sqllab/schemas.py
+++ b/superset/sqllab/schemas.py
@@ -25,6 +25,13 @@ sql_lab_get_results_schema = {
 }
 
 
+class EstimateQueryCostSchema(Schema):
+    database_id = fields.Integer(required=True)
+    sql = fields.String(required=True)
+    template_params = fields.Dict(keys=fields.String())
+    schema = fields.String(allow_none=True)
+
+
 class ExecutePayloadSchema(Schema):
     database_id = fields.Integer(required=True)
     sql = fields.String(required=True)
diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py
index 4c2080ad4c..8c34e0a79b 100644
--- a/tests/integration_tests/sql_lab/api_tests.py
+++ b/tests/integration_tests/sql_lab/api_tests.py
@@ -39,6 +39,73 @@ QUERIES_FIXTURE_COUNT = 10
 
 
 class TestSqlLabApi(SupersetTestCase):
+    def test_estimate_required_params(self):
+        self.login()
+
+        rv = self.client.post(
+            "/api/v1/sqllab/estimate/",
+            json={},
+        )
+        failed_resp = {
+            "message": {
+                "sql": ["Missing data for required field."],
+                "database_id": ["Missing data for required field."],
+            }
+        }
+        resp_data = json.loads(rv.data.decode("utf-8"))
+        self.assertDictEqual(resp_data, failed_resp)
+        self.assertEqual(rv.status_code, 400)
+
+        data = {"sql": "SELECT 1"}
+        rv = self.client.post(
+            "/api/v1/sqllab/estimate/",
+            json=data,
+        )
+        failed_resp = {"message": {"database_id": ["Missing data for required field."]}}
+        resp_data = json.loads(rv.data.decode("utf-8"))
+        self.assertDictEqual(resp_data, failed_resp)
+        self.assertEqual(rv.status_code, 400)
+
+        data = {"database_id": 1}
+        rv = self.client.post(
+            "/api/v1/sqllab/estimate/",
+            json=data,
+        )
+        failed_resp = {"message": {"sql": ["Missing data for required field."]}}
+        resp_data = json.loads(rv.data.decode("utf-8"))
+        self.assertDictEqual(resp_data, failed_resp)
+        self.assertEqual(rv.status_code, 400)
+
+    def test_estimate_valid_request(self):
+        self.login()
+
+        formatter_response = [
+            {
+                "value": 100,
+            }
+        ]
+
+        db_mock = mock.Mock()
+        db_mock.db_engine_spec = mock.Mock()
+        db_mock.db_engine_spec.estimate_query_cost = mock.Mock(return_value=100)
+        db_mock.db_engine_spec.query_cost_formatter = mock.Mock(
+            return_value=formatter_response
+        )
+
+        with mock.patch("superset.sqllab.commands.estimate.db") as mock_superset_db:
+            mock_superset_db.session.query().get.return_value = db_mock
+
+            data = {"database_id": 1, "sql": "SELECT 1"}
+            rv = self.client.post(
+                "/api/v1/sqllab/estimate/",
+                json=data,
+            )
+
+        success_resp = {"result": formatter_response}
+        resp_data = json.loads(rv.data.decode("utf-8"))
+        self.assertDictEqual(resp_data, success_resp)
+        self.assertEqual(rv.status_code, 200)
+
     @mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False)
     def test_execute_required_params(self):
         self.login()
diff --git a/tests/integration_tests/sql_lab/commands_tests.py b/tests/integration_tests/sql_lab/commands_tests.py
index 74c1fe7082..84e2272947 100644
--- a/tests/integration_tests/sql_lab/commands_tests.py
+++ b/tests/integration_tests/sql_lab/commands_tests.py
@@ -18,18 +18,88 @@ from unittest import mock, skip
 from unittest.mock import patch
 
 import pytest
+from flask_babel import gettext as __
 
-from superset import db, sql_lab
+from superset import app, db, sql_lab
 from superset.common.db_query_status import QueryStatus
-from superset.errors import SupersetErrorType
-from superset.exceptions import SerializationError, SupersetErrorException
+from superset.errors import ErrorLevel, SupersetErrorType
+from superset.exceptions import (
+    SerializationError,
+    SupersetErrorException,
+    SupersetTimeoutException,
+)
 from superset.models.core import Database
 from superset.models.sql_lab import Query
-from superset.sqllab.commands import results
+from superset.sqllab.commands import estimate, results
 from superset.utils import core as utils
 from tests.integration_tests.base_tests import SupersetTestCase
 
 
+class TestQueryEstimationCommand(SupersetTestCase):
+    def test_validation_no_database(self) -> None:
+        params = {"database_id": 1, "sql": "SELECT 1"}
+        command = estimate.QueryEstimationCommand(params)
+
+        with mock.patch("superset.sqllab.commands.estimate.db") as mock_superset_db:
+            mock_superset_db.session.query().get.return_value = None
+            with pytest.raises(SupersetErrorException) as ex_info:
+                command.validate()
+            assert (
+                ex_info.value.error.error_type
+                == SupersetErrorType.RESULTS_BACKEND_ERROR
+            )
+
+    @patch("superset.tasks.scheduler.is_feature_enabled")
+    def test_run_timeout(self, is_feature_enabled) -> None:
+        params = {"database_id": 1, "sql": "SELECT 1", "template_params": {"temp": 123}}
+        command = estimate.QueryEstimationCommand(params)
+
+        db_mock = mock.Mock()
+        db_mock.db_engine_spec = mock.Mock()
+        db_mock.db_engine_spec.estimate_query_cost = mock.Mock(
+            side_effect=SupersetTimeoutException(
+                error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
+                message=(
+                    "Please check your connection details and database settings, "
+                    "and ensure that your database is accepting connections, "
+                    "then try connecting again."
+                ),
+                level=ErrorLevel.ERROR,
+            )
+        )
+        db_mock.db_engine_spec.query_cost_formatter = mock.Mock(return_value=None)
+        is_feature_enabled.return_value = False
+
+        with mock.patch("superset.sqllab.commands.estimate.db") as mock_superset_db:
+            mock_superset_db.session.query().get.return_value = db_mock
+            with pytest.raises(SupersetErrorException) as ex_info:
+                command.run()
+            assert (
+                ex_info.value.error.error_type == SupersetErrorType.SQLLAB_TIMEOUT_ERROR
+            )
+            assert ex_info.value.error.message == __(
+                "The query estimation was killed after %(sqllab_timeout)s seconds. It might "
+                "be too complex, or the database might be under heavy load.",
+                sqllab_timeout=app.config["SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT"],
+            )
+
+    def test_run_success(self) -> None:
+        params = {"database_id": 1, "sql": "SELECT 1"}
+        command = estimate.QueryEstimationCommand(params)
+
+        payload = {"value": 100}
+
+        db_mock = mock.Mock()
+        db_mock.db_engine_spec = mock.Mock()
+        db_mock.db_engine_spec.estimate_query_cost = mock.Mock(return_value=100)
+        db_mock.db_engine_spec.query_cost_formatter = mock.Mock(return_value=payload)
+
+        with mock.patch("superset.sqllab.commands.estimate.db") as mock_superset_db:
+            mock_superset_db.session.query().get.return_value = db_mock
+            result = command.run()
+            assert result == payload
+
+
 class TestSqlExecutionResultsCommand(SupersetTestCase):
     @mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False)
     def test_validation_no_results_backend(self) -> None: