You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2022/08/30 12:41:46 UTC

[superset] 08/13: fix(20428): Address-Presto/Trino-Poll-Issue-Refactor (#20434)

This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 1.5
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 094b17e8cc9f345395533aa85f1c8f4388074177
Author: Simon Thelin <si...@gmail.com>
AuthorDate: Mon Jun 20 00:28:59 2022 +0100

    fix(20428): Address-Presto/Trino-Poll-Issue-Refactor (#20434)
    
    * fix(20428)-Address-Presto/Trino-Poll-Issue-Refacto
    r
    
    Update linter
    
    * Update to only use BaseEngineSpec handle_cursor
    
    * Fix CI
    
    Co-authored-by: John Bodley <45...@users.noreply.github.com>
---
 superset/db_engine_specs/presto.py |  4 --
 superset/db_engine_specs/trino.py  | 97 ++++++++++++--------------------------
 2 files changed, 30 insertions(+), 71 deletions(-)

diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 86b12b8538..49810cdd6c 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -946,11 +946,7 @@ class PrestoEngineSpec(BaseEngineSpec):  # pylint: disable=too-many-public-metho
             sql = f"SHOW CREATE VIEW {schema}.{table}"
             try:
                 cls.execute(cursor, sql)
-                polled = cursor.poll()
 
-                while polled:
-                    time.sleep(0.2)
-                    polled = cursor.poll()
             except DatabaseError:  # not a VIEW
                 return None
             rows = cls.fetch_data(cursor, 1)
diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py
index 4a5e9af01c..1759345069 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -21,9 +21,13 @@ from urllib import parse
 
 import simplejson as json
 from flask import current_app
+from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import make_url, URL
+from sqlalchemy.orm import Session
 
 from superset.db_engine_specs.base import BaseEngineSpec
+from superset.db_engine_specs.presto import PrestoEngineSpec
+from superset.models.sql_lab import Query
 from superset.utils import core as utils
 
 if TYPE_CHECKING:
@@ -133,76 +137,35 @@ class TrinoEngineSpec(BaseEngineSpec):
         return True
 
     @classmethod
-    def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
-        """
-        Run a SQL query that estimates the cost of a given statement.
-
-        :param statement: A single SQL statement
-        :param cursor: Cursor instance
-        :return: JSON response from Trino
-        """
-        sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement}"
-        cursor.execute(sql)
-
-        # the output from Trino is a single column and a single row containing
-        # JSON:
-        #
-        #   {
-        #     ...
-        #     "estimate" : {
-        #       "outputRowCount" : 8.73265878E8,
-        #       "outputSizeInBytes" : 3.41425774958E11,
-        #       "cpuCost" : 3.41425774958E11,
-        #       "maxMemory" : 0.0,
-        #       "networkCost" : 3.41425774958E11
-        #     }
-        #   }
-        result = json.loads(cursor.fetchone()[0])
-        return result
+    def get_table_names(
+        cls,
+        database: "Database",
+        inspector: Inspector,
+        schema: Optional[str],
+    ) -> List[str]:
+        return BaseEngineSpec.get_table_names(
+            database=database,
+            inspector=inspector,
+            schema=schema,
+        )
 
     @classmethod
-    def query_cost_formatter(
-        cls, raw_cost: List[Dict[str, Any]]
-    ) -> List[Dict[str, str]]:
-        """
-        Format cost estimate.
-
-        :param raw_cost: JSON estimate from Trino
-        :return: Human readable cost estimate
-        """
-
-        def humanize(value: Any, suffix: str) -> str:
-            try:
-                value = int(value)
-            except ValueError:
-                return str(value)
-
-            prefixes = ["K", "M", "G", "T", "P", "E", "Z", "Y"]
-            prefix = ""
-            to_next_prefix = 1000
-            while value > to_next_prefix and prefixes:
-                prefix = prefixes.pop(0)
-                value //= to_next_prefix
-
-            return f"{value} {prefix}{suffix}"
-
-        cost = []
-        columns = [
-            ("outputRowCount", "Output count", " rows"),
-            ("outputSizeInBytes", "Output size", "B"),
-            ("cpuCost", "CPU cost", ""),
-            ("maxMemory", "Max memory", "B"),
-            ("networkCost", "Network cost", ""),
-        ]
-        for row in raw_cost:
-            estimate: Dict[str, float] = row.get("estimate", {})
-            statement_cost = {}
-            for key, label, suffix in columns:
-                if key in estimate:
-                    statement_cost[label] = humanize(estimate[key], suffix).strip()
-            cost.append(statement_cost)
+    def get_view_names(
+        cls,
+        database: "Database",
+        inspector: Inspector,
+        schema: Optional[str],
+    ) -> List[str]:
+        return BaseEngineSpec.get_view_names(
+            database=database,
+            inspector=inspector,
+            schema=schema,
+        )
 
-        return cost
+    @classmethod
+    def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
+        """Updates progress information"""
+        BaseEngineSpec.handle_cursor(cursor=cursor, query=query, session=session)
 
     @staticmethod
     def get_extra_params(database: "Database") -> Dict[str, Any]: