You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2023/03/18 00:02:34 UTC

[superset] branch master updated: chore: Hugh/migrate estimate query cost to v1 (#23226)

This is an automated email from the ASF dual-hosted git repository.

hugh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 8fa77adf9a chore: Hugh/migrate estimate query cost to v1 (#23226)
8fa77adf9a is described below

commit 8fa77adf9a810d622ea3033e621e3925e503b993
Author: Hugh A. Miles II <hu...@gmail.com>
AuthorDate: Fri Mar 17 18:02:25 2023 -0600

    chore: Hugh/migrate estimate query cost to v1 (#23226)
    
    Co-authored-by: Diego Medina <di...@gmail.com>
---
 UPDATING.md                                       |   1 +
 superset-frontend/src/SqlLab/actions/sqlLab.js    |  20 ++--
 superset-frontend/src/SqlLab/reducers/sqlLab.js   |   2 +-
 superset/security/manager.py                      |   1 +
 superset/sqllab/api.py                            |  56 +++++++++++-
 superset/sqllab/commands/estimate.py              | 106 ++++++++++++++++++++++
 superset/sqllab/schemas.py                        |   9 ++
 superset/views/core.py                            |   1 +
 tests/integration_tests/sql_lab/api_tests.py      |  67 ++++++++++++++
 tests/integration_tests/sql_lab/commands_tests.py |  81 ++++++++++++++++-
 10 files changed, 329 insertions(+), 15 deletions(-)

diff --git a/UPDATING.md b/UPDATING.md
index c29b7182a0..3496742bee 100644
--- a/UPDATING.md
+++ b/UPDATING.md
@@ -24,6 +24,7 @@ assists people when migrating to a new version.
 
 ## Next
 
+- [23226](https://github.com/apache/superset/pull/23226) Migrated endpoint `/estimate_query_cost/<int:database_id>` to `/api/v1/sqllab/estimate/`. Corresponding permissions are can estimate query cost on SQLLab. Make sure you add/replace the necessary permissions on any custom roles you may have.
 - [22809](https://github.com/apache/superset/pull/22809): Migrated endpoint `/superset/sql_json` and `/superset/results/` to `/api/v1/sqllab/execute/` and `/api/v1/sqllab/results/` respectively. Corresponding permissions are `can sql_json on Superset` to `can execute on SQLLab`, `can results on Superset` to `can results on SQLLab`. Make sure you add/replace the necessary permissions on any custom roles you may have.
 - [22931](https://github.com/apache/superset/pull/22931): Migrated endpoint `/superset/get_or_create_table/` to `/api/v1/dataset/get_or_create/`. Corresponding permissions are `can get or create table on Superset` to `can get or create dataset on Dataset`. Make sure you add/replace the necessary permissions on any custom roles you may have.
 - [22882](https://github.com/apache/superset/pull/22882): Migrated endpoint `/superset/filter/<datasource_type>/<int:datasource_id>/<column>/` to `/api/v1/datasource/<datasource_type>/<datasource_id>/column/<column_name>/values/`. Corresponding permissions are `can filter on Superset` to `can get column values on Datasource`. Make sure you add/replace the necessary permissions on any custom roles you may have.
diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js
index ab8abe0edc..c27485879b 100644
--- a/superset-frontend/src/SqlLab/actions/sqlLab.js
+++ b/superset-frontend/src/SqlLab/actions/sqlLab.js
@@ -184,18 +184,20 @@ export function estimateQueryCost(queryEditor) {
     const { dbId, schema, sql, selectedText, templateParams } =
       getUpToDateQuery(getState(), queryEditor);
     const requestSql = selectedText || sql;
-    const endpoint =
-      schema === null
-        ? `/superset/estimate_query_cost/${dbId}/`
-        : `/superset/estimate_query_cost/${dbId}/${schema}/`;
+
+    const postPayload = {
+      database_id: dbId,
+      schema,
+      sql: requestSql,
+      template_params: JSON.parse(templateParams || '{}'),
+    };
+
     return Promise.all([
       dispatch({ type: COST_ESTIMATE_STARTED, query: queryEditor }),
       SupersetClient.post({
-        endpoint,
-        postPayload: {
-          sql: requestSql,
-          templateParams: JSON.parse(templateParams || '{}'),
-        },
+        endpoint: '/api/v1/sqllab/estimate/',
+        body: JSON.stringify(postPayload),
+        headers: { 'Content-Type': 'application/json' },
       })
         .then(({ json }) =>
           dispatch({ type: COST_ESTIMATE_RETURNED, query: queryEditor, json }),
diff --git a/superset-frontend/src/SqlLab/reducers/sqlLab.js b/superset-frontend/src/SqlLab/reducers/sqlLab.js
index e3bb196fbc..a110914b81 100644
--- a/superset-frontend/src/SqlLab/reducers/sqlLab.js
+++ b/superset-frontend/src/SqlLab/reducers/sqlLab.js
@@ -335,7 +335,7 @@ export default function sqlLabReducer(state = {}, action) {
           ...state.queryCostEstimates,
           [action.query.id]: {
             completed: true,
-            cost: action.json,
+            cost: action.json.result,
             error: null,
           },
         },
diff --git a/superset/security/manager.py b/superset/security/manager.py
index 5aa5080294..9e7a8bbd4b 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -250,6 +250,7 @@ class SupersetSecurityManager(  # pylint: disable=too-many-public-methods
         ("can_export_csv", "Query"),
         ("can_get_results", "SQLLab"),
         ("can_execute_sql_query", "SQLLab"),
+        ("can_estimate_query_cost", "SQL Lab"),
         ("can_export_csv", "SQLLab"),
         ("can_sql_json", "Superset"),  # Deprecated permission remove on 3.0.0
         ("can_sqllab_history", "Superset"),
diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py
index 5915601c0d..8801cbc5f9 100644
--- a/superset/sqllab/api.py
+++ b/superset/sqllab/api.py
@@ -19,7 +19,7 @@ from typing import Any, cast, Dict, Optional
 from urllib import parse
 
 import simplejson as json
-from flask import request
+from flask import request, Response
 from flask_appbuilder.api import expose, protect, rison
 from flask_appbuilder.models.sqla.interface import SQLAInterface
 from marshmallow import ValidationError
@@ -32,6 +32,7 @@ from superset.models.sql_lab import Query
 from superset.queries.dao import QueryDAO
 from superset.sql_lab import get_sql_results
 from superset.sqllab.command_status import SqlJsonExecutionStatus
+from superset.sqllab.commands.estimate import QueryEstimationCommand
 from superset.sqllab.commands.execute import CommandResult, ExecuteSqlCommand
 from superset.sqllab.commands.export import SqlResultExportCommand
 from superset.sqllab.commands.results import SqlExecutionResultsCommand
@@ -42,6 +43,7 @@ from superset.sqllab.exceptions import (
 from superset.sqllab.execution_context_convertor import ExecutionContextConvertor
 from superset.sqllab.query_render import SqlQueryRenderImpl
 from superset.sqllab.schemas import (
+    EstimateQueryCostSchema,
     ExecutePayloadSchema,
     QueryExecutionResponseSchema,
     sql_lab_get_results_schema,
@@ -70,6 +72,7 @@ class SqlLabRestApi(BaseSupersetApi):
 
     class_permission_name = "SQLLab"
 
+    estimate_model_schema = EstimateQueryCostSchema()
     execute_model_schema = ExecutePayloadSchema()
 
     apispec_parameter_schemas = {
@@ -77,10 +80,61 @@ class SqlLabRestApi(BaseSupersetApi):
     }
     openapi_spec_tag = "SQL Lab"
     openapi_spec_component_schemas = (
+        EstimateQueryCostSchema,
         ExecutePayloadSchema,
         QueryExecutionResponseSchema,
     )
 
+    @expose("/estimate/", methods=["POST"])
+    @protect()
+    @statsd_metrics
+    @requires_json
+    @event_logger.log_this_with_context(
+        action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
+        f".estimate_query_cost",
+        log_to_statsd=False,
+    )
+    def estimate_query_cost(self) -> Response:
+        """Estimates the SQL query execution cost
+        ---
+        post:
+          summary: >-
+            Estimates the SQL query execution cost
+          requestBody:
+            description: SQL query and params
+            required: true
+            content:
+              application/json:
+                schema:
+                  $ref: '#/components/schemas/EstimateQueryCostSchema'
+          responses:
+            200:
+              description: Query estimation result
+              content:
+                application/json:
+                  schema:
+                    type: object
+                    properties:
+                      result:
+                        type: object
+            400:
+              $ref: '#/components/responses/400'
+            401:
+              $ref: '#/components/responses/401'
+            403:
+              $ref: '#/components/responses/403'
+            500:
+              $ref: '#/components/responses/500'
+        """
+        try:
+            model = self.estimate_model_schema.load(request.json)
+        except ValidationError as error:
+            return self.response_400(message=error.messages)
+
+        command = QueryEstimationCommand(model)
+        result = command.run()
+        return self.response(200, result=result)
+
     @expose("/export/<string:client_id>/")
     @protect()
     @statsd_metrics
diff --git a/superset/sqllab/commands/estimate.py b/superset/sqllab/commands/estimate.py
new file mode 100644
index 0000000000..2b8c5814b9
--- /dev/null
+++ b/superset/sqllab/commands/estimate.py
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+from typing import Any, Dict, List
+
+from flask_babel import gettext as __
+
+from superset import app, db
+from superset.commands.base import BaseCommand
+from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from superset.exceptions import SupersetErrorException, SupersetTimeoutException
+from superset.jinja_context import get_template_processor
+from superset.models.core import Database
+from superset.sqllab.schemas import EstimateQueryCostSchema
+from superset.utils import core as utils
+
+config = app.config
+SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = config["SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT"]
+stats_logger = config["STATS_LOGGER"]
+
+logger = logging.getLogger(__name__)
+
+
+class QueryEstimationCommand(BaseCommand):
+    _database_id: int
+    _sql: str
+    _template_params: Dict[str, Any]
+    _schema: str
+    _database: Database
+
+    def __init__(self, params: EstimateQueryCostSchema) -> None:
+        self._database_id = params.get("database_id")
+        self._sql = params.get("sql", "")
+        self._template_params = params.get("template_params", {})
+        self._schema = params.get("schema", "")
+
+    def validate(self) -> None:
+        self._database = db.session.query(Database).get(self._database_id)
+        if not self._database:
+            raise SupersetErrorException(
+                SupersetError(
+                    message=__("The database could not be found"),
+                    error_type=SupersetErrorType.RESULTS_BACKEND_ERROR,
+                    level=ErrorLevel.ERROR,
+                ),
+                status=404,
+            )
+
+    def run(
+        self,
+    ) -> List[Dict[str, Any]]:
+        self.validate()
+
+        sql = self._sql
+        if self._template_params:
+            template_processor = get_template_processor(self._database)
+            sql = template_processor.process_template(sql, **self._template_params)
+
+        timeout = SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT
+        timeout_msg = f"The estimation exceeded the {timeout} seconds timeout."
+        try:
+            with utils.timeout(seconds=timeout, error_message=timeout_msg):
+                cost = self._database.db_engine_spec.estimate_query_cost(
+                    self._database, self._schema, sql, utils.QuerySource.SQL_LAB
+                )
+        except SupersetTimeoutException as ex:
+            logger.exception(ex)
+            raise SupersetErrorException(
+                SupersetError(
+                    message=__(
+                        "The query estimation was killed after %(sqllab_timeout)s "
+                        "seconds. It might be too complex, or the database might be "
+                        "under heavy load.",
+                        sqllab_timeout=SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT,
+                    ),
+                    error_type=SupersetErrorType.SQLLAB_TIMEOUT_ERROR,
+                    level=ErrorLevel.ERROR,
+                ),
+                status=500,
+            ) from ex
+
+        spec = self._database.db_engine_spec
+        query_cost_formatters: Dict[str, Any] = app.config[
+            "QUERY_COST_FORMATTERS_BY_ENGINE"
+        ]
+        query_cost_formatter = query_cost_formatters.get(
+            spec.engine, spec.query_cost_formatter
+        )
+        cost = query_cost_formatter(cost)
+        return cost
diff --git a/superset/sqllab/schemas.py b/superset/sqllab/schemas.py
index f238fda5c9..134b9ea7bb 100644
--- a/superset/sqllab/schemas.py
+++ b/superset/sqllab/schemas.py
@@ -25,6 +25,15 @@ sql_lab_get_results_schema = {
 }
 
 
+class EstimateQueryCostSchema(Schema):
+    database_id = fields.Integer(required=True, description="The database id")
+    sql = fields.String(required=True, description="The SQL query to estimate")
+    template_params = fields.Dict(
+        keys=fields.String(), description="The SQL query template params"
+    )
+    schema = fields.String(allow_none=True, description="The database schema")
+
+
 class ExecutePayloadSchema(Schema):
     database_id = fields.Integer(required=True)
     sql = fields.String(required=True)
diff --git a/superset/views/core.py b/superset/views/core.py
index 3bd0ec651e..44f1b78af0 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -2062,6 +2062,7 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
     @expose("/estimate_query_cost/<int:database_id>/", methods=["POST"])
     @expose("/estimate_query_cost/<int:database_id>/<schema>/", methods=["POST"])
     @event_logger.log_this
+    @deprecated()
     def estimate_query_cost(  # pylint: disable=no-self-use
         self, database_id: int, schema: Optional[str] = None
     ) -> FlaskResponse:
diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py
index 93beb380f0..a57d24c3e4 100644
--- a/tests/integration_tests/sql_lab/api_tests.py
+++ b/tests/integration_tests/sql_lab/api_tests.py
@@ -42,6 +42,73 @@ QUERIES_FIXTURE_COUNT = 10
 
 
 class TestSqlLabApi(SupersetTestCase):
+    def test_estimate_required_params(self):
+        self.login()
+
+        rv = self.client.post(
+            "/api/v1/sqllab/estimate/",
+            json={},
+        )
+        failed_resp = {
+            "message": {
+                "sql": ["Missing data for required field."],
+                "database_id": ["Missing data for required field."],
+            }
+        }
+        resp_data = json.loads(rv.data.decode("utf-8"))
+        self.assertDictEqual(resp_data, failed_resp)
+        self.assertEqual(rv.status_code, 400)
+
+        data = {"sql": "SELECT 1"}
+        rv = self.client.post(
+            "/api/v1/sqllab/estimate/",
+            json=data,
+        )
+        failed_resp = {"message": {"database_id": ["Missing data for required field."]}}
+        resp_data = json.loads(rv.data.decode("utf-8"))
+        self.assertDictEqual(resp_data, failed_resp)
+        self.assertEqual(rv.status_code, 400)
+
+        data = {"database_id": 1}
+        rv = self.client.post(
+            "/api/v1/sqllab/estimate/",
+            json=data,
+        )
+        failed_resp = {"message": {"sql": ["Missing data for required field."]}}
+        resp_data = json.loads(rv.data.decode("utf-8"))
+        self.assertDictEqual(resp_data, failed_resp)
+        self.assertEqual(rv.status_code, 400)
+
+    def test_estimate_valid_request(self):
+        self.login()
+
+        formatter_response = [
+            {
+                "value": 100,
+            }
+        ]
+
+        db_mock = mock.Mock()
+        db_mock.db_engine_spec = mock.Mock()
+        db_mock.db_engine_spec.estimate_query_cost = mock.Mock(return_value=100)
+        db_mock.db_engine_spec.query_cost_formatter = mock.Mock(
+            return_value=formatter_response
+        )
+
+        with mock.patch("superset.sqllab.commands.estimate.db") as mock_superset_db:
+            mock_superset_db.session.query().get.return_value = db_mock
+
+            data = {"database_id": 1, "sql": "SELECT 1"}
+            rv = self.client.post(
+                "/api/v1/sqllab/estimate/",
+                json=data,
+            )
+
+        success_resp = {"result": formatter_response}
+        resp_data = json.loads(rv.data.decode("utf-8"))
+        self.assertDictEqual(resp_data, success_resp)
+        self.assertEqual(rv.status_code, 200)
+
     @mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False)
     def test_execute_required_params(self):
         self.login()
diff --git a/tests/integration_tests/sql_lab/commands_tests.py b/tests/integration_tests/sql_lab/commands_tests.py
index cf0aebf001..3d505ee2f5 100644
--- a/tests/integration_tests/sql_lab/commands_tests.py
+++ b/tests/integration_tests/sql_lab/commands_tests.py
@@ -19,25 +19,98 @@ from unittest.mock import Mock, patch
 
 import pandas as pd
 import pytest
+from flask_babel import gettext as __
 
-from superset import db, sql_lab
+from superset import app, db, sql_lab
 from superset.common.db_query_status import QueryStatus
-from superset.errors import ErrorLevel, SupersetErrorType
+from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.exceptions import (
     SerializationError,
-    SupersetError,
     SupersetErrorException,
     SupersetSecurityException,
+    SupersetTimeoutException,
 )
 from superset.models.core import Database
 from superset.models.sql_lab import Query
-from superset.sqllab.commands import export, results
+from superset.sqllab.commands import estimate, export, results
 from superset.sqllab.limiting_factor import LimitingFactor
+from superset.sqllab.schemas import EstimateQueryCostSchema
 from superset.utils import core as utils
 from superset.utils.database import get_example_database
 from tests.integration_tests.base_tests import SupersetTestCase
 
 
+class TestQueryEstimationCommand(SupersetTestCase):
+    def test_validation_no_database(self) -> None:
+        params = {"database_id": 1, "sql": "SELECT 1"}
+        schema = EstimateQueryCostSchema()
+        data: EstimateQueryCostSchema = schema.dump(params)
+        command = estimate.QueryEstimationCommand(data)
+
+        with mock.patch("superset.sqllab.commands.estimate.db") as mock_superset_db:
+            mock_superset_db.session.query().get.return_value = None
+            with pytest.raises(SupersetErrorException) as ex_info:
+                command.validate()
+            assert (
+                ex_info.value.error.error_type
+                == SupersetErrorType.RESULTS_BACKEND_ERROR
+            )
+
+    @patch("superset.tasks.scheduler.is_feature_enabled")
+    def test_run_timeout(self, is_feature_enabled) -> None:
+        params = {"database_id": 1, "sql": "SELECT 1", "template_params": {"temp": 123}}
+        schema = EstimateQueryCostSchema()
+        data: EstimateQueryCostSchema = schema.dump(params)
+        command = estimate.QueryEstimationCommand(data)
+
+        db_mock = mock.Mock()
+        db_mock.db_engine_spec = mock.Mock()
+        db_mock.db_engine_spec.estimate_query_cost = mock.Mock(
+            side_effect=SupersetTimeoutException(
+                error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
+                message=(
+                    "Please check your connection details and database settings, "
+                    "and ensure that your database is accepting connections, "
+                    "then try connecting again."
+                ),
+                level=ErrorLevel.ERROR,
+            )
+        )
+        db_mock.db_engine_spec.query_cost_formatter = mock.Mock(return_value=None)
+        is_feature_enabled.return_value = False
+
+        with mock.patch("superset.sqllab.commands.estimate.db") as mock_superset_db:
+            mock_superset_db.session.query().get.return_value = db_mock
+            with pytest.raises(SupersetErrorException) as ex_info:
+                command.run()
+            assert (
+                ex_info.value.error.error_type == SupersetErrorType.SQLLAB_TIMEOUT_ERROR
+            )
+            assert ex_info.value.error.message == __(
+                "The query estimation was killed after %(sqllab_timeout)s seconds. It might "
+                "be too complex, or the database might be under heavy load.",
+                sqllab_timeout=app.config["SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT"],
+            )
+
+    def test_run_success(self) -> None:
+        params = {"database_id": 1, "sql": "SELECT 1"}
+        schema = EstimateQueryCostSchema()
+        data: EstimateQueryCostSchema = schema.dump(params)
+        command = estimate.QueryEstimationCommand(data)
+
+        payload = {"value": 100}
+
+        db_mock = mock.Mock()
+        db_mock.db_engine_spec = mock.Mock()
+        db_mock.db_engine_spec.estimate_query_cost = mock.Mock(return_value=100)
+        db_mock.db_engine_spec.query_cost_formatter = mock.Mock(return_value=payload)
+
+        with mock.patch("superset.sqllab.commands.estimate.db") as mock_superset_db:
+            mock_superset_db.session.query().get.return_value = db_mock
+            result = command.run()
+            assert result == payload
+
+
 class TestSqlResultExportCommand(SupersetTestCase):
     @pytest.fixture()
     def create_database_and_query(self):