You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2023/03/22 22:33:33 UTC

[superset] 01/01: feat(bigquery): get_catalog_names

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch bigquery_get_catalog_names
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 3b25de4c88461c0370f93b8ebf4d0691e19f36a8
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Wed Mar 22 15:33:13 2023 -0700

    feat(bigquery): get_catalog_names
---
 superset/db_engine_specs/bigquery.py | 85 ++++++++++++++++++++++--------------
 1 file changed, 53 insertions(+), 32 deletions(-)

diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py
index 171dad4732..c65d1c2f0a 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -28,6 +28,7 @@ from marshmallow import fields, Schema
 from marshmallow.exceptions import ValidationError
 from sqlalchemy import column, types
 from sqlalchemy.engine.base import Engine
+from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.sql import sqltypes
 from typing_extensions import TypedDict
 
@@ -42,6 +43,17 @@ from superset.sql_parse import Table
 from superset.utils import core as utils
 from superset.utils.hashing import md5_sha_from_str
 
+try:
+    import pandas_gbq
+    from google.cloud import bigquery
+    from google.oauth2 import service_account
+
+    Client = bigquery.Client
+except ModuleNotFoundError:
+    bigquery = None
+    pandas_gbq = None
+    Client = None  # for type checking
+
 if TYPE_CHECKING:
     from superset.models.core import Database  # pragma: no cover
 
@@ -327,17 +339,10 @@ class BigQueryEngineSpec(BaseEngineSpec):
         :param df: The dataframe with data to be uploaded
         :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
         """
-
-        try:
-            # pylint: disable=import-outside-toplevel
-            import pandas_gbq
-            from google.oauth2 import service_account
-        except ImportError as ex:
+        if pandas_gbq is None or service_account is None:
             raise Exception(
-                "Could not import libraries `pandas_gbq` or `google.oauth2`, which are "
-                "required to be installed in your environment in order "
-                "to upload data to BigQuery"
-            ) from ex
+                "Could not import libraries needed to upload data to BigQuery."
+            )
 
         if not table.schema:
             raise Exception("The table schema must be defined")
@@ -366,6 +371,21 @@ class BigQueryEngineSpec(BaseEngineSpec):
 
         pandas_gbq.to_gbq(df, **to_gbq_kwargs)
 
+    @classmethod
+    def _get_client(cls, engine: Engine) -> Client:
+        """
+        Return the BigQuery client associated with an engine.
+        """
+        if bigquery is None or service_account is None:
+            raise Exception(
+                "Could not import libraries needed to upload data to BigQuery."
+            )
+
+        credentials = service_account.Credentials.from_service_account_info(
+            engine.dialect.credentials_info
+        )
+        return bigquery.Client(credentials=credentials)
+
     @classmethod
     def estimate_query_cost(
         cls,
@@ -395,35 +415,36 @@ class BigQueryEngineSpec(BaseEngineSpec):
             costs.append(cls.estimate_statement_cost(processed_statement, database))
         return costs
 
+    @classmethod
+    def get_catalog_names(  # pylint: disable=unused-argument
+        cls,
+        database: "Database",
+        inspector: Inspector,
+    ) -> List[str]:
+        """
+        Get all catalogs.
+
+        In BigQuery, a catalog is called a "project".
+        """
+        with database.get_sqla_engine_with_context() as engine:
+            client = cls._get_client(engine)
+            projects = client.list_projects()
+
+        return sorted(project.project_id for project in projects)
+
     @classmethod
     def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
         return True
 
     @classmethod
     def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
-        try:
-            # pylint: disable=import-outside-toplevel
-            # It's the only way to perfom a dry-run estimate cost
-            from google.cloud import bigquery
-            from google.oauth2 import service_account
-        except ImportError as ex:
-            raise Exception(
-                "Could not import libraries `pygibquery` or `google.oauth2`, which are "
-                "required to be installed in your environment in order "
-                "to upload data to BigQuery"
-            ) from ex
-
         with cls.get_engine(cursor) as engine:
-            creds = engine.dialect.credentials_info
-
-        creds = service_account.Credentials.from_service_account_info(creds)
-        client = bigquery.Client(credentials=creds)
-        job_config = bigquery.QueryJobConfig(dry_run=True)
-
-        query_job = client.query(
-            statement,
-            job_config=job_config,
-        )  # Make an API request.
+            client = cls._get_client(engine)
+            job_config = bigquery.QueryJobConfig(dry_run=True)
+            query_job = client.query(
+                statement,
+                job_config=job_config,
+            )  # Make an API request.
 
         # Format Bytes.
         # TODO: Humanize in case more db engine specs need to be added,