You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2022/10/07 17:14:20 UTC

[superset] 11/13: fix: add `get_column` function for Query obj (#21691)

This is an automated email from the ASF dual-hosted git repository.

hugh pushed a commit to branch 2022.39.1
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 42304a4302d46a8f9e7a9c397cf96d77fdfa3d87
Author: Hugh A. Miles II <hu...@gmail.com>
AuthorDate: Wed Oct 5 18:25:44 2022 -0400

    fix: add `get_column` function for Query obj (#21691)
    
    (cherry picked from commit 51c54b3c9bc69273bb5da004b8f9a7ae202de8fd)
---
 superset/common/query_context_processor.py |  3 ++-
 superset/models/sql_lab.py                 | 15 +++++++++++----
 2 files changed, 13 insertions(+), 5 deletions(-)

diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py
index 98aaecebd9..01259ede1d 100644
--- a/superset/common/query_context_processor.py
+++ b/superset/common/query_context_processor.py
@@ -266,7 +266,8 @@ class QueryContextProcessor:
             # Query datasource didn't support `get_column`
             and hasattr(datasource, "get_column")
             and (col := datasource.get_column(label))
-            and col.is_dttm
+            # todo(hugh) standardize column object in Query datasource
+            and (col.get("is_dttm") if isinstance(col, dict) else col.is_dttm)
         )
         dttm_cols = [
             DateColumn(
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 408bc708df..f75973ad17 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -51,7 +51,6 @@ from superset.models.helpers import (
 )
 from superset.sql_parse import CtasMethod, ParsedQuery, Table
 from superset.sqllab.limiting_factor import LimitingFactor
-from superset.superset_typing import ResultSetColumnType
 from superset.utils.core import GenericDataType, QueryStatus, user_label
 
 if TYPE_CHECKING:
@@ -183,7 +182,7 @@ class Query(
         return list(ParsedQuery(self.sql).tables)
 
     @property
-    def columns(self) -> List[ResultSetColumnType]:
+    def columns(self) -> List[Dict[str, Any]]:
         bool_types = ("BOOL",)
         num_types = (
             "DOUBLE",
@@ -217,7 +216,7 @@ class Query(
             computed_column["column_name"] = col.get("name")
             computed_column["groupby"] = True
             columns.append(computed_column)
-        return columns  # type: ignore
+        return columns
 
     @property
     def data(self) -> Dict[str, Any]:
@@ -288,7 +287,7 @@ class Query(
     def main_dttm_col(self) -> Optional[str]:
         for col in self.columns:
             if col.get("is_dttm"):
-                return col.get("column_name")  # type: ignore
+                return col.get("column_name")
         return None
 
     @property
@@ -332,6 +331,14 @@ class Query(
     def tracking_url(self, value: str) -> None:
         self.tracking_url_raw = value
 
+    def get_column(self, column_name: Optional[str]) -> Optional[Dict[str, Any]]:
+        if not column_name:
+            return None
+        for col in self.columns:
+            if col.get("column_name") == column_name:
+                return col
+        return None
+
 
 class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
     """ORM model for SQL query"""