You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by ro...@apache.org on 2021/03/31 18:23:30 UTC

[superset] branch master updated: fix(#13378): Ensure g.user is set for impersonation (#13878)

This is an automated email from the ASF dual-hosted git repository.

robdiciuccio pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new ca506e9  fix(#13378): Ensure g.user is set for impersonation (#13878)
ca506e9 is described below

commit ca506e939617f3cd1e3faaa2ed3ea9ab8bd3218e
Author: Ben Reinhart <be...@gmail.com>
AuthorDate: Wed Mar 31 11:22:56 2021 -0700

    fix(#13378): Ensure g.user is set for impersonation (#13878)
---
 superset/tasks/async_queries.py    | 17 ++++++++++--
 tests/tasks/async_queries_tests.py | 53 ++++++++++++++++++++++++++++----------
 2 files changed, 54 insertions(+), 16 deletions(-)

diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py
index b8db82b..f5f3c14 100644
--- a/superset/tasks/async_queries.py
+++ b/superset/tasks/async_queries.py
@@ -18,11 +18,16 @@
 import logging
 from typing import Any, cast, Dict, Optional
 
-from flask import current_app
+from flask import current_app, g
 
 from superset import app
 from superset.exceptions import SupersetVizException
-from superset.extensions import async_query_manager, cache_manager, celery_app
+from superset.extensions import (
+    async_query_manager,
+    cache_manager,
+    celery_app,
+    security_manager,
+)
 from superset.utils.cache import generate_cache_key, set_and_log_cache
 from superset.views.utils import get_datasource_info, get_viz
 
@@ -32,6 +37,12 @@ query_timeout = current_app.config[
 ]  # TODO: new config key
 
 
+def ensure_user_is_set(user_id: Optional[int]) -> None:
+    user_is_set = hasattr(g, "user") and g.user is not None
+    if not user_is_set and user_id is not None:
+        g.user = security_manager.get_user_by_id(user_id)
+
+
 @celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
 def load_chart_data_into_cache(
     job_metadata: Dict[str, Any], form_data: Dict[str, Any],
@@ -42,6 +53,7 @@ def load_chart_data_into_cache(
 
     with app.app_context():  # type: ignore
         try:
+            ensure_user_is_set(job_metadata.get("user_id"))
             command = ChartDataCommand()
             command.set_query_context(form_data)
             result = command.run(cache=True)
@@ -72,6 +84,7 @@ def load_explore_json_into_cache(
     with app.app_context():  # type: ignore
         cache_key_prefix = "ejr-"  # ejr: explore_json request
         try:
+            ensure_user_is_set(job_metadata.get("user_id"))
             datasource_id, datasource_type = get_datasource_info(None, None, form_data)
 
             viz_obj = get_viz(
diff --git a/tests/tasks/async_queries_tests.py b/tests/tasks/async_queries_tests.py
index 5a7b86a..cd4f0c0 100644
--- a/tests/tasks/async_queries_tests.py
+++ b/tests/tasks/async_queries_tests.py
@@ -26,7 +26,8 @@ from superset.charts.commands.data import ChartDataCommand
 from superset.charts.commands.exceptions import ChartDataQueryFailedError
 from superset.connectors.sqla.models import SqlaTable
 from superset.exceptions import SupersetException
-from superset.extensions import async_query_manager
+from superset.extensions import async_query_manager, security_manager
+from superset.tasks import async_queries
 from superset.tasks.async_queries import (
     load_chart_data_into_cache,
     load_explore_json_into_cache,
@@ -48,17 +49,24 @@ class TestAsyncQueries(SupersetTestCase):
     def test_load_chart_data_into_cache(self, mock_update_job):
         async_query_manager.init_app(app)
         query_context = get_query_context("birth_names")
+        user = security_manager.find_user("gamma")
         job_metadata = {
             "channel_id": str(uuid4()),
             "job_id": str(uuid4()),
-            "user_id": 1,
+            "user_id": user.id,
             "status": "pending",
             "errors": [],
         }
 
-        load_chart_data_into_cache(job_metadata, query_context)
+        with mock.patch.object(
+            async_queries, "ensure_user_is_set"
+        ) as ensure_user_is_set:
+            load_chart_data_into_cache(job_metadata, query_context)
 
-        mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY)
+        ensure_user_is_set.assert_called_once_with(user.id)
+        mock_update_job.assert_called_once_with(
+            job_metadata, "done", result_url=mock.ANY
+        )
 
     @mock.patch.object(
         ChartDataCommand, "run", side_effect=ChartDataQueryFailedError("Error: foo")
@@ -67,25 +75,31 @@ class TestAsyncQueries(SupersetTestCase):
     def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command):
         async_query_manager.init_app(app)
         query_context = get_query_context("birth_names")
+        user = security_manager.find_user("gamma")
         job_metadata = {
             "channel_id": str(uuid4()),
             "job_id": str(uuid4()),
-            "user_id": 1,
+            "user_id": user.id,
             "status": "pending",
             "errors": [],
         }
         with pytest.raises(ChartDataQueryFailedError):
-            load_chart_data_into_cache(job_metadata, query_context)
+            with mock.patch.object(
+                async_queries, "ensure_user_is_set"
+            ) as ensure_user_is_set:
+                load_chart_data_into_cache(job_metadata, query_context)
+            ensure_user_is_set.assert_called_once_with(user.id)
 
-        mock_run_command.assert_called_with(cache=True)
+        mock_run_command.assert_called_once_with(cache=True)
         errors = [{"message": "Error: foo"}]
-        mock_update_job.assert_called_with(job_metadata, "error", errors=errors)
+        mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)
 
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     @mock.patch.object(async_query_manager, "update_job")
     def test_load_explore_json_into_cache(self, mock_update_job):
         async_query_manager.init_app(app)
         table = get_table_by_name("birth_names")
+        user = security_manager.find_user("gamma")
         form_data = {
             "datasource": f"{table.id}__table",
             "viz_type": "dist_bar",
@@ -100,29 +114,40 @@ class TestAsyncQueries(SupersetTestCase):
         job_metadata = {
             "channel_id": str(uuid4()),
             "job_id": str(uuid4()),
-            "user_id": 1,
+            "user_id": user.id,
             "status": "pending",
             "errors": [],
         }
 
-        load_explore_json_into_cache(job_metadata, form_data)
+        with mock.patch.object(
+            async_queries, "ensure_user_is_set"
+        ) as ensure_user_is_set:
+            load_explore_json_into_cache(job_metadata, form_data)
 
-        mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY)
+        ensure_user_is_set.assert_called_once_with(user.id)
+        mock_update_job.assert_called_once_with(
+            job_metadata, "done", result_url=mock.ANY
+        )
 
     @mock.patch.object(async_query_manager, "update_job")
     def test_load_explore_json_into_cache_error(self, mock_update_job):
         async_query_manager.init_app(app)
+        user = security_manager.find_user("gamma")
         form_data = {}
         job_metadata = {
             "channel_id": str(uuid4()),
             "job_id": str(uuid4()),
-            "user_id": 1,
+            "user_id": user.id,
             "status": "pending",
             "errors": [],
         }
 
         with pytest.raises(SupersetException):
-            load_explore_json_into_cache(job_metadata, form_data)
+            with mock.patch.object(
+                async_queries, "ensure_user_is_set"
+            ) as ensure_user_is_set:
+                load_explore_json_into_cache(job_metadata, form_data)
+            ensure_user_is_set.assert_called_once_with(user.id)
 
         errors = ["The dataset associated with this chart no longer exists"]
-        mock_update_job.assert_called_with(job_metadata, "error", errors=errors)
+        mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)