You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2024/03/20 16:53:17 UTC
(superset) 01/01: WIP
This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch sip-85
in repository https://gitbox.apache.org/repos/asf/superset.git
commit aac320c8c776fa866151ae7cfa36db019db5821e
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Wed Mar 20 12:53:07 2024 -0400
WIP
---
superset/config.py | 10 ++++++++
superset/models/core.py | 33 +++++++++++++++++++++---
superset/utils/oauth2.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++++
3 files changed, 107 insertions(+), 3 deletions(-)
diff --git a/superset/config.py b/superset/config.py
index 197e4bac42..0de3f284e5 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -1392,6 +1392,16 @@ PREFERRED_DATABASES: list[str] = [
# one here.
TEST_DATABASE_CONNECTION_TIMEOUT = timedelta(seconds=30)
+# Details needed for databases that allows users to authenticate via personal OAuth2
+# tokens. See https://github.com/apache/superset/issues/20300 for details.
+DATABASE_OAUTH2_CREDENTIALS = {
+ "GSheets": {
+ "CLIENT_ID": "XXX.apps.googleusercontent.com",
+ "CLIENT_SECRET": "GOCSPX-YYY",
+ "REDIRECT_URI": "http://localhost:8088/api/v1/database/oauth2/",
+ },
+}
+
# Enable/disable CSP warning
CONTENT_SECURITY_POLICY_WARNING = True
diff --git a/superset/models/core.py b/superset/models/core.py
index 71a6e9d042..27ee347139 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -75,6 +75,7 @@ from superset.superset_typing import ResultSetColumnType
from superset.utils import cache as cache_util, core as utils
from superset.utils.backports import StrEnum
from superset.utils.core import get_username
+from superset.utils.oauth2 import get_oauth2_access_token
config = app.config
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
@@ -461,6 +462,11 @@ class Database(
)
effective_username = self.get_effective_user(sqlalchemy_url)
+ access_token = (
+ get_oauth2_access_token(self.id, g.user.id, self.db_engine_spec)
+ if hasattr(g, "user")
+ else None
+ )
# If using MySQL or Presto for example, will set url.username
# If using Hive, will not do anything yet since that relies on a
# configuration parameter instead.
@@ -468,6 +474,7 @@ class Database(
sqlalchemy_url,
self.impersonate_user,
effective_username,
+ access_token,
)
masked_url = self.get_password_masked_url(sqlalchemy_url)
@@ -588,7 +595,7 @@ class Database(
database=None,
)
_log_query(sql_)
- self.db_engine_spec.execute(cursor, sql_)
+ self.db_engine_spec.execute(cursor, sql_, self.id)
cursor.fetchall()
if mutate_after_split:
@@ -598,10 +605,10 @@ class Database(
database=None,
)
_log_query(last_sql)
- self.db_engine_spec.execute(cursor, last_sql)
+ self.db_engine_spec.execute(cursor, last_sql, self.id)
else:
_log_query(sqls[-1])
- self.db_engine_spec.execute(cursor, sqls[-1])
+ self.db_engine_spec.execute(cursor, sqls[-1], self.id)
data = self.db_engine_spec.fetch_data(cursor)
result_set = SupersetResultSet(
@@ -978,6 +985,26 @@ sqla.event.listen(Database, "after_update", security_manager.database_after_upda
sqla.event.listen(Database, "after_delete", security_manager.database_after_delete)
+class DatabaseUserOAuth2Tokens(Model, AuditMixinNullable):
+ """
+ Store OAuth2 tokens, for authenticating to DBs using user personal tokens.
+ """
+
+ __tablename__ = "database_user_oauth2_tokens"
+
+ id = Column(Integer, primary_key=True)
+
+ user_id = Column(Integer, ForeignKey("ab_user.id"))
+ user = relationship(security_manager.user_model, foreign_keys=[user_id])
+
+ database_id = Column(Integer, ForeignKey("dbs.id"), nullable=False)
+ database = relationship("Database", foreign_keys=[database_id])
+
+ access_token = Column(encrypted_field_factory.create(Text), nullable=True)
+ access_token_expiration = Column(DateTime, nullable=True)
+ refresh_token = Column(encrypted_field_factory.create(Text), nullable=True)
+
+
class Log(Model): # pylint: disable=too-few-public-methods
"""ORM object used to log Superset actions to the database"""
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
new file mode 100644
index 0000000000..b004e4e02e
--- /dev/null
+++ b/superset/utils/oauth2.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime, timedelta
+from typing import Optional, Type
+
+from superset import db
+from superset.db_engine_specs.base import BaseEngineSpec
+
+
+def get_oauth2_access_token(
+ database_id: int,
+ user_id: int,
+ db_engine_spec: Type[BaseEngineSpec],
+) -> Optional[str]:
+ """
+ Return a valid OAuth2 access token.
+ If the token exists but is expired and a refresh token is available the function will
+ return a fresh token and store it in the database for further requests.
+ """
+ # pylint: disable=import-outside-toplevel
+ from superset.models.core import DatabaseUserOAuth2Tokens
+
+ token = (
+ db.session.query(DatabaseUserOAuth2Tokens)
+ .filter_by(user_id=user_id, database_id=database_id)
+ .one_or_none()
+ )
+ if token is None:
+ return None
+
+ if token.access_token and token.access_token_expiration < datetime.now():
+ return token.access_token
+
+ if token.refresh_token:
+ # refresh access token
+ token_response = db_engine_spec.get_oauth2_fresh_token(token.refresh_token)
+
+ # store new access token; note that the refresh token might be revoked, in which
+ # case there would be no access token in the response
+ if "access_token" in token_response:
+ token.access_token = token_response["access_token"]
+ token.access_token_expiration = datetime.now() + timedelta(
+ seconds=token_response["expires_in"]
+ )
+ db.session.add(token)
+
+ return token.access_token
+
+ # since the access token is expired and there's no refresh token, delete the entry
+ db.session.delete(token)
+
+ return None