You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2022/12/24 04:31:57 UTC

[superset] branch master updated: feat(trino): support early cancellation of queries (#22498)

This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new b6d39d194c feat(trino): support early cancellation of queries (#22498)
b6d39d194c is described below

commit b6d39d194c90dbbf0050bb3d32d2e1a513dfc0a6
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Sat Dec 24 06:31:46 2022 +0200

    feat(trino): support early cancellation of queries (#22498)
---
 superset-frontend/src/SqlLab/actions/sqlLab.js     |  6 ++-
 .../src/SqlLab/actions/sqlLab.test.js              | 31 ++++++++++++
 .../src/SqlLab/components/ResultSet/index.tsx      | 26 +++++-----
 .../SqlLab/components/SqlEditorTabHeader/index.tsx |  4 +-
 superset-frontend/src/SqlLab/reducers/sqlLab.js    | 29 ++++++-----
 .../src/views/CRUD/data/query/QueryList.test.tsx   |  3 +-
 .../src/views/CRUD/data/query/QueryList.tsx        | 34 ++++++++++---
 .../CRUD/data/query/QueryPreviewModal.test.tsx     |  3 +-
 superset-frontend/src/views/CRUD/types.ts          | 10 +---
 superset/constants.py                              |  3 ++
 superset/db_engine_specs/base.py                   | 52 ++++++++++++-------
 superset/db_engine_specs/hive.py                   |  2 +-
 superset/db_engine_specs/presto.py                 |  2 +-
 superset/db_engine_specs/trino.py                  | 23 ++++++++-
 superset/sql_lab.py                                | 18 +++++--
 tests/unit_tests/db_engine_specs/test_trino.py     | 59 ++++++++++++++++++++++
 16 files changed, 231 insertions(+), 74 deletions(-)

diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js
index a58630cda1..12487d1a94 100644
--- a/superset-frontend/src/SqlLab/actions/sqlLab.js
+++ b/superset-frontend/src/SqlLab/actions/sqlLab.js
@@ -17,7 +17,7 @@
  * under the License.
  */
 import shortid from 'shortid';
-import { t, SupersetClient } from '@superset-ui/core';
+import { SupersetClient, t } from '@superset-ui/core';
 import invert from 'lodash/invert';
 import mapKeys from 'lodash/mapKeys';
 import { isFeatureEnabled, FeatureFlag } from 'src/featureFlags';
@@ -229,11 +229,13 @@ export function startQuery(query) {
 
 export function querySuccess(query, results) {
   return function (dispatch) {
+    const sqlEditorId = results?.query?.sqlEditorId;
     const sync =
+      sqlEditorId &&
       !query.isDataPreview &&
       isFeatureEnabled(FeatureFlag.SQLLAB_BACKEND_PERSISTENCE)
         ? SupersetClient.put({
-            endpoint: encodeURI(`/tabstateview/${results.query.sqlEditorId}`),
+            endpoint: encodeURI(`/tabstateview/${sqlEditorId}`),
             postPayload: { latest_query_id: query.id },
           })
         : Promise.resolve();
diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.test.js b/superset-frontend/src/SqlLab/actions/sqlLab.test.js
index 7792f1da8a..acc79031ed 100644
--- a/superset-frontend/src/SqlLab/actions/sqlLab.test.js
+++ b/superset-frontend/src/SqlLab/actions/sqlLab.test.js
@@ -30,6 +30,7 @@ import {
   initialState,
   queryId,
 } from 'src/SqlLab/fixtures';
+import { QueryState } from '@superset-ui/core';
 
 const middlewares = [thunk];
 const mockStore = configureMockStore(middlewares);
@@ -502,6 +503,7 @@ describe('async actions', () => {
         const results = {
           data: mockBigNumber,
           query: { sqlEditorId: 'abcd' },
+          status: QueryState.SUCCESS,
           query_id: 'efgh',
         };
         fetchMock.get(fetchQueryEndpoint, JSON.stringify(results), {
@@ -525,6 +527,35 @@ describe('async actions', () => {
           expect(fetchMock.calls(updateTabStateEndpoint)).toHaveLength(1);
         });
       });
+
+      it("doesn't update the tab state in the backend on stoppped query", () => {
+        expect.assertions(2);
+
+        const results = {
+          status: QueryState.STOPPED,
+          query_id: 'efgh',
+        };
+        fetchMock.get(fetchQueryEndpoint, JSON.stringify(results), {
+          overwriteRoutes: true,
+        });
+        const store = mockStore({});
+        const expectedActions = [
+          {
+            type: actions.REQUEST_QUERY_RESULTS,
+            query,
+          },
+          // missing below
+          {
+            type: actions.QUERY_SUCCESS,
+            query,
+            results,
+          },
+        ];
+        return store.dispatch(actions.fetchQueryResults(query)).then(() => {
+          expect(store.getActions()).toEqual(expectedActions);
+          expect(fetchMock.calls(updateTabStateEndpoint)).toHaveLength(0);
+        });
+      });
     });
 
     describe('addQueryEditor', () => {
diff --git a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx
index 0d61d8ddab..10cdd8a39e 100644
--- a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx
+++ b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx
@@ -16,13 +16,13 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-import React, { useState, useEffect, useCallback } from 'react';
+import React, { useCallback, useEffect, useState } from 'react';
 import { useDispatch } from 'react-redux';
 import ButtonGroup from 'src/components/ButtonGroup';
 import Alert from 'src/components/Alert';
 import Button from 'src/components/Button';
 import shortid from 'shortid';
-import { styled, t, QueryResponse } from '@superset-ui/core';
+import { QueryResponse, QueryState, styled, t } from '@superset-ui/core';
 import { usePrevious } from 'src/hooks/usePrevious';
 import ErrorMessageWithStackTrace from 'src/components/ErrorMessage/ErrorMessageWithStackTrace';
 import {
@@ -43,9 +43,9 @@ import CopyToClipboard from 'src/components/CopyToClipboard';
 import { addDangerToast } from 'src/components/MessageToasts/actions';
 import { prepareCopyToClipboardTabularData } from 'src/utils/common';
 import {
-  CtasEnum,
-  clearQueryResults,
   addQueryEditor,
+  clearQueryResults,
+  CtasEnum,
   fetchQueryResults,
   reFetchQueryResults,
   reRunQuery,
@@ -387,8 +387,8 @@ const ResultSet = ({
   let trackingUrl;
   if (
     query.trackingUrl &&
-    query.state !== 'success' &&
-    query.state !== 'fetching'
+    query.state !== QueryState.SUCCESS &&
+    query.state !== QueryState.FETCHING
   ) {
     trackingUrl = (
       <Button
@@ -397,7 +397,9 @@ const ResultSet = ({
         href={query.trackingUrl}
         target="_blank"
       >
-        {query.state === 'running' ? t('Track job') : t('See query details')}
+        {query.state === QueryState.RUNNING
+          ? t('Track job')
+          : t('See query details')}
       </Button>
     );
   }
@@ -406,11 +408,11 @@ const ResultSet = ({
     sql = <HighlightedSql sql={query.sql} />;
   }
 
-  if (query.state === 'stopped') {
+  if (query.state === QueryState.STOPPED) {
     return <Alert type="warning" message={t('Query was stopped')} />;
   }
 
-  if (query.state === 'failed') {
+  if (query.state === QueryState.FAILED) {
     return (
       <ResultlessStyles>
         <ErrorMessageWithStackTrace
@@ -426,7 +428,7 @@ const ResultSet = ({
     );
   }
 
-  if (query.state === 'success' && query.ctas) {
+  if (query.state === QueryState.SUCCESS && query.ctas) {
     const { tempSchema, tempTable } = query;
     let object = 'Table';
     if (query.ctas_method === CtasEnum.VIEW) {
@@ -465,7 +467,7 @@ const ResultSet = ({
     );
   }
 
-  if (query.state === 'success' && query.results) {
+  if (query.state === QueryState.SUCCESS && query.results) {
     const { results } = query;
     // Accounts for offset needed for height of ResultSetRowsReturned component if !limitReached
     const rowMessageHeight = !limitReached ? 32 : 0;
@@ -508,7 +510,7 @@ const ResultSet = ({
     }
   }
 
-  if (query.cached || (query.state === 'success' && !query.results)) {
+  if (query.cached || (query.state === QueryState.SUCCESS && !query.results)) {
     if (query.isDataPreview) {
       return (
         <Button
diff --git a/superset-frontend/src/SqlLab/components/SqlEditorTabHeader/index.tsx b/superset-frontend/src/SqlLab/components/SqlEditorTabHeader/index.tsx
index debacbb0d3..8e4372d109 100644
--- a/superset-frontend/src/SqlLab/components/SqlEditorTabHeader/index.tsx
+++ b/superset-frontend/src/SqlLab/components/SqlEditorTabHeader/index.tsx
@@ -53,7 +53,7 @@ const SqlEditorTabHeader: React.FC<Props> = ({ queryEditor }) => {
     }),
     shallowEqual,
   );
-  const queryStatus = useSelector<SqlLabRootState, QueryState>(
+  const queryState = useSelector<SqlLabRootState, QueryState>(
     ({ sqlLab }) => sqlLab.queries[qe.latestQueryId || '']?.state || '',
   );
   const dispatch = useDispatch();
@@ -139,7 +139,7 @@ const SqlEditorTabHeader: React.FC<Props> = ({ queryEditor }) => {
           </Menu>
         }
       />
-      <TabTitle>{qe.name}</TabTitle> <TabStatusIcon tabState={queryStatus} />{' '}
+      <TabTitle>{qe.name}</TabTitle> <TabStatusIcon tabState={queryState} />{' '}
     </TabTitleWrapper>
   );
 };
diff --git a/superset-frontend/src/SqlLab/reducers/sqlLab.js b/superset-frontend/src/SqlLab/reducers/sqlLab.js
index ed103a2afe..478487d6e2 100644
--- a/superset-frontend/src/SqlLab/reducers/sqlLab.js
+++ b/superset-frontend/src/SqlLab/reducers/sqlLab.js
@@ -16,8 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-import { t } from '@superset-ui/core';
-
+import { QueryState, t } from '@superset-ui/core';
 import getInitialState from './getInitialState';
 import * as actions from '../actions/sqlLab';
 import { now } from '../../utils/dates';
@@ -391,7 +390,7 @@ export default function sqlLabReducer(state = {}, action) {
     },
     [actions.STOP_QUERY]() {
       return alterInObject(state, 'queries', action.query, {
-        state: 'stopped',
+        state: QueryState.STOPPED,
         results: [],
       });
     },
@@ -405,12 +404,16 @@ export default function sqlLabReducer(state = {}, action) {
     },
     [actions.REQUEST_QUERY_RESULTS]() {
       return alterInObject(state, 'queries', action.query, {
-        state: 'fetching',
+        state: QueryState.FETCHING,
       });
     },
     [actions.QUERY_SUCCESS]() {
-      // prevent race condition were query succeeds shortly after being canceled
-      if (action.query.state === 'stopped') {
+      // prevent race condition where query succeeds shortly after being canceled
+      // or the final result was unsuccessful
+      if (
+        action.query.state === QueryState.STOPPED ||
+        action.results.status !== QueryState.SUCCESS
+      ) {
         return state;
       }
       const alts = {
@@ -418,7 +421,7 @@ export default function sqlLabReducer(state = {}, action) {
         progress: 100,
         results: action.results,
         rows: action?.results?.query?.rows || 0,
-        state: 'success',
+        state: QueryState.SUCCESS,
         limitingFactor: action?.results?.query?.limitingFactor,
         tempSchema: action?.results?.query?.tempSchema,
         tempTable: action?.results?.query?.tempTable,
@@ -434,11 +437,11 @@ export default function sqlLabReducer(state = {}, action) {
       return alterInObject(state, 'queries', action.query, alts);
     },
     [actions.QUERY_FAILED]() {
-      if (action.query.state === 'stopped') {
+      if (action.query.state === QueryState.STOPPED) {
         return state;
       }
       const alts = {
-        state: 'failed',
+        state: QueryState.FAILED,
         errors: action.errors,
         errorMessage: action.msg,
         endDttm: now(),
@@ -723,8 +726,8 @@ export default function sqlLabReducer(state = {}, action) {
       Object.entries(action.alteredQueries).forEach(([id, changedQuery]) => {
         if (
           !state.queries.hasOwnProperty(id) ||
-          (state.queries[id].state !== 'stopped' &&
-            state.queries[id].state !== 'failed')
+          (state.queries[id].state !== QueryState.STOPPED &&
+            state.queries[id].state !== QueryState.FAILED)
         ) {
           if (changedQuery.changedOn > queriesLastUpdate) {
             queriesLastUpdate = changedQuery.changedOn;
@@ -738,8 +741,8 @@ export default function sqlLabReducer(state = {}, action) {
             // because of async behavior, sql lab may still poll a couple of seconds
             // when it started fetching or finished rendering results
             state:
-              currentState === 'success' &&
-              ['fetching', 'success'].includes(prevState)
+              currentState === QueryState.SUCCESS &&
+              [QueryState.FETCHING, QueryState.SUCCESS].includes(prevState)
                 ? prevState
                 : currentState,
           };
diff --git a/superset-frontend/src/views/CRUD/data/query/QueryList.test.tsx b/superset-frontend/src/views/CRUD/data/query/QueryList.test.tsx
index eaaa75a1cb..be28d7e2df 100644
--- a/superset-frontend/src/views/CRUD/data/query/QueryList.test.tsx
+++ b/superset-frontend/src/views/CRUD/data/query/QueryList.test.tsx
@@ -33,6 +33,7 @@ import ListView from 'src/components/ListView';
 import Filters from 'src/components/ListView/Filters';
 import SyntaxHighlighter from 'react-syntax-highlighter/dist/cjs/light';
 import SubMenu from 'src/views/components/SubMenu';
+import { QueryState } from '@superset-ui/core';
 
 // store needed for withToasts
 const mockStore = configureStore([thunk]);
@@ -54,7 +55,7 @@ const mockQueries: QueryObject[] = [...new Array(3)].map((_, i) => ({
     { schema: 'foo', table: 'table' },
     { schema: 'bar', table: 'table_2' },
   ],
-  status: 'success',
+  status: QueryState.SUCCESS,
   tab_name: 'Main Tab',
   user: {
     first_name: 'cool',
diff --git a/superset-frontend/src/views/CRUD/data/query/QueryList.tsx b/superset-frontend/src/views/CRUD/data/query/QueryList.tsx
index bbee625092..dbe8e259da 100644
--- a/superset-frontend/src/views/CRUD/data/query/QueryList.tsx
+++ b/superset-frontend/src/views/CRUD/data/query/QueryList.tsx
@@ -17,7 +17,13 @@
  * under the License.
  */
 import React, { useMemo, useState, useCallback, ReactElement } from 'react';
-import { SupersetClient, t, styled, useTheme } from '@superset-ui/core';
+import {
+  QueryState,
+  styled,
+  SupersetClient,
+  t,
+  useTheme,
+} from '@superset-ui/core';
 import moment from 'moment';
 import {
   createFetchRelated,
@@ -127,7 +133,13 @@ function QueryList({ addDangerToast }: QueryListProps) {
           row: {
             original: { status },
           },
-        }: any) => {
+        }: {
+          row: {
+            original: {
+              status: QueryState;
+            };
+          };
+        }) => {
           const statusConfig: {
             name: ReactElement | null;
             label: string;
@@ -135,33 +147,39 @@ function QueryList({ addDangerToast }: QueryListProps) {
             name: null,
             label: '',
           };
-          if (status === 'success') {
+          if (status === QueryState.SUCCESS) {
             statusConfig.name = (
               <Icons.Check iconColor={theme.colors.success.base} />
             );
             statusConfig.label = t('Success');
-          } else if (status === 'failed' || status === 'stopped') {
+          } else if (
+            status === QueryState.FAILED ||
+            status === QueryState.STOPPED
+          ) {
             statusConfig.name = (
               <Icons.XSmall
                 iconColor={
-                  status === 'failed'
+                  status === QueryState.FAILED
                     ? theme.colors.error.base
                     : theme.colors.grayscale.base
                 }
               />
             );
             statusConfig.label = t('Failed');
-          } else if (status === 'running') {
+          } else if (status === QueryState.RUNNING) {
             statusConfig.name = (
               <Icons.Running iconColor={theme.colors.primary.base} />
             );
             statusConfig.label = t('Running');
-          } else if (status === 'timed_out') {
+          } else if (status === QueryState.TIMED_OUT) {
             statusConfig.name = (
               <Icons.Offline iconColor={theme.colors.grayscale.light1} />
             );
             statusConfig.label = t('Offline');
-          } else if (status === 'scheduled' || status === 'pending') {
+          } else if (
+            status === QueryState.SCHEDULED ||
+            status === QueryState.PENDING
+          ) {
             statusConfig.name = (
               <Icons.Queued iconColor={theme.colors.grayscale.base} />
             );
diff --git a/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.test.tsx b/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.test.tsx
index 7a85e4c292..96498f6e69 100644
--- a/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.test.tsx
+++ b/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.test.tsx
@@ -27,6 +27,7 @@ import QueryPreviewModal from 'src/views/CRUD/data/query/QueryPreviewModal';
 import { QueryObject } from 'src/views/CRUD/types';
 import SyntaxHighlighter from 'react-syntax-highlighter/dist/cjs/light';
 import { act } from 'react-dom/test-utils';
+import { QueryState } from '@superset-ui/core';
 
 // store needed for withToasts
 const mockStore = configureStore([thunk]);
@@ -46,7 +47,7 @@ const mockQueries: QueryObject[] = [...new Array(3)].map((_, i) => ({
     { schema: 'foo', table: 'table' },
     { schema: 'bar', table: 'table_2' },
   ],
-  status: 'success',
+  status: QueryState.SUCCESS,
   tab_name: 'Main Tab',
   user: {
     first_name: 'cool',
diff --git a/superset-frontend/src/views/CRUD/types.ts b/superset-frontend/src/views/CRUD/types.ts
index 86784c9704..441fc9bde4 100644
--- a/superset-frontend/src/views/CRUD/types.ts
+++ b/superset-frontend/src/views/CRUD/types.ts
@@ -16,6 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+import { QueryState } from '@superset-ui/core';
 import { User } from 'src/types/bootstrapTypes';
 import Database from 'src/types/Database';
 import Owner from 'src/types/Owner';
@@ -94,14 +95,7 @@ export interface QueryObject {
   sql: string;
   executed_sql: string | null;
   sql_tables?: { catalog?: string; schema: string; table: string }[];
-  status:
-    | 'success'
-    | 'failed'
-    | 'stopped'
-    | 'running'
-    | 'timed_out'
-    | 'scheduled'
-    | 'pending';
+  status: QueryState;
   tab_name: string;
   user: {
     first_name: string;
diff --git a/superset/constants.py b/superset/constants.py
index 7d759acf67..5091d65a43 100644
--- a/superset/constants.py
+++ b/superset/constants.py
@@ -34,6 +34,9 @@ PASSWORD_MASK = "X" * 10
 
 NO_TIME_RANGE = "No filter"
 
+QUERY_CANCEL_KEY = "cancel_query"
+QUERY_EARLY_CANCEL_KEY = "early_cancel_query"
+
 
 class RouteMethod:  # pylint: disable=too-few-public-methods
     """
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 3e2c0f56ba..43dd607876 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -15,6 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=too-many-lines
+
+from __future__ import annotations
+
 import json
 import logging
 import re
@@ -478,7 +481,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     @classmethod
     def get_engine(
         cls,
-        database: "Database",
+        database: Database,
         schema: Optional[str] = None,
         source: Optional[utils.QuerySource] = None,
     ) -> ContextManager[Engine]:
@@ -733,7 +736,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     @classmethod
     def extra_table_metadata(  # pylint: disable=unused-argument
         cls,
-        database: "Database",
+        database: Database,
         table_name: str,
         schema_name: Optional[str],
     ) -> Dict[str, Any]:
@@ -750,7 +753,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
 
     @classmethod
     def apply_limit_to_sql(
-        cls, sql: str, limit: int, database: "Database", force: bool = False
+        cls, sql: str, limit: int, database: Database, force: bool = False
     ) -> str:
         """
         Alters the SQL statement to apply a LIMIT clause
@@ -892,7 +895,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     @classmethod
     def df_to_sql(
         cls,
-        database: "Database",
+        database: Database,
         table: Table,
         df: pd.DataFrame,
         to_sql_kwargs: Dict[str, Any],
@@ -939,7 +942,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         return None
 
     @classmethod
-    def handle_cursor(cls, cursor: Any, query: "Query", session: Session) -> None:
+    def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
         """Handle a live cursor between the execute and fetchall calls
 
         The flow works without this method doing anything, but it allows
@@ -1031,7 +1034,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     @classmethod
     def get_table_names(  # pylint: disable=unused-argument
         cls,
-        database: "Database",
+        database: Database,
         inspector: Inspector,
         schema: Optional[str],
     ) -> Set[str]:
@@ -1059,7 +1062,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     @classmethod
     def get_view_names(  # pylint: disable=unused-argument
         cls,
-        database: "Database",
+        database: Database,
         inspector: Inspector,
         schema: Optional[str],
     ) -> Set[str]:
@@ -1125,7 +1128,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     @classmethod
     def get_metrics(  # pylint: disable=unused-argument
         cls,
-        database: "Database",
+        database: Database,
         inspector: Inspector,
         table_name: str,
         schema: Optional[str],
@@ -1147,7 +1150,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         cls,
         table_name: str,
         schema: Optional[str],
-        database: "Database",
+        database: Database,
         query: Select,
         columns: Optional[List[Dict[str, str]]] = None,
     ) -> Optional[Select]:
@@ -1172,7 +1175,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     @classmethod
     def select_star(  # pylint: disable=too-many-arguments,too-many-locals
         cls,
-        database: "Database",
+        database: Database,
         table_name: str,
         engine: Engine,
         schema: Optional[str] = None,
@@ -1251,7 +1254,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         raise Exception("Database does not support cost estimation")
 
     @classmethod
-    def process_statement(cls, statement: str, database: "Database") -> str:
+    def process_statement(cls, statement: str, database: Database) -> str:
         """
         Process a SQL statement by stripping and mutating it.
 
@@ -1275,7 +1278,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     @classmethod
     def estimate_query_cost(
         cls,
-        database: "Database",
+        database: Database,
         schema: str,
         sql: str,
         source: Optional[utils.QuerySource] = None,
@@ -1471,7 +1474,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     @classmethod
     def get_function_names(  # pylint: disable=unused-argument
         cls,
-        database: "Database",
+        database: Database,
     ) -> List[str]:
         """
         Get a list of function names that are able to be called on the database.
@@ -1496,7 +1499,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
 
     @staticmethod
     def mutate_db_for_connection_test(  # pylint: disable=unused-argument
-        database: "Database",
+        database: Database,
     ) -> None:
         """
         Some databases require passing additional parameters for validating database
@@ -1508,7 +1511,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         return None
 
     @staticmethod
-    def get_extra_params(database: "Database") -> Dict[str, Any]:
+    def get_extra_params(database: Database) -> Dict[str, Any]:
         """
         Some databases require adding elements to connection parameters,
         like passing certificates to `extra`. This can be done here.
@@ -1527,7 +1530,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
 
     @staticmethod
     def update_params_from_encrypted_extra(  # pylint: disable=invalid-name
-        database: "Database", params: Dict[str, Any]
+        database: Database, params: Dict[str, Any]
     ) -> None:
         """
         Some databases require some sensitive information which do not conform to
@@ -1589,11 +1592,22 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
             )
         return None
 
+    # pylint: disable=unused-argument
+    @classmethod
+    def prepare_cancel_query(cls, query: Query, session: Session) -> None:
+        """
+        Some databases may acquire the query cancelation id after the query
+        cancelation request has been received. For those cases, the db engine spec
+        can record the cancelation intent so that the query can either be stopped
+        prior to execution, or canceled once the query id is acquired.
+        """
+        return None
+
     @classmethod
     def has_implicit_cancel(cls) -> bool:
         """
         Return True if the live cursor handles the implicit cancelation of the query,
-        False otherise.
+        False otherwise.
 
         :return: Whether the live cursor implicitly cancels the query
         :see: handle_cursor
@@ -1605,7 +1619,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     def get_cancel_query_id(  # pylint: disable=unused-argument
         cls,
         cursor: Any,
-        query: "Query",
+        query: Query,
     ) -> Optional[str]:
         """
         Select identifiers from the database engine that uniquely identifies the
@@ -1623,7 +1637,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     def cancel_query(  # pylint: disable=unused-argument
         cls,
         cursor: Any,
-        query: "Query",
+        query: Query,
         cancel_query_id: str,
     ) -> bool:
         """
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index 60786e417b..c699089767 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -559,7 +559,7 @@ class HiveEngineSpec(PrestoEngineSpec):
     def has_implicit_cancel(cls) -> bool:
         """
         Return True if the live cursor handles the implicit cancelation of the query,
-        False otherise.
+        False otherwise.
 
         :return: Whether the live cursor implicitly cancels the query
         :see: handle_cursor
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 2a3acb8bb5..ba5df3e28d 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -1307,7 +1307,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
     def has_implicit_cancel(cls) -> bool:
         """
         Return True if the live cursor handles the implicit cancelation of the query,
-        False otherise.
+        False otherwise.
 
         :return: Whether the live cursor implicitly cancels the query
         :see: handle_cursor
diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py
index 2a1d8cc639..3b23f79873 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -26,7 +26,7 @@ from flask import current_app
 from sqlalchemy.engine.url import URL
 from sqlalchemy.orm import Session
 
-from superset.constants import USER_AGENT
+from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
 from superset.databases.utils import make_url_safe
 from superset.db_engine_specs.base import BaseEngineSpec
 from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
@@ -181,11 +181,30 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
             query.tracking_url = tracking_url
 
         # Adds the executed query id to the extra payload so the query can be cancelled
-        query.set_extra_json_key("cancel_query", cursor.stats["queryId"])
+        query.set_extra_json_key(
+            key=QUERY_CANCEL_KEY,
+            value=(cancel_query_id := cursor.stats["queryId"]),
+        )
 
         session.commit()
+
+        # if query cancelation was requested prior to the handle_cursor call, but
+        # the query was still executed, trigger the actual query cancelation now
+        if query.extra.get(QUERY_EARLY_CANCEL_KEY):
+            cls.cancel_query(
+                cursor=cursor,
+                query=query,
+                cancel_query_id=cancel_query_id,
+            )
+
         super().handle_cursor(cursor=cursor, query=query, session=session)
 
+    @classmethod
+    def prepare_cancel_query(cls, query: Query, session: Session) -> None:
+        if QUERY_CANCEL_KEY not in query.extra:
+            query.set_extra_json_key(QUERY_EARLY_CANCEL_KEY, True)
+            session.commit()
+
     @classmethod
     def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool:
         """
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 6d9903c8f0..143806c7f5 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -33,12 +33,14 @@ from sqlalchemy.orm import Session
 
 from superset import (
     app,
+    db,
     is_feature_enabled,
     results_backend,
     results_backend_use_msgpack,
     security_manager,
 )
 from superset.common.db_query_status import QueryStatus
+from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
 from superset.dataframe import df_to_records
 from superset.db_engine_specs import BaseEngineSpec
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
@@ -69,7 +71,6 @@ SQLLAB_CTAS_NO_LIMIT = config["SQLLAB_CTAS_NO_LIMIT"]
 SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"]
 log_query = config["QUERY_LOGGER"]
 logger = logging.getLogger(__name__)
-cancel_query_key = "cancel_query"
 
 
 class SqlLabException(Exception):
@@ -473,7 +474,7 @@ def execute_sql_statements(  # pylint: disable=too-many-arguments, too-many-loca
             cursor = conn.cursor()
             cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
             if cancel_query_id is not None:
-                query.set_extra_json_key(cancel_query_key, cancel_query_id)
+                query.set_extra_json_key(QUERY_CANCEL_KEY, cancel_query_id)
                 session.commit()
             statement_count = len(statements)
             for i, statement in enumerate(statements):
@@ -613,7 +614,7 @@ def cancel_query(query: Query) -> bool:
     """
     Cancel a running query.
 
-    Note some engines implicitly handle the cancelation of a query and thus no expliicit
+    Note some engines implicitly handle the cancelation of a query and thus no explicit
     action is required.
 
     :param query: Query to cancel
@@ -623,7 +624,16 @@ def cancel_query(query: Query) -> bool:
     if query.database.db_engine_spec.has_implicit_cancel():
         return True
 
-    cancel_query_id = query.extra.get(cancel_query_key)
+    # Some databases may need to make preparations for query cancellation
+    query.database.db_engine_spec.prepare_cancel_query(query, db.session)
+
+    if query.extra.get(QUERY_EARLY_CANCEL_KEY):
+        # Query has been cancelled prior to being able to set the cancel key.
+        # This can happen if the query cancellation key can only be acquired after the
+        # query has been executed
+        return True
+
+    cancel_query_id = query.extra.get(QUERY_CANCEL_KEY)
     if cancel_query_id is None:
         return False
 
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py
index 6a77e63236..382b65ce52 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -15,8 +15,15 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=unused-argument, import-outside-toplevel, protected-access
+import json
+from typing import Any, Dict
 from unittest import mock
 
+import pytest
+from pytest_mock import MockerFixture
+
+from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
+
 
 @mock.patch("sqlalchemy.engine.Engine.connect")
 def test_cancel_query_success(engine_mock: mock.Mock) -> None:
@@ -36,3 +43,55 @@ def test_cancel_query_failed(engine_mock: mock.Mock) -> None:
     query = Query()
     cursor_mock = engine_mock.raiseError.side_effect = Exception()
     assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is False
+
+
+@pytest.mark.parametrize(
+    "initial_extra,final_extra",
+    [
+        ({}, {QUERY_EARLY_CANCEL_KEY: True}),
+        ({QUERY_CANCEL_KEY: "my_key"}, {QUERY_CANCEL_KEY: "my_key"}),
+    ],
+)
+def test_prepare_cancel_query(
+    initial_extra: Dict[str, Any],
+    final_extra: Dict[str, Any],
+    mocker: MockerFixture,
+) -> None:
+    from superset.db_engine_specs.trino import TrinoEngineSpec
+    from superset.models.sql_lab import Query
+
+    session_mock = mocker.MagicMock()
+    query = Query(extra_json=json.dumps(initial_extra))
+    TrinoEngineSpec.prepare_cancel_query(query=query, session=session_mock)
+    assert query.extra == final_extra
+
+
+@pytest.mark.parametrize("cancel_early", [True, False])
+@mock.patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
+@mock.patch("sqlalchemy.engine.Engine.connect")
+def test_handle_cursor_early_cancel(
+    engine_mock: mock.Mock,
+    cancel_query_mock: mock.Mock,
+    cancel_early: bool,
+    mocker: MockerFixture,
+) -> None:
+    from superset.db_engine_specs.trino import TrinoEngineSpec
+    from superset.models.sql_lab import Query
+
+    query_id = "myQueryId"
+
+    cursor_mock = engine_mock.return_value.__enter__.return_value
+    cursor_mock.stats = {"queryId": query_id}
+    session_mock = mocker.MagicMock()
+
+    query = Query()
+
+    if cancel_early:
+        TrinoEngineSpec.prepare_cancel_query(query=query, session=session_mock)
+
+    TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query, session=session_mock)
+
+    if cancel_early:
+        assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id
+    else:
+        assert cancel_query_mock.call_args is None