You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2024/02/13 16:02:14 UTC
(superset) 01/16: feat(embedded+async queries): support async queries to work with embedded guest user (#26332)
This is an automated email from the ASF dual-hosted git repository.
michaelsmolina pushed a commit to branch 3.1
in repository https://gitbox.apache.org/repos/asf/superset.git
commit dac73fe0cd0838576d08d4b6dbb4aaadc6f257a3
Author: Zef Lin <ze...@preset.io>
AuthorDate: Mon Jan 8 17:11:45 2024 -0800
feat(embedded+async queries): support async queries to work with embedded guest user (#26332)
(cherry picked from commit efdeb9df0550458363e1c84850770012f501c9fb)
---
superset/async_events/async_query_manager.py | 26 ++++++-
superset/common/query_context_processor.py | 10 ++-
superset/tasks/async_queries.py | 37 ++++++----
tests/integration_tests/query_context_tests.py | 1 +
.../async_events/async_query_manager_tests.py | 79 +++++++++++++++++++++-
5 files changed, 134 insertions(+), 19 deletions(-)
diff --git a/superset/async_events/async_query_manager.py b/superset/async_events/async_query_manager.py
index 94941541fb..32cf247cf3 100644
--- a/superset/async_events/async_query_manager.py
+++ b/superset/async_events/async_query_manager.py
@@ -191,9 +191,14 @@ class AsyncQueryManager:
force: Optional[bool] = False,
user_id: Optional[int] = None,
) -> dict[str, Any]:
+ # pylint: disable=import-outside-toplevel
+ from superset import security_manager
+
job_metadata = self.init_job(channel_id, user_id)
self._load_explore_json_into_cache_job.delay(
- job_metadata,
+ {**job_metadata, "guest_token": guest_user.guest_token}
+ if (guest_user := security_manager.get_current_guest_user_if_guest())
+ else job_metadata,
form_data,
response_type,
force,
@@ -201,10 +206,25 @@ class AsyncQueryManager:
return job_metadata
def submit_chart_data_job(
- self, channel_id: str, form_data: dict[str, Any], user_id: Optional[int]
+ self,
+ channel_id: str,
+ form_data: dict[str, Any],
+ user_id: Optional[int] = None,
) -> dict[str, Any]:
+ # pylint: disable=import-outside-toplevel
+ from superset import security_manager
+
+ # if it's guest user, we want to pass the guest token to the celery task
+ # chart data cache key is calculated based on the current user
+ # this way we can keep the cache key consistent between sync and async command
+ # so that it can be looked up consistently
job_metadata = self.init_job(channel_id, user_id)
- self._load_chart_data_into_cache_job.delay(job_metadata, form_data)
+ self._load_chart_data_into_cache_job.delay(
+ {**job_metadata, "guest_token": guest_user.guest_token}
+ if (guest_user := security_manager.get_current_guest_user_if_guest())
+ else job_metadata,
+ form_data,
+ )
return job_metadata
def read_events(
diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py
index 5b1414d53b..d8b5bea4bb 100644
--- a/superset/common/query_context_processor.py
+++ b/superset/common/query_context_processor.py
@@ -600,7 +600,15 @@ class QueryContextProcessor:
set_and_log_cache(
cache_manager.cache,
cache_key,
- {"data": self._query_context.cache_values},
+ {
+ "data": {
+ # setting form_data into query context cache value as well
+ # so that it can be used to reconstruct form_data field
+ # for query context object when reading from cache
+ "form_data": self._query_context.form_data,
+ **self._query_context.cache_values,
+ },
+ },
self.get_cache_timeout(),
)
return_value["cache_key"] = cache_key # type: ignore
diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py
index 61970ca1f3..b804847cd8 100644
--- a/superset/tasks/async_queries.py
+++ b/superset/tasks/async_queries.py
@@ -22,6 +22,7 @@ from typing import Any, cast, TYPE_CHECKING
from celery.exceptions import SoftTimeLimitExceeded
from flask import current_app, g
+from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError
from superset.charts.schemas import ChartDataQueryContextSchema
@@ -58,6 +59,20 @@ def _create_query_context_from_form(form_data: dict[str, Any]) -> QueryContext:
raise error
+def _load_user_from_job_metadata(job_metadata: dict[str, Any]) -> User:
+ if user_id := job_metadata.get("user_id"):
+ # logged in user
+ user = security_manager.get_user_by_id(user_id)
+ elif guest_token := job_metadata.get("guest_token"):
+ # embedded guest user
+ user = security_manager.get_guest_user_from_token(guest_token)
+ del job_metadata["guest_token"]
+ else:
+ # default to anonymous user if no user is found
+ user = security_manager.get_anonymous_user()
+ return user
+
+
@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
def load_chart_data_into_cache(
job_metadata: dict[str, Any],
@@ -66,12 +81,7 @@ def load_chart_data_into_cache(
# pylint: disable=import-outside-toplevel
from superset.commands.chart.data.get_data_command import ChartDataCommand
- user = (
- security_manager.get_user_by_id(job_metadata.get("user_id"))
- or security_manager.get_anonymous_user()
- )
-
- with override_user(user, force=False):
+ with override_user(_load_user_from_job_metadata(job_metadata), force=False):
try:
set_form_data(form_data)
query_context = _create_query_context_from_form(form_data)
@@ -106,12 +116,7 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals
) -> None:
cache_key_prefix = "ejr-" # ejr: explore_json request
- user = (
- security_manager.get_user_by_id(job_metadata.get("user_id"))
- or security_manager.get_anonymous_user()
- )
-
- with override_user(user, force=False):
+ with override_user(_load_user_from_job_metadata(job_metadata), force=False):
try:
set_form_data(form_data)
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
@@ -140,7 +145,13 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals
"response_type": response_type,
}
cache_key = generate_cache_key(cache_value, cache_key_prefix)
- set_and_log_cache(cache_manager.cache, cache_key, cache_value)
+ cache_instance = cache_manager.cache
+ cache_timeout = (
+ cache_instance.cache.default_timeout if cache_instance.cache else None
+ )
+ set_and_log_cache(
+ cache_instance, cache_key, cache_value, cache_timeout=cache_timeout
+ )
result_url = f"/superset/explore_json/data/{cache_key}"
async_query_manager.update_job(
job_metadata,
diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py
index 8c2082d1c4..30cd160d7e 100644
--- a/tests/integration_tests/query_context_tests.py
+++ b/tests/integration_tests/query_context_tests.py
@@ -121,6 +121,7 @@ class TestQueryContext(SupersetTestCase):
cached = cache_manager.cache.get(cache_key)
assert cached is not None
+ assert "form_data" in cached["data"]
rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"])
rehydrated_qo = rehydrated_qc.queries[0]
diff --git a/tests/unit_tests/async_events/async_query_manager_tests.py b/tests/unit_tests/async_events/async_query_manager_tests.py
index b4ae06dfc3..85ea114201 100644
--- a/tests/unit_tests/async_events/async_query_manager_tests.py
+++ b/tests/unit_tests/async_events/async_query_manager_tests.py
@@ -14,12 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from unittest import mock
+from unittest.mock import ANY, Mock
-from unittest.mock import Mock
-
+from flask import g
from jwt import encode
from pytest import fixture, raises
+from superset import security_manager
from superset.async_events.async_query_manager import (
AsyncQueryManager,
AsyncQueryTokenException,
@@ -38,6 +40,12 @@ def async_query_manager():
return query_manager
+def set_current_as_guest_user():
+ g.user = security_manager.get_guest_user_from_token(
+ {"user": {}, "resources": [{"type": "dashboard", "id": "some-uuid"}]}
+ )
+
+
def test_parse_channel_id_from_request(async_query_manager):
encoded_token = encode(
{"channel": "test_channel_id"}, JWT_TOKEN_SECRET, algorithm="HS256"
@@ -65,3 +73,70 @@ def test_parse_channel_id_from_request_bad_jwt(async_query_manager):
with raises(AsyncQueryTokenException):
async_query_manager.parse_channel_id_from_request(request)
+
+
+@mock.patch("superset.is_feature_enabled")
+def test_submit_chart_data_job_as_guest_user(
+ is_feature_enabled_mock, async_query_manager
+):
+ is_feature_enabled_mock.return_value = True
+ set_current_as_guest_user()
+ job_mock = Mock()
+ async_query_manager._load_chart_data_into_cache_job = job_mock
+ job_meta = async_query_manager.submit_chart_data_job(
+ channel_id="test_channel_id",
+ form_data={},
+ )
+
+ job_mock.delay.assert_called_once_with(
+ {
+ "channel_id": "test_channel_id",
+ "errors": [],
+ "guest_token": {
+ "resources": [{"id": "some-uuid", "type": "dashboard"}],
+ "user": {},
+ },
+ "job_id": ANY,
+ "result_url": None,
+ "status": "pending",
+ "user_id": None,
+ },
+ {},
+ )
+
+ assert "guest_token" not in job_meta
+
+
+@mock.patch("superset.is_feature_enabled")
+def test_submit_explore_json_job_as_guest_user(
+ is_feature_enabled_mock, async_query_manager
+):
+ is_feature_enabled_mock.return_value = True
+ set_current_as_guest_user()
+ job_mock = Mock()
+ async_query_manager._load_explore_json_into_cache_job = job_mock
+ job_meta = async_query_manager.submit_explore_json_job(
+ channel_id="test_channel_id",
+ form_data={},
+ response_type="json",
+ )
+
+ job_mock.delay.assert_called_once_with(
+ {
+ "channel_id": "test_channel_id",
+ "errors": [],
+ "guest_token": {
+ "resources": [{"id": "some-uuid", "type": "dashboard"}],
+ "user": {},
+ },
+ "job_id": ANY,
+ "result_url": None,
+ "status": "pending",
+ "user_id": None,
+ },
+ {},
+ "json",
+ False,
+ )
+
+ assert "guest_token" not in job_meta