You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2024/02/13 16:02:14 UTC

(superset) 01/16: feat(embedded+async queries): support async queries to work with embedded guest user (#26332)

This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 3.1
in repository https://gitbox.apache.org/repos/asf/superset.git

commit dac73fe0cd0838576d08d4b6dbb4aaadc6f257a3
Author: Zef Lin <ze...@preset.io>
AuthorDate: Mon Jan 8 17:11:45 2024 -0800

    feat(embedded+async queries): support async queries to work with embedded guest user (#26332)
    
    (cherry picked from commit efdeb9df0550458363e1c84850770012f501c9fb)
---
 superset/async_events/async_query_manager.py       | 26 ++++++-
 superset/common/query_context_processor.py         | 10 ++-
 superset/tasks/async_queries.py                    | 37 ++++++----
 tests/integration_tests/query_context_tests.py     |  1 +
 .../async_events/async_query_manager_tests.py      | 79 +++++++++++++++++++++-
 5 files changed, 134 insertions(+), 19 deletions(-)

diff --git a/superset/async_events/async_query_manager.py b/superset/async_events/async_query_manager.py
index 94941541fb..32cf247cf3 100644
--- a/superset/async_events/async_query_manager.py
+++ b/superset/async_events/async_query_manager.py
@@ -191,9 +191,14 @@ class AsyncQueryManager:
         force: Optional[bool] = False,
         user_id: Optional[int] = None,
     ) -> dict[str, Any]:
+        # pylint: disable=import-outside-toplevel
+        from superset import security_manager
+
         job_metadata = self.init_job(channel_id, user_id)
         self._load_explore_json_into_cache_job.delay(
-            job_metadata,
+            {**job_metadata, "guest_token": guest_user.guest_token}
+            if (guest_user := security_manager.get_current_guest_user_if_guest())
+            else job_metadata,
             form_data,
             response_type,
             force,
@@ -201,10 +206,25 @@ class AsyncQueryManager:
         return job_metadata
 
     def submit_chart_data_job(
-        self, channel_id: str, form_data: dict[str, Any], user_id: Optional[int]
+        self,
+        channel_id: str,
+        form_data: dict[str, Any],
+        user_id: Optional[int] = None,
     ) -> dict[str, Any]:
+        # pylint: disable=import-outside-toplevel
+        from superset import security_manager
+
+        # if it's guest user, we want to pass the guest token to the celery task
+        # chart data cache key is calculated based on the current user
+        # this way we can keep the cache key consistent between sync and async command
+        # so that it can be looked up consistently
         job_metadata = self.init_job(channel_id, user_id)
-        self._load_chart_data_into_cache_job.delay(job_metadata, form_data)
+        self._load_chart_data_into_cache_job.delay(
+            {**job_metadata, "guest_token": guest_user.guest_token}
+            if (guest_user := security_manager.get_current_guest_user_if_guest())
+            else job_metadata,
+            form_data,
+        )
         return job_metadata
 
     def read_events(
diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py
index 5b1414d53b..d8b5bea4bb 100644
--- a/superset/common/query_context_processor.py
+++ b/superset/common/query_context_processor.py
@@ -600,7 +600,15 @@ class QueryContextProcessor:
             set_and_log_cache(
                 cache_manager.cache,
                 cache_key,
-                {"data": self._query_context.cache_values},
+                {
+                    "data": {
+                        # setting form_data into query context cache value as well
+                        # so that it can be used to reconstruct form_data field
+                        # for query context object when reading from cache
+                        "form_data": self._query_context.form_data,
+                        **self._query_context.cache_values,
+                    },
+                },
                 self.get_cache_timeout(),
             )
             return_value["cache_key"] = cache_key  # type: ignore
diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py
index 61970ca1f3..b804847cd8 100644
--- a/superset/tasks/async_queries.py
+++ b/superset/tasks/async_queries.py
@@ -22,6 +22,7 @@ from typing import Any, cast, TYPE_CHECKING
 
 from celery.exceptions import SoftTimeLimitExceeded
 from flask import current_app, g
+from flask_appbuilder.security.sqla.models import User
 from marshmallow import ValidationError
 
 from superset.charts.schemas import ChartDataQueryContextSchema
@@ -58,6 +59,20 @@ def _create_query_context_from_form(form_data: dict[str, Any]) -> QueryContext:
         raise error
 
 
+def _load_user_from_job_metadata(job_metadata: dict[str, Any]) -> User:
+    if user_id := job_metadata.get("user_id"):
+        # logged in user
+        user = security_manager.get_user_by_id(user_id)
+    elif guest_token := job_metadata.get("guest_token"):
+        # embedded guest user
+        user = security_manager.get_guest_user_from_token(guest_token)
+        del job_metadata["guest_token"]
+    else:
+        # default to anonymous user if no user is found
+        user = security_manager.get_anonymous_user()
+    return user
+
+
 @celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
 def load_chart_data_into_cache(
     job_metadata: dict[str, Any],
@@ -66,12 +81,7 @@ def load_chart_data_into_cache(
     # pylint: disable=import-outside-toplevel
     from superset.commands.chart.data.get_data_command import ChartDataCommand
 
-    user = (
-        security_manager.get_user_by_id(job_metadata.get("user_id"))
-        or security_manager.get_anonymous_user()
-    )
-
-    with override_user(user, force=False):
+    with override_user(_load_user_from_job_metadata(job_metadata), force=False):
         try:
             set_form_data(form_data)
             query_context = _create_query_context_from_form(form_data)
@@ -106,12 +116,7 @@ def load_explore_json_into_cache(  # pylint: disable=too-many-locals
 ) -> None:
     cache_key_prefix = "ejr-"  # ejr: explore_json request
 
-    user = (
-        security_manager.get_user_by_id(job_metadata.get("user_id"))
-        or security_manager.get_anonymous_user()
-    )
-
-    with override_user(user, force=False):
+    with override_user(_load_user_from_job_metadata(job_metadata), force=False):
         try:
             set_form_data(form_data)
             datasource_id, datasource_type = get_datasource_info(None, None, form_data)
@@ -140,7 +145,13 @@ def load_explore_json_into_cache(  # pylint: disable=too-many-locals
                 "response_type": response_type,
             }
             cache_key = generate_cache_key(cache_value, cache_key_prefix)
-            set_and_log_cache(cache_manager.cache, cache_key, cache_value)
+            cache_instance = cache_manager.cache
+            cache_timeout = (
+                cache_instance.cache.default_timeout if cache_instance.cache else None
+            )
+            set_and_log_cache(
+                cache_instance, cache_key, cache_value, cache_timeout=cache_timeout
+            )
             result_url = f"/superset/explore_json/data/{cache_key}"
             async_query_manager.update_job(
                 job_metadata,
diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py
index 8c2082d1c4..30cd160d7e 100644
--- a/tests/integration_tests/query_context_tests.py
+++ b/tests/integration_tests/query_context_tests.py
@@ -121,6 +121,7 @@ class TestQueryContext(SupersetTestCase):
 
         cached = cache_manager.cache.get(cache_key)
         assert cached is not None
+        assert "form_data" in cached["data"]
 
         rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"])
         rehydrated_qo = rehydrated_qc.queries[0]
diff --git a/tests/unit_tests/async_events/async_query_manager_tests.py b/tests/unit_tests/async_events/async_query_manager_tests.py
index b4ae06dfc3..85ea114201 100644
--- a/tests/unit_tests/async_events/async_query_manager_tests.py
+++ b/tests/unit_tests/async_events/async_query_manager_tests.py
@@ -14,12 +14,14 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+from unittest import mock
+from unittest.mock import ANY, Mock
 
-from unittest.mock import Mock
-
+from flask import g
 from jwt import encode
 from pytest import fixture, raises
 
+from superset import security_manager
 from superset.async_events.async_query_manager import (
     AsyncQueryManager,
     AsyncQueryTokenException,
@@ -38,6 +40,12 @@ def async_query_manager():
     return query_manager
 
 
+def set_current_as_guest_user():
+    g.user = security_manager.get_guest_user_from_token(
+        {"user": {}, "resources": [{"type": "dashboard", "id": "some-uuid"}]}
+    )
+
+
 def test_parse_channel_id_from_request(async_query_manager):
     encoded_token = encode(
         {"channel": "test_channel_id"}, JWT_TOKEN_SECRET, algorithm="HS256"
@@ -65,3 +73,70 @@ def test_parse_channel_id_from_request_bad_jwt(async_query_manager):
 
     with raises(AsyncQueryTokenException):
         async_query_manager.parse_channel_id_from_request(request)
+
+
+@mock.patch("superset.is_feature_enabled")
+def test_submit_chart_data_job_as_guest_user(
+    is_feature_enabled_mock, async_query_manager
+):
+    is_feature_enabled_mock.return_value = True
+    set_current_as_guest_user()
+    job_mock = Mock()
+    async_query_manager._load_chart_data_into_cache_job = job_mock
+    job_meta = async_query_manager.submit_chart_data_job(
+        channel_id="test_channel_id",
+        form_data={},
+    )
+
+    job_mock.delay.assert_called_once_with(
+        {
+            "channel_id": "test_channel_id",
+            "errors": [],
+            "guest_token": {
+                "resources": [{"id": "some-uuid", "type": "dashboard"}],
+                "user": {},
+            },
+            "job_id": ANY,
+            "result_url": None,
+            "status": "pending",
+            "user_id": None,
+        },
+        {},
+    )
+
+    assert "guest_token" not in job_meta
+
+
+@mock.patch("superset.is_feature_enabled")
+def test_submit_explore_json_job_as_guest_user(
+    is_feature_enabled_mock, async_query_manager
+):
+    is_feature_enabled_mock.return_value = True
+    set_current_as_guest_user()
+    job_mock = Mock()
+    async_query_manager._load_explore_json_into_cache_job = job_mock
+    job_meta = async_query_manager.submit_explore_json_job(
+        channel_id="test_channel_id",
+        form_data={},
+        response_type="json",
+    )
+
+    job_mock.delay.assert_called_once_with(
+        {
+            "channel_id": "test_channel_id",
+            "errors": [],
+            "guest_token": {
+                "resources": [{"id": "some-uuid", "type": "dashboard"}],
+                "user": {},
+            },
+            "job_id": ANY,
+            "result_url": None,
+            "status": "pending",
+            "user_id": None,
+        },
+        {},
+        "json",
+        False,
+    )
+
+    assert "guest_token" not in job_meta