You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2022/09/06 17:14:49 UTC

[superset] 01/04: fix(celery cache warmup): add auth and use warm_up_cache endpoint (#21076)

This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 1.5
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 9861c2b9e7fc5bf571c07511b4e7aa702808d7d4
Author: ʈᵃᵢ <td...@gmail.com>
AuthorDate: Tue Aug 30 09:24:24 2022 -0700

    fix(celery cache warmup): add auth and use warm_up_cache endpoint (#21076)
    
    (cherry picked from commit 04dd8d414db6a3cddcd073ad74acb2a4b7a53b0b)
---
 docker/pythonpath_dev/superset_config.py        |  10 ++
 superset/tasks/cache.py                         |  98 ++++++++--------
 tests/integration_tests/strategy_tests.py       | 141 +++---------------------
 tests/integration_tests/superset_test_config.py |   2 +
 4 files changed, 70 insertions(+), 181 deletions(-)

diff --git a/docker/pythonpath_dev/superset_config.py b/docker/pythonpath_dev/superset_config.py
index 6c58bec79c..1c78baf59f 100644
--- a/docker/pythonpath_dev/superset_config.py
+++ b/docker/pythonpath_dev/superset_config.py
@@ -69,6 +69,16 @@ REDIS_RESULTS_DB = get_env_variable("REDIS_RESULTS_DB", "1")
 
 RESULTS_BACKEND = FileSystemCache("/app/superset_home/sqllab")
 
+CACHE_CONFIG = {
+    "CACHE_TYPE": "redis",
+    "CACHE_DEFAULT_TIMEOUT": 300,
+    "CACHE_KEY_PREFIX": "superset_",
+    "CACHE_REDIS_HOST": REDIS_HOST,
+    "CACHE_REDIS_PORT": REDIS_PORT,
+    "CACHE_REDIS_DB": REDIS_RESULTS_DB,
+}
+DATA_CACHE_CONFIG = CACHE_CONFIG
+
 
 class CeleryConfig(object):
     BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"
diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py
index ee73df5fde..137ec068e8 100644
--- a/superset/tasks/cache.py
+++ b/superset/tasks/cache.py
@@ -14,73 +14,36 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import json
 import logging
 from typing import Any, Dict, List, Optional, Union
 from urllib import request
 from urllib.error import URLError
 
+from celery.beat import SchedulingError
 from celery.utils.log import get_task_logger
 from sqlalchemy import and_, func
 
-from superset import app, db
+from superset import app, db, security_manager
 from superset.extensions import celery_app
 from superset.models.core import Log
 from superset.models.dashboard import Dashboard
 from superset.models.slice import Slice
 from superset.models.tags import Tag, TaggedObject
 from superset.utils.date_parser import parse_human_datetime
-from superset.views.utils import build_extra_filters
+from superset.utils.machine_auth import MachineAuthProvider
 
 logger = get_task_logger(__name__)
 logger.setLevel(logging.INFO)
 
 
-def get_form_data(
-    chart_id: int, dashboard: Optional[Dashboard] = None
-) -> Dict[str, Any]:
-    """
-    Build `form_data` for chart GET request from dashboard's `default_filters`.
-
-    When a dashboard has `default_filters` they need to be added  as extra
-    filters in the GET request for charts.
-
-    """
-    form_data: Dict[str, Any] = {"slice_id": chart_id}
-
-    if dashboard is None or not dashboard.json_metadata:
-        return form_data
-
-    json_metadata = json.loads(dashboard.json_metadata)
-    default_filters = json.loads(json_metadata.get("default_filters", "null"))
-    if not default_filters:
-        return form_data
-
-    filter_scopes = json_metadata.get("filter_scopes", {})
-    layout = json.loads(dashboard.position_json or "{}")
-    if (
-        isinstance(layout, dict)
-        and isinstance(filter_scopes, dict)
-        and isinstance(default_filters, dict)
-    ):
-        extra_filters = build_extra_filters(
-            layout, filter_scopes, default_filters, chart_id
-        )
-        if extra_filters:
-            form_data["extra_filters"] = extra_filters
-
-    return form_data
-
-
-def get_url(chart: Slice, extra_filters: Optional[Dict[str, Any]] = None) -> str:
+def get_url(chart: Slice, dashboard: Optional[Dashboard] = None) -> str:
     """Return external URL for warming up a given chart/table cache."""
     with app.test_request_context():
-        baseurl = (
-            "{SUPERSET_WEBSERVER_PROTOCOL}://"
-            "{SUPERSET_WEBSERVER_ADDRESS}:"
-            "{SUPERSET_WEBSERVER_PORT}".format(**app.config)
-        )
-        return f"{baseurl}{chart.get_explore_url(overrides=extra_filters)}"
+        baseurl = "{WEBDRIVER_BASEURL}".format(**app.config)
+        url = f"{baseurl}superset/warm_up_cache/?slice_id={chart.id}"
+        if dashboard:
+            url += f"&dashboard_id={dashboard.id}"
+        return url
 
 
 class Strategy:  # pylint: disable=too-few-public-methods
@@ -179,8 +142,7 @@ class TopNDashboardsStrategy(Strategy):  # pylint: disable=too-few-public-method
         dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
         for dashboard in dashboards:
             for chart in dashboard.slices:
-                form_data_with_filters = get_form_data(chart.id, dashboard)
-                urls.append(get_url(chart, form_data_with_filters))
+                urls.append(get_url(chart, dashboard))
 
         return urls
 
@@ -253,6 +215,30 @@ class DashboardTagsStrategy(Strategy):  # pylint: disable=too-few-public-methods
 strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy]
 
 
+@celery_app.task(name="fetch_url")
+def fetch_url(url: str, headers: Dict[str, str]) -> Dict[str, str]:
+    """
+    Celery job to fetch url
+    """
+    result = {}
+    try:
+        logger.info("Fetching %s", url)
+        req = request.Request(url, headers=headers)
+        response = request.urlopen(  # pylint: disable=consider-using-with
+            req, timeout=600
+        )
+        logger.info("Fetched %s, status code: %s", url, response.code)
+        if response.code == 200:
+            result = {"success": url, "response": response.read().decode("utf-8")}
+        else:
+            result = {"error": url, "status_code": response.code}
+            logger.error("Error fetching %s, status code: %s", url, response.code)
+    except URLError as err:
+        logger.exception("Error warming up cache!")
+        result = {"error": url, "exception": str(err)}
+    return result
+
+
 @celery_app.task(name="cache-warmup")
 def cache_warmup(
     strategy_name: str, *args: Any, **kwargs: Any
@@ -282,14 +268,18 @@ def cache_warmup(
         logger.exception(message)
         return message
 
-    results: Dict[str, List[str]] = {"success": [], "errors": []}
+    user = security_manager.get_user_by_username(app.config["THUMBNAIL_SELENIUM_USER"])
+    cookies = MachineAuthProvider.get_auth_cookies(user)
+    headers = {"Cookie": f"session={cookies.get('session', '')}"}
+
+    results: Dict[str, List[str]] = {"scheduled": [], "errors": []}
     for url in strategy.get_urls():
         try:
-            logger.info("Fetching %s", url)
-            request.urlopen(url)  # pylint: disable=consider-using-with
-            results["success"].append(url)
-        except URLError:
-            logger.exception("Error warming up cache!")
+            logger.info("Scheduling %s", url)
+            fetch_url.delay(url, headers)
+            results["scheduled"].append(url)
+        except SchedulingError:
+            logger.exception("Error scheduling fetch_url: %s", url)
             results["errors"].append(url)
 
     return results
diff --git a/tests/integration_tests/strategy_tests.py b/tests/integration_tests/strategy_tests.py
index aec73b1efe..f31489bb04 100644
--- a/tests/integration_tests/strategy_tests.py
+++ b/tests/integration_tests/strategy_tests.py
@@ -38,9 +38,9 @@ from superset.models.core import Log
 from superset.models.tags import get_tag, ObjectTypes, TaggedObject, TagTypes
 from superset.tasks.cache import (
     DashboardTagsStrategy,
-    get_form_data,
     TopNDashboardsStrategy,
 )
+from superset.utils.urls import get_url_host
 
 from .base_tests import SupersetTestCase
 from .dashboard_utils import create_dashboard, create_slice, create_table_metadata
@@ -49,7 +49,6 @@ from .fixtures.unicode_dashboard import (
     load_unicode_data,
 )
 
-URL_PREFIX = "http://0.0.0.0:8081"
 
 mock_positions = {
     "DASHBOARD_VERSION_KEY": "v2",
@@ -69,128 +68,6 @@ mock_positions = {
 
 
 class TestCacheWarmUp(SupersetTestCase):
-    def test_get_form_data_chart_only(self):
-        chart_id = 1
-        result = get_form_data(chart_id, None)
-        expected = {"slice_id": chart_id}
-        self.assertEqual(result, expected)
-
-    def test_get_form_data_no_dashboard_metadata(self):
-        chart_id = 1
-        dashboard = MagicMock()
-        dashboard.json_metadata = None
-        dashboard.position_json = json.dumps(mock_positions)
-        result = get_form_data(chart_id, dashboard)
-        expected = {"slice_id": chart_id}
-        self.assertEqual(result, expected)
-
-    def test_get_form_data_immune_slice(self):
-        chart_id = 1
-        filter_box_id = 2
-        dashboard = MagicMock()
-        dashboard.position_json = json.dumps(mock_positions)
-        dashboard.json_metadata = json.dumps(
-            {
-                "filter_scopes": {
-                    str(filter_box_id): {
-                        "name": {"scope": ["ROOT_ID"], "immune": [chart_id]}
-                    }
-                },
-                "default_filters": json.dumps(
-                    {str(filter_box_id): {"name": ["Alice", "Bob"]}}
-                ),
-            }
-        )
-        result = get_form_data(chart_id, dashboard)
-        expected = {"slice_id": chart_id}
-        self.assertEqual(result, expected)
-
-    def test_get_form_data_no_default_filters(self):
-        chart_id = 1
-        dashboard = MagicMock()
-        dashboard.json_metadata = json.dumps({})
-        dashboard.position_json = json.dumps(mock_positions)
-        result = get_form_data(chart_id, dashboard)
-        expected = {"slice_id": chart_id}
-        self.assertEqual(result, expected)
-
-    def test_get_form_data_immune_fields(self):
-        chart_id = 1
-        filter_box_id = 2
-        dashboard = MagicMock()
-        dashboard.position_json = json.dumps(mock_positions)
-        dashboard.json_metadata = json.dumps(
-            {
-                "default_filters": json.dumps(
-                    {
-                        str(filter_box_id): {
-                            "name": ["Alice", "Bob"],
-                            "__time_range": "100 years ago : today",
-                        }
-                    }
-                ),
-                "filter_scopes": {
-                    str(filter_box_id): {
-                        "__time_range": {"scope": ["ROOT_ID"], "immune": [chart_id]}
-                    }
-                },
-            }
-        )
-        result = get_form_data(chart_id, dashboard)
-        expected = {
-            "slice_id": chart_id,
-            "extra_filters": [{"col": "name", "op": "in", "val": ["Alice", "Bob"]}],
-        }
-        self.assertEqual(result, expected)
-
-    def test_get_form_data_no_extra_filters(self):
-        chart_id = 1
-        filter_box_id = 2
-        dashboard = MagicMock()
-        dashboard.position_json = json.dumps(mock_positions)
-        dashboard.json_metadata = json.dumps(
-            {
-                "default_filters": json.dumps(
-                    {str(filter_box_id): {"__time_range": "100 years ago : today"}}
-                ),
-                "filter_scopes": {
-                    str(filter_box_id): {
-                        "__time_range": {"scope": ["ROOT_ID"], "immune": [chart_id]}
-                    }
-                },
-            }
-        )
-        result = get_form_data(chart_id, dashboard)
-        expected = {"slice_id": chart_id}
-        self.assertEqual(result, expected)
-
-    def test_get_form_data(self):
-        chart_id = 1
-        filter_box_id = 2
-        dashboard = MagicMock()
-        dashboard.position_json = json.dumps(mock_positions)
-        dashboard.json_metadata = json.dumps(
-            {
-                "default_filters": json.dumps(
-                    {
-                        str(filter_box_id): {
-                            "name": ["Alice", "Bob"],
-                            "__time_range": "100 years ago : today",
-                        }
-                    }
-                )
-            }
-        )
-        result = get_form_data(chart_id, dashboard)
-        expected = {
-            "slice_id": chart_id,
-            "extra_filters": [
-                {"col": "name", "op": "in", "val": ["Alice", "Bob"]},
-                {"col": "__time_range", "op": "==", "val": "100 years ago : today"},
-            ],
-        }
-        self.assertEqual(result, expected)
-
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_top_n_dashboards_strategy(self):
         # create a top visited dashboard
@@ -202,7 +79,12 @@ class TestCacheWarmUp(SupersetTestCase):
 
         strategy = TopNDashboardsStrategy(1)
         result = sorted(strategy.get_urls())
-        expected = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices])
+        expected = sorted(
+            [
+                f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}&dashboard_id={dash.id}"
+                for slc in dash.slices
+            ]
+        )
         self.assertEqual(result, expected)
 
     def reset_tag(self, tag):
@@ -228,7 +110,12 @@ class TestCacheWarmUp(SupersetTestCase):
         # tag dashboard 'births' with `tag1`
         tag1 = get_tag("tag1", db.session, TagTypes.custom)
         dash = self.get_dash_by_slug("births")
-        tag1_urls = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices])
+        tag1_urls = sorted(
+            [
+                f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}"
+                for slc in dash.slices
+            ]
+        )
         tagged_object = TaggedObject(
             tag_id=tag1.id, object_id=dash.id, object_type=ObjectTypes.dashboard
         )
@@ -248,7 +135,7 @@ class TestCacheWarmUp(SupersetTestCase):
         # tag first slice
         dash = self.get_dash_by_slug("unicode-test")
         slc = dash.slices[0]
-        tag2_urls = [f"{URL_PREFIX}{slc.url}"]
+        tag2_urls = [f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}"]
         object_id = slc.id
         tagged_object = TaggedObject(
             tag_id=tag2.id, object_id=object_id, object_type=ObjectTypes.chart
diff --git a/tests/integration_tests/superset_test_config.py b/tests/integration_tests/superset_test_config.py
index 7c86232829..c3e80cb07a 100644
--- a/tests/integration_tests/superset_test_config.py
+++ b/tests/integration_tests/superset_test_config.py
@@ -66,6 +66,8 @@ FEATURE_FLAGS = {
     "DASHBOARD_NATIVE_FILTERS": True,
 }
 
+WEBDRIVER_BASEURL = "http://0.0.0.0:8081/"
+
 
 def GET_FEATURE_FLAGS_FUNC(ff):
     ff_copy = copy(ff)