You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2024/03/20 16:53:17 UTC

(superset) 01/01: WIP

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch sip-85
in repository https://gitbox.apache.org/repos/asf/superset.git

commit aac320c8c776fa866151ae7cfa36db019db5821e
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Wed Mar 20 12:53:07 2024 -0400

    WIP
---
 superset/config.py       | 10 ++++++++
 superset/models/core.py  | 33 +++++++++++++++++++++---
 superset/utils/oauth2.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 107 insertions(+), 3 deletions(-)

diff --git a/superset/config.py b/superset/config.py
index 197e4bac42..0de3f284e5 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -1392,6 +1392,16 @@ PREFERRED_DATABASES: list[str] = [
 # one here.
 TEST_DATABASE_CONNECTION_TIMEOUT = timedelta(seconds=30)
 
+# Details needed for databases that allows users to authenticate via personal OAuth2
+# tokens. See https://github.com/apache/superset/issues/20300 for details.
+DATABASE_OAUTH2_CREDENTIALS = {
+    "GSheets": {
+        "CLIENT_ID": "XXX.apps.googleusercontent.com",
+        "CLIENT_SECRET": "GOCSPX-YYY",
+        "REDIRECT_URI": "http://localhost:8088/api/v1/database/oauth2/",
+    },
+}
+
 # Enable/disable CSP warning
 CONTENT_SECURITY_POLICY_WARNING = True
 
diff --git a/superset/models/core.py b/superset/models/core.py
index 71a6e9d042..27ee347139 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -75,6 +75,7 @@ from superset.superset_typing import ResultSetColumnType
 from superset.utils import cache as cache_util, core as utils
 from superset.utils.backports import StrEnum
 from superset.utils.core import get_username
+from superset.utils.oauth2 import get_oauth2_access_token
 
 config = app.config
 custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
@@ -461,6 +462,11 @@ class Database(
         )
 
         effective_username = self.get_effective_user(sqlalchemy_url)
+        access_token = (
+            get_oauth2_access_token(self.id, g.user.id, self.db_engine_spec)
+            if hasattr(g, "user")
+            else None
+        )
         # If using MySQL or Presto for example, will set url.username
         # If using Hive, will not do anything yet since that relies on a
         # configuration parameter instead.
@@ -468,6 +474,7 @@ class Database(
             sqlalchemy_url,
             self.impersonate_user,
             effective_username,
+            access_token,
         )
 
         masked_url = self.get_password_masked_url(sqlalchemy_url)
@@ -588,7 +595,7 @@ class Database(
                         database=None,
                     )
                 _log_query(sql_)
-                self.db_engine_spec.execute(cursor, sql_)
+                self.db_engine_spec.execute(cursor, sql_, self.id)
                 cursor.fetchall()
 
             if mutate_after_split:
@@ -598,10 +605,10 @@ class Database(
                     database=None,
                 )
                 _log_query(last_sql)
-                self.db_engine_spec.execute(cursor, last_sql)
+                self.db_engine_spec.execute(cursor, last_sql, self.id)
             else:
                 _log_query(sqls[-1])
-                self.db_engine_spec.execute(cursor, sqls[-1])
+                self.db_engine_spec.execute(cursor, sqls[-1], self.id)
 
             data = self.db_engine_spec.fetch_data(cursor)
             result_set = SupersetResultSet(
@@ -978,6 +985,26 @@ sqla.event.listen(Database, "after_update", security_manager.database_after_upda
 sqla.event.listen(Database, "after_delete", security_manager.database_after_delete)
 
 
+class DatabaseUserOAuth2Tokens(Model, AuditMixinNullable):
+    """
+    Store OAuth2 tokens, for authenticating to DBs using user personal tokens.
+    """
+
+    __tablename__ = "database_user_oauth2_tokens"
+
+    id = Column(Integer, primary_key=True)
+
+    user_id = Column(Integer, ForeignKey("ab_user.id"))
+    user = relationship(security_manager.user_model, foreign_keys=[user_id])
+
+    database_id = Column(Integer, ForeignKey("dbs.id"), nullable=False)
+    database = relationship("Database", foreign_keys=[database_id])
+
+    access_token = Column(encrypted_field_factory.create(Text), nullable=True)
+    access_token_expiration = Column(DateTime, nullable=True)
+    refresh_token = Column(encrypted_field_factory.create(Text), nullable=True)
+
+
 class Log(Model):  # pylint: disable=too-few-public-methods
     """ORM object used to log Superset actions to the database"""
 
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
new file mode 100644
index 0000000000..b004e4e02e
--- /dev/null
+++ b/superset/utils/oauth2.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime, timedelta
+from typing import Optional, Type
+
+from superset import db
+from superset.db_engine_specs.base import BaseEngineSpec
+
+
+def get_oauth2_access_token(
+    database_id: int,
+    user_id: int,
+    db_engine_spec: Type[BaseEngineSpec],
+) -> Optional[str]:
+    """
+    Return a valid OAuth2 access token.
+    If the token exists but is expired and a refresh token is available the function will
+    return a fresh token and store it in the database for further requests.
+    """
+    # pylint: disable=import-outside-toplevel
+    from superset.models.core import DatabaseUserOAuth2Tokens
+
+    token = (
+        db.session.query(DatabaseUserOAuth2Tokens)
+        .filter_by(user_id=user_id, database_id=database_id)
+        .one_or_none()
+    )
+    if token is None:
+        return None
+
+    if token.access_token and token.access_token_expiration < datetime.now():
+        return token.access_token
+
+    if token.refresh_token:
+        # refresh access token
+        token_response = db_engine_spec.get_oauth2_fresh_token(token.refresh_token)
+
+        # store new access token; note that the refresh token might be revoked, in which
+        # case there would be no access token in the response
+        if "access_token" in token_response:
+            token.access_token = token_response["access_token"]
+            token.access_token_expiration = datetime.now() + timedelta(
+                seconds=token_response["expires_in"]
+            )
+            db.session.add(token)
+
+            return token.access_token
+
+    # since the access token is expired and there's no refresh token, delete the entry
+    db.session.delete(token)
+
+    return None