You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by dp...@apache.org on 2023/06/20 11:08:38 UTC

[superset] branch master updated: chore: Migrate warm up cache endpoint to api v1 (#23853)

This is an automated email from the ASF dual-hosted git repository.

dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 5af298e1f6 chore: Migrate warm up cache endpoint to api v1 (#23853)
5af298e1f6 is described below

commit 5af298e1f6526af9a43baf566330ef9108f4c0d4
Author: Jack Fragassi <jf...@gmail.com>
AuthorDate: Tue Jun 20 04:08:29 2023 -0700

    chore: Migrate warm up cache endpoint to api v1 (#23853)
---
 superset/charts/api.py                             |  61 +++++++++++
 superset/charts/commands/exceptions.py             |   5 +
 superset/charts/commands/warm_up_cache.py          |  84 +++++++++++++++
 superset/charts/schemas.py                         |  38 +++++++
 superset/datasets/api.py                           |  66 ++++++++++++
 superset/datasets/commands/exceptions.py           |   5 +
 superset/datasets/commands/warm_up_cache.py        |  69 ++++++++++++
 superset/datasets/schemas.py                       |  40 +++++++
 superset/tasks/cache.py                            |  90 +++++++++-------
 tests/integration_tests/charts/api_tests.py        |  90 +++++++++++++++-
 tests/integration_tests/charts/commands_tests.py   |  30 +++++-
 tests/integration_tests/datasets/api_tests.py      | 116 +++++++++++++++++++++
 tests/integration_tests/datasets/commands_tests.py |  32 ++++++
 tests/integration_tests/strategy_tests.py          |  44 ++++----
 14 files changed, 704 insertions(+), 66 deletions(-)

diff --git a/superset/charts/api.py b/superset/charts/api.py
index 39b0c2dbf8..c87b7bdda8 100644
--- a/superset/charts/api.py
+++ b/superset/charts/api.py
@@ -47,6 +47,7 @@ from superset.charts.commands.exceptions import (
 from superset.charts.commands.export import ExportChartsCommand
 from superset.charts.commands.importers.dispatcher import ImportChartsCommand
 from superset.charts.commands.update import UpdateChartCommand
+from superset.charts.commands.warm_up_cache import ChartWarmUpCacheCommand
 from superset.charts.filters import (
     ChartAllTextFilter,
     ChartCertifiedFilter,
@@ -59,6 +60,7 @@ from superset.charts.filters import (
 )
 from superset.charts.schemas import (
     CHART_SCHEMAS,
+    ChartCacheWarmUpRequestSchema,
     ChartPostSchema,
     ChartPutSchema,
     get_delete_ids_schema,
@@ -68,6 +70,7 @@ from superset.charts.schemas import (
     screenshot_query_schema,
     thumbnail_query_schema,
 )
+from superset.commands.exceptions import CommandException
 from superset.commands.importers.exceptions import (
     IncorrectFormatError,
     NoValidFilesFoundError,
@@ -118,6 +121,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
         "thumbnail",
         "screenshot",
         "cache_screenshot",
+        "warm_up_cache",
     }
     class_permission_name = "Chart"
     method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
@@ -942,6 +946,63 @@ class ChartRestApi(BaseSupersetModelRestApi):
         ChartDAO.remove_favorite(chart)
         return self.response(200, result="OK")
 
+    @expose("/warm_up_cache", methods=("PUT",))
+    @protect()
+    @safe
+    @statsd_metrics
+    @event_logger.log_this_with_context(
+        action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
+        f".warm_up_cache",
+        log_to_statsd=False,
+    )
+    def warm_up_cache(self) -> Response:
+        """
+        ---
+        put:
+          summary: >-
+            Warms up the cache for the chart
+          description: >-
+            Warms up the cache for the chart.
+            Note for slices a force refresh occurs.
+            In terms of the `extra_filters` these can be obtained from records in the JSON
+            encoded `logs.json` column associated with the `explore_json` action.
+          requestBody:
+            description: >-
+              Identifies the chart to warm up cache for, and any additional dashboard or
+              filter context to use.
+            required: true
+            content:
+              application/json:
+                schema:
+                  $ref: "#/components/schemas/ChartCacheWarmUpRequestSchema"
+          responses:
+            200:
+              description: Each chart's warmup status
+              content:
+                application/json:
+                  schema:
+                    $ref: "#/components/schemas/ChartCacheWarmUpResponseSchema"
+            400:
+              $ref: '#/components/responses/400'
+            404:
+              $ref: '#/components/responses/404'
+            500:
+              $ref: '#/components/responses/500'
+        """
+        try:
+            body = ChartCacheWarmUpRequestSchema().load(request.json)
+        except ValidationError as error:
+            return self.response_400(message=error.messages)
+        try:
+            result = ChartWarmUpCacheCommand(
+                body["chart_id"],
+                body.get("dashboard_id"),
+                body.get("extra_filters"),
+            ).run()
+            return self.response(200, result=[result])
+        except CommandException as ex:
+            return self.response(ex.status, message=ex.message)
+
     @expose("/import/", methods=("POST",))
     @protect()
     @statsd_metrics
diff --git a/superset/charts/commands/exceptions.py b/superset/charts/commands/exceptions.py
index 6d5c078b12..1079cdca81 100644
--- a/superset/charts/commands/exceptions.py
+++ b/superset/charts/commands/exceptions.py
@@ -153,3 +153,8 @@ class ChartBulkDeleteFailedReportsExistError(ChartBulkDeleteFailedError):
 
 class ChartImportError(ImportFailedError):
     message = _("Import chart failed for an unknown reason")
+
+
+class WarmUpCacheChartNotFoundError(CommandException):
+    status = 404
+    message = _("Chart not found")
diff --git a/superset/charts/commands/warm_up_cache.py b/superset/charts/commands/warm_up_cache.py
new file mode 100644
index 0000000000..6fe9f94ffa
--- /dev/null
+++ b/superset/charts/commands/warm_up_cache.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from typing import Any, Optional, Union
+
+import simplejson as json
+from flask import g
+
+from superset.charts.commands.exceptions import WarmUpCacheChartNotFoundError
+from superset.commands.base import BaseCommand
+from superset.extensions import db
+from superset.models.slice import Slice
+from superset.utils.core import error_msg_from_exception
+from superset.views.utils import get_dashboard_extra_filters, get_form_data, get_viz
+
+
+class ChartWarmUpCacheCommand(BaseCommand):
+    # pylint: disable=too-many-arguments
+    def __init__(
+        self,
+        chart_or_id: Union[int, Slice],
+        dashboard_id: Optional[int],
+        extra_filters: Optional[str],
+    ):
+        self._chart_or_id = chart_or_id
+        self._dashboard_id = dashboard_id
+        self._extra_filters = extra_filters
+
+    def run(self) -> dict[str, Any]:
+        self.validate()
+        chart: Slice = self._chart_or_id  # type: ignore
+        try:
+            form_data = get_form_data(chart.id, use_slice_data=True)[0]
+            if self._dashboard_id:
+                form_data["extra_filters"] = (
+                    json.loads(self._extra_filters)
+                    if self._extra_filters
+                    else get_dashboard_extra_filters(chart.id, self._dashboard_id)
+                )
+
+            if not chart.datasource:
+                raise Exception("Chart's datasource does not exist")
+
+            obj = get_viz(
+                datasource_type=chart.datasource.type,
+                datasource_id=chart.datasource.id,
+                form_data=form_data,
+                force=True,
+            )
+
+            # pylint: disable=assigning-non-slot
+            g.form_data = form_data
+            payload = obj.get_payload()
+            delattr(g, "form_data")
+            error = payload["errors"] or None
+            status = payload["status"]
+        except Exception as ex:  # pylint: disable=broad-except
+            error = error_msg_from_exception(ex)
+            status = None
+
+        return {"chart_id": chart.id, "viz_error": error, "viz_status": status}
+
+    def validate(self) -> None:
+        if isinstance(self._chart_or_id, Slice):
+            return
+        chart = db.session.query(Slice).filter_by(id=self._chart_or_id).scalar()
+        if not chart:
+            raise WarmUpCacheChartNotFoundError()
+        self._chart_or_id = chart
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index a5e0a6c44c..1145d5be73 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -1557,7 +1557,45 @@ class ImportV1ChartSchema(Schema):
     external_url = fields.String(allow_none=True)
 
 
+class ChartCacheWarmUpRequestSchema(Schema):
+    chart_id = fields.Integer(
+        required=True,
+        metadata={"description": "The ID of the chart to warm up cache for"},
+    )
+    dashboard_id = fields.Integer(
+        metadata={
+            "description": "The ID of the dashboard to get filters for when warming cache"
+        }
+    )
+    extra_filters = fields.String(
+        metadata={"description": "Extra filters to apply when warming up cache"}
+    )
+
+
+class ChartCacheWarmUpResponseSingleSchema(Schema):
+    chart_id = fields.Integer(
+        metadata={"description": "The ID of the chart the status belongs to"}
+    )
+    viz_error = fields.String(
+        metadata={"description": "Error that occurred when warming cache for chart"}
+    )
+    viz_status = fields.String(
+        metadata={"description": "Status of the underlying query for the viz"}
+    )
+
+
+class ChartCacheWarmUpResponseSchema(Schema):
+    result = fields.List(
+        fields.Nested(ChartCacheWarmUpResponseSingleSchema),
+        metadata={
+            "description": "A list of each chart's warmup status and errors if any"
+        },
+    )
+
+
 CHART_SCHEMAS = (
+    ChartCacheWarmUpRequestSchema,
+    ChartCacheWarmUpResponseSchema,
     ChartDataQueryContextSchema,
     ChartDataResponseSchema,
     ChartDataAsyncResponseSchema,
diff --git a/superset/datasets/api.py b/superset/datasets/api.py
index b2457b066a..6e6cf38aad 100644
--- a/superset/datasets/api.py
+++ b/superset/datasets/api.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# pylint: disable=too-many-lines
 import json
 import logging
 from datetime import datetime
@@ -29,6 +30,7 @@ from flask_babel import ngettext
 from marshmallow import ValidationError
 
 from superset import event_logger, is_feature_enabled
+from superset.commands.exceptions import CommandException
 from superset.commands.importers.exceptions import NoValidFilesFoundError
 from superset.commands.importers.v1.utils import get_contents_from_bundle
 from superset.connectors.sqla.models import SqlaTable
@@ -53,8 +55,11 @@ from superset.datasets.commands.export import ExportDatasetsCommand
 from superset.datasets.commands.importers.dispatcher import ImportDatasetsCommand
 from superset.datasets.commands.refresh import RefreshDatasetCommand
 from superset.datasets.commands.update import UpdateDatasetCommand
+from superset.datasets.commands.warm_up_cache import DatasetWarmUpCacheCommand
 from superset.datasets.filters import DatasetCertifiedFilter, DatasetIsNullOrEmptyFilter
 from superset.datasets.schemas import (
+    DatasetCacheWarmUpRequestSchema,
+    DatasetCacheWarmUpResponseSchema,
     DatasetDuplicateSchema,
     DatasetPostSchema,
     DatasetPutSchema,
@@ -95,6 +100,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
         "related_objects",
         "duplicate",
         "get_or_create_dataset",
+        "warm_up_cache",
     }
     list_columns = [
         "id",
@@ -244,6 +250,8 @@ class DatasetRestApi(BaseSupersetModelRestApi):
         "get_export_ids_schema": get_export_ids_schema,
     }
     openapi_spec_component_schemas = (
+        DatasetCacheWarmUpRequestSchema,
+        DatasetCacheWarmUpResponseSchema,
         DatasetRelatedObjectsResponse,
         DatasetDuplicateSchema,
         GetOrCreateDatasetSchema,
@@ -992,3 +1000,61 @@ class DatasetRestApi(BaseSupersetModelRestApi):
                 exc_info=True,
             )
             return self.response_422(message=ex.message)
+
+    @expose("/warm_up_cache", methods=("PUT",))
+    @protect()
+    @safe
+    @statsd_metrics
+    @event_logger.log_this_with_context(
+        action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
+        f".warm_up_cache",
+        log_to_statsd=False,
+    )
+    def warm_up_cache(self) -> Response:
+        """
+        ---
+        put:
+          summary: >-
+            Warms up the cache for each chart powered by the given table
+          description: >-
+            Warms up the cache for the table.
+            Note for slices a force refresh occurs.
+            In terms of the `extra_filters` these can be obtained from records in the JSON
+            encoded `logs.json` column associated with the `explore_json` action.
+          requestBody:
+            description: >-
+              Identifies the database and table to warm up cache for, and any
+              additional dashboard or filter context to use.
+            required: true
+            content:
+              application/json:
+                schema:
+                  $ref: "#/components/schemas/DatasetCacheWarmUpRequestSchema"
+          responses:
+            200:
+              description: Each chart's warmup status
+              content:
+                application/json:
+                  schema:
+                    $ref: "#/components/schemas/DatasetCacheWarmUpResponseSchema"
+            400:
+              $ref: '#/components/responses/400'
+            404:
+              $ref: '#/components/responses/404'
+            500:
+              $ref: '#/components/responses/500'
+        """
+        try:
+            body = DatasetCacheWarmUpRequestSchema().load(request.json)
+        except ValidationError as error:
+            return self.response_400(message=error.messages)
+        try:
+            result = DatasetWarmUpCacheCommand(
+                body["db_name"],
+                body["table_name"],
+                body.get("dashboard_id"),
+                body.get("extra_filters"),
+            ).run()
+            return self.response(200, result=result)
+        except CommandException as ex:
+            return self.response(ex.status, message=ex.message)
diff --git a/superset/datasets/commands/exceptions.py b/superset/datasets/commands/exceptions.py
index e06e92802f..7c6ef86634 100644
--- a/superset/datasets/commands/exceptions.py
+++ b/superset/datasets/commands/exceptions.py
@@ -212,3 +212,8 @@ class DatasetDuplicateFailedError(CreateFailedError):
 
 class DatasetForbiddenDataURI(ImportFailedError):
     message = _("Data URI is not allowed.")
+
+
+class WarmUpCacheTableNotFoundError(CommandException):
+    status = 404
+    message = _("The provided table was not found in the provided database")
diff --git a/superset/datasets/commands/warm_up_cache.py b/superset/datasets/commands/warm_up_cache.py
new file mode 100644
index 0000000000..62044e7224
--- /dev/null
+++ b/superset/datasets/commands/warm_up_cache.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from typing import Any, Optional
+
+from superset.charts.commands.warm_up_cache import ChartWarmUpCacheCommand
+from superset.commands.base import BaseCommand
+from superset.connectors.sqla.models import SqlaTable
+from superset.datasets.commands.exceptions import WarmUpCacheTableNotFoundError
+from superset.extensions import db
+from superset.models.core import Database
+from superset.models.slice import Slice
+
+
+class DatasetWarmUpCacheCommand(BaseCommand):
+    # pylint: disable=too-many-arguments
+    def __init__(
+        self,
+        db_name: str,
+        table_name: str,
+        dashboard_id: Optional[int],
+        extra_filters: Optional[str],
+    ):
+        self._db_name = db_name
+        self._table_name = table_name
+        self._dashboard_id = dashboard_id
+        self._extra_filters = extra_filters
+        self._charts: list[Slice] = []
+
+    def run(self) -> list[dict[str, Any]]:
+        self.validate()
+        return [
+            ChartWarmUpCacheCommand(
+                chart, self._dashboard_id, self._extra_filters
+            ).run()
+            for chart in self._charts
+        ]
+
+    def validate(self) -> None:
+        table = (
+            db.session.query(SqlaTable)
+            .join(Database)
+            .filter(
+                Database.database_name == self._db_name,
+                SqlaTable.table_name == self._table_name,
+            )
+        ).one_or_none()
+        if not table:
+            raise WarmUpCacheTableNotFoundError()
+        self._charts = (
+            db.session.query(Slice)
+            .filter_by(datasource_id=table.id, datasource_type=table.type)
+            .all()
+        )
diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py
index 9a2af98066..f95897ce59 100644
--- a/superset/datasets/schemas.py
+++ b/superset/datasets/schemas.py
@@ -254,3 +254,43 @@ class DatasetSchema(SQLAlchemyAutoSchema):
         model = Dataset
         load_instance = True
         include_relationships = True
+
+
+class DatasetCacheWarmUpRequestSchema(Schema):
+    db_name = fields.String(
+        required=True,
+        metadata={"description": "The name of the database where the table is located"},
+    )
+    table_name = fields.String(
+        required=True,
+        metadata={"description": "The name of the table to warm up cache for"},
+    )
+    dashboard_id = fields.Integer(
+        metadata={
+            "description": "The ID of the dashboard to get filters for when warming cache"
+        }
+    )
+    extra_filters = fields.String(
+        metadata={"description": "Extra filters to apply when warming up cache"}
+    )
+
+
+class DatasetCacheWarmUpResponseSingleSchema(Schema):
+    chart_id = fields.Integer(
+        metadata={"description": "The ID of the chart the status belongs to"}
+    )
+    viz_error = fields.String(
+        metadata={"description": "Error that occurred when warming cache for chart"}
+    )
+    viz_status = fields.String(
+        metadata={"description": "Status of the underlying query for the viz"}
+    )
+
+
+class DatasetCacheWarmUpResponseSchema(Schema):
+    result = fields.List(
+        fields.Nested(DatasetCacheWarmUpResponseSingleSchema),
+        metadata={
+            "description": "A list of each chart's warmup status and errors if any"
+        },
+    )
diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py
index 448271269a..68b5657a22 100644
--- a/superset/tasks/cache.py
+++ b/superset/tasks/cache.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import json
 import logging
 from typing import Any, Optional, Union
 from urllib import request
@@ -36,22 +37,20 @@ logger = get_task_logger(__name__)
 logger.setLevel(logging.INFO)
 
 
-def get_url(chart: Slice, dashboard: Optional[Dashboard] = None) -> str:
-    """Return external URL for warming up a given chart/table cache."""
-    with app.test_request_context():
-        baseurl = "{WEBDRIVER_BASEURL}".format(**app.config)
-        url = f"{baseurl}superset/warm_up_cache/?slice_id={chart.id}"
-        if dashboard:
-            url += f"&dashboard_id={dashboard.id}"
-        return url
+def get_payload(chart: Slice, dashboard: Optional[Dashboard] = None) -> dict[str, int]:
+    """Return payload for warming up a given chart/table cache."""
+    payload = {"chart_id": chart.id}
+    if dashboard:
+        payload["dashboard_id"] = dashboard.id
+    return payload
 
 
 class Strategy:  # pylint: disable=too-few-public-methods
     """
     A cache warm up strategy.
 
-    Each strategy defines a `get_urls` method that returns a list of URLs to
-    be fetched from the `/superset/warm_up_cache/` endpoint.
+    Each strategy defines a `get_payloads` method that returns a list of payloads to
+    send to the `/api/v1/chart/warm_up_cache` endpoint.
 
     Strategies can be configured in `superset/config.py`:
 
@@ -72,8 +71,8 @@ class Strategy:  # pylint: disable=too-few-public-methods
     def __init__(self) -> None:
         pass
 
-    def get_urls(self) -> list[str]:
-        raise NotImplementedError("Subclasses must implement get_urls!")
+    def get_payloads(self) -> list[dict[str, int]]:
+        raise NotImplementedError("Subclasses must implement get_payloads!")
 
 
 class DummyStrategy(Strategy):  # pylint: disable=too-few-public-methods
@@ -94,11 +93,11 @@ class DummyStrategy(Strategy):  # pylint: disable=too-few-public-methods
 
     name = "dummy"
 
-    def get_urls(self) -> list[str]:
+    def get_payloads(self) -> list[dict[str, int]]:
         session = db.create_scoped_session()
         charts = session.query(Slice).all()
 
-        return [get_url(chart) for chart in charts]
+        return [get_payload(chart) for chart in charts]
 
 
 class TopNDashboardsStrategy(Strategy):  # pylint: disable=too-few-public-methods
@@ -126,8 +125,8 @@ class TopNDashboardsStrategy(Strategy):  # pylint: disable=too-few-public-method
         self.top_n = top_n
         self.since = parse_human_datetime(since) if since else None
 
-    def get_urls(self) -> list[str]:
-        urls = []
+    def get_payloads(self) -> list[dict[str, int]]:
+        payloads = []
         session = db.create_scoped_session()
 
         records = (
@@ -142,9 +141,9 @@ class TopNDashboardsStrategy(Strategy):  # pylint: disable=too-few-public-method
         dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
         for dashboard in dashboards:
             for chart in dashboard.slices:
-                urls.append(get_url(chart, dashboard))
+                payloads.append(get_payload(chart, dashboard))
 
-        return urls
+        return payloads
 
 
 class DashboardTagsStrategy(Strategy):  # pylint: disable=too-few-public-methods
@@ -169,8 +168,8 @@ class DashboardTagsStrategy(Strategy):  # pylint: disable=too-few-public-methods
         super().__init__()
         self.tags = tags or []
 
-    def get_urls(self) -> list[str]:
-        urls = []
+    def get_payloads(self) -> list[dict[str, int]]:
+        payloads = []
         session = db.create_scoped_session()
 
         tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
@@ -191,7 +190,7 @@ class DashboardTagsStrategy(Strategy):  # pylint: disable=too-few-public-methods
         tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids))
         for dashboard in tagged_dashboards:
             for chart in dashboard.slices:
-                urls.append(get_url(chart))
+                payloads.append(get_payload(chart))
 
         # add charts that are tagged
         tagged_objects = (
@@ -207,35 +206,46 @@ class DashboardTagsStrategy(Strategy):  # pylint: disable=too-few-public-methods
         chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
         tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
         for chart in tagged_charts:
-            urls.append(get_url(chart))
+            payloads.append(get_payload(chart))
 
-        return urls
+        return payloads
 
 
 strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy]
 
 
 @celery_app.task(name="fetch_url")
-def fetch_url(url: str, headers: dict[str, str]) -> dict[str, str]:
+def fetch_url(data: str, headers: dict[str, str]) -> dict[str, str]:
     """
     Celery job to fetch url
     """
     result = {}
     try:
-        logger.info("Fetching %s", url)
-        req = request.Request(url, headers=headers)
+        baseurl = "{WEBDRIVER_BASEURL}".format(**app.config)
+        url = f"{baseurl}api/v1/chart/warm_up_cache"
+        logger.info("Fetching %s with payload %s", url, data)
+        req = request.Request(
+            url, data=bytes(data, "utf-8"), headers=headers, method="PUT"
+        )
         response = request.urlopen(  # pylint: disable=consider-using-with
             req, timeout=600
         )
-        logger.info("Fetched %s, status code: %s", url, response.code)
+        logger.info(
+            "Fetched %s with payload %s, status code: %s", url, data, response.code
+        )
         if response.code == 200:
-            result = {"success": url, "response": response.read().decode("utf-8")}
+            result = {"success": data, "response": response.read().decode("utf-8")}
         else:
-            result = {"error": url, "status_code": response.code}
-            logger.error("Error fetching %s, status code: %s", url, response.code)
+            result = {"error": data, "status_code": response.code}
+            logger.error(
+                "Error fetching %s with payload %s, status code: %s",
+                url,
+                data,
+                response.code,
+            )
     except URLError as err:
         logger.exception("Error warming up cache!")
-        result = {"error": url, "exception": str(err)}
+        result = {"error": data, "exception": str(err)}
     return result
 
 
@@ -270,16 +280,20 @@ def cache_warmup(
 
     user = security_manager.get_user_by_username(app.config["THUMBNAIL_SELENIUM_USER"])
     cookies = MachineAuthProvider.get_auth_cookies(user)
-    headers = {"Cookie": f"session={cookies.get('session', '')}"}
+    headers = {
+        "Cookie": f"session={cookies.get('session', '')}",
+        "Content-Type": "application/json",
+    }
 
     results: dict[str, list[str]] = {"scheduled": [], "errors": []}
-    for url in strategy.get_urls():
+    for payload in strategy.get_payloads():
         try:
-            logger.info("Scheduling %s", url)
-            fetch_url.delay(url, headers)
-            results["scheduled"].append(url)
+            payload = json.dumps(payload)
+            logger.info("Scheduling %s", payload)
+            fetch_url.delay(payload, headers)
+            results["scheduled"].append(payload)
         except SchedulingError:
-            logger.exception("Error scheduling fetch_url: %s", url)
-            results["errors"].append(url)
+            logger.exception("Error scheduling fetch_url for payload: %s", payload)
+            results["errors"].append(payload)
 
     return results
diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py
index 60633c8894..69e99978e5 100644
--- a/tests/integration_tests/charts/api_tests.py
+++ b/tests/integration_tests/charts/api_tests.py
@@ -33,6 +33,7 @@ from superset.models.dashboard import Dashboard
 from superset.reports.models import ReportSchedule, ReportScheduleType
 from superset.models.slice import Slice
 from superset.utils.core import get_example_default_schema
+from superset.utils.database import get_example_database
 
 from tests.integration_tests.conftest import with_feature_flags
 from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin
@@ -199,7 +200,12 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
         rv = self.get_assert_metric(uri, "info")
         data = json.loads(rv.data.decode("utf-8"))
         assert rv.status_code == 200
-        assert set(data["permissions"]) == {"can_read", "can_write", "can_export"}
+        assert set(data["permissions"]) == {
+            "can_read",
+            "can_write",
+            "can_export",
+            "can_warm_up_cache",
+        }
 
     def create_chart_import(self):
         buf = BytesIO()
@@ -1682,3 +1688,85 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
 
         assert data["result"][0]["slice_name"] == "name0"
         assert data["result"][0]["datasource_id"] == 1
+
+    @pytest.mark.usefixtures(
+        "load_energy_table_with_slice", "load_birth_names_dashboard_with_slices"
+    )
+    def test_warm_up_cache(self):
+        self.login()
+        slc = self.get_slice("Girls", db.session)
+        rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": slc.id})
+        self.assertEqual(rv.status_code, 200)
+        data = json.loads(rv.data.decode("utf-8"))
+
+        self.assertEqual(
+            data["result"],
+            [{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}],
+        )
+
+        dashboard = self.get_dash_by_slug("births")
+
+        rv = self.client.put(
+            "/api/v1/chart/warm_up_cache",
+            json={"chart_id": slc.id, "dashboard_id": dashboard.id},
+        )
+        self.assertEqual(rv.status_code, 200)
+        data = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(
+            data["result"],
+            [{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}],
+        )
+
+        rv = self.client.put(
+            "/api/v1/chart/warm_up_cache",
+            json={
+                "chart_id": slc.id,
+                "dashboard_id": dashboard.id,
+                "extra_filters": json.dumps(
+                    [{"col": "name", "op": "in", "val": ["Jennifer"]}]
+                ),
+            },
+        )
+        self.assertEqual(rv.status_code, 200)
+        data = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(
+            data["result"],
+            [{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}],
+        )
+
+    def test_warm_up_cache_chart_id_required(self):
+        self.login()
+        rv = self.client.put("/api/v1/chart/warm_up_cache", json={"dashboard_id": 1})
+        self.assertEqual(rv.status_code, 400)
+        data = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(
+            data,
+            {"message": {"chart_id": ["Missing data for required field."]}},
+        )
+
+    def test_warm_up_cache_chart_not_found(self):
+        self.login()
+        rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": 99999})
+        self.assertEqual(rv.status_code, 404)
+        data = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(data, {"message": "Chart not found"})
+
+    def test_warm_up_cache_payload_validation(self):
+        self.login()
+        rv = self.client.put(
+            "/api/v1/chart/warm_up_cache",
+            json={"chart_id": "id", "dashboard_id": "id", "extra_filters": 4},
+        )
+        self.assertEqual(rv.status_code, 400)
+        data = json.loads(rv.data.decode("utf-8"))
+        print(data)
+        self.assertEqual(
+            data,
+            {
+                "message": {
+                    "chart_id": ["Not a valid integer."],
+                    "dashboard_id": ["Not a valid integer."],
+                    "extra_filters": ["Not a valid string."],
+                }
+            },
+        )
diff --git a/tests/integration_tests/charts/commands_tests.py b/tests/integration_tests/charts/commands_tests.py
index 4d365d56b5..217b1655a5 100644
--- a/tests/integration_tests/charts/commands_tests.py
+++ b/tests/integration_tests/charts/commands_tests.py
@@ -23,16 +23,24 @@ from flask import g
 
 from superset import db, security_manager
 from superset.charts.commands.create import CreateChartCommand
-from superset.charts.commands.exceptions import ChartNotFoundError
+from superset.charts.commands.exceptions import (
+    ChartNotFoundError,
+    WarmUpCacheChartNotFoundError,
+)
 from superset.charts.commands.export import ExportChartsCommand
 from superset.charts.commands.importers.v1 import ImportChartsCommand
 from superset.charts.commands.update import UpdateChartCommand
+from superset.charts.commands.warm_up_cache import ChartWarmUpCacheCommand
 from superset.commands.exceptions import CommandInvalidError
 from superset.commands.importers.exceptions import IncorrectVersionError
 from superset.connectors.sqla.models import SqlaTable
 from superset.models.core import Database
 from superset.models.slice import Slice
 from tests.integration_tests.base_tests import SupersetTestCase
+from tests.integration_tests.fixtures.birth_names_dashboard import (
+    load_birth_names_dashboard_with_slices,
+    load_birth_names_data,
+)
 from tests.integration_tests.fixtures.energy_dashboard import (
     load_energy_table_data,
     load_energy_table_with_slice,
@@ -442,3 +450,23 @@ class TestChartsUpdateCommand(SupersetTestCase):
         assert chart.query_context == query_context
         assert len(chart.owners) == 1
         assert chart.owners[0] == admin
+
+
+class TestChartWarmUpCacheCommand(SupersetTestCase):
+    def test_warm_up_cache_command_chart_not_found(self):
+        with self.assertRaises(WarmUpCacheChartNotFoundError):
+            ChartWarmUpCacheCommand(99999, None, None).run()
+
+    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+    def test_warm_up_cache(self):
+        slc = self.get_slice("Girls", db.session)
+        result = ChartWarmUpCacheCommand(slc.id, None, None).run()
+        self.assertEqual(
+            result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
+        )
+
+        # can just pass in chart as well
+        result = ChartWarmUpCacheCommand(slc, None, None).run()
+        self.assertEqual(
+            result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
+        )
diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py
index 55fda1af65..2f55a1e978 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -39,6 +39,7 @@ from superset.datasets.commands.exceptions import DatasetCreateFailedError
 from superset.datasets.models import Dataset
 from superset.extensions import db, security_manager
 from superset.models.core import Database
+from superset.models.slice import Slice
 from superset.utils.core import backend, get_example_default_schema
 from superset.utils.database import get_example_database, get_main_database
 from superset.utils.dict_import_export import export_to_dict
@@ -514,6 +515,7 @@ class TestDatasetApi(SupersetTestCase):
             "can_export",
             "can_duplicate",
             "can_get_or_create_dataset",
+            "can_warm_up_cache",
         }
 
     def test_create_dataset_item(self):
@@ -2501,3 +2503,117 @@ class TestDatasetApi(SupersetTestCase):
         with examples_db.get_sqla_engine_with_context() as engine:
             engine.execute("DROP TABLE test_create_sqla_table_api")
         db.session.commit()
+
+    @pytest.mark.usefixtures(
+        "load_energy_table_with_slice", "load_birth_names_dashboard_with_slices"
+    )
+    def test_warm_up_cache(self):
+        """
+        Dataset API: Test warm up cache endpoint
+        """
+        self.login()
+        energy_table = self.get_energy_usage_dataset()
+        energy_charts = (
+            db.session.query(Slice)
+            .filter(
+                Slice.datasource_id == energy_table.id, Slice.datasource_type == "table"
+            )
+            .all()
+        )
+        rv = self.client.put(
+            "/api/v1/dataset/warm_up_cache",
+            json={
+                "table_name": "energy_usage",
+                "db_name": get_example_database().database_name,
+            },
+        )
+        self.assertEqual(rv.status_code, 200)
+        data = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(
+            len(data["result"]),
+            len(energy_charts),
+        )
+        for chart_result in data["result"]:
+            assert "chart_id" in chart_result
+            assert "viz_error" in chart_result
+            assert "viz_status" in chart_result
+
+        # With dashboard id
+        dashboard = self.get_dash_by_slug("births")
+        birth_table = self.get_birth_names_dataset()
+        birth_charts = (
+            db.session.query(Slice)
+            .filter(
+                Slice.datasource_id == birth_table.id, Slice.datasource_type == "table"
+            )
+            .all()
+        )
+        rv = self.client.put(
+            "/api/v1/dataset/warm_up_cache",
+            json={
+                "table_name": "birth_names",
+                "db_name": get_example_database().database_name,
+                "dashboard_id": dashboard.id,
+            },
+        )
+        self.assertEqual(rv.status_code, 200)
+        data = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(
+            len(data["result"]),
+            len(birth_charts),
+        )
+        for chart_result in data["result"]:
+            assert "chart_id" in chart_result
+            assert "viz_error" in chart_result
+            assert "viz_status" in chart_result
+
+        # With extra filters
+        rv = self.client.put(
+            "/api/v1/dataset/warm_up_cache",
+            json={
+                "table_name": "birth_names",
+                "db_name": get_example_database().database_name,
+                "dashboard_id": dashboard.id,
+                "extra_filters": json.dumps(
+                    [{"col": "name", "op": "in", "val": ["Jennifer"]}]
+                ),
+            },
+        )
+        self.assertEqual(rv.status_code, 200)
+        data = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(
+            len(data["result"]),
+            len(birth_charts),
+        )
+        for chart_result in data["result"]:
+            assert "chart_id" in chart_result
+            assert "viz_error" in chart_result
+            assert "viz_status" in chart_result
+
+    def test_warm_up_cache_db_and_table_name_required(self):
+        self.login()
+        rv = self.client.put("/api/v1/dataset/warm_up_cache", json={"dashboard_id": 1})
+        self.assertEqual(rv.status_code, 400)
+        data = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(
+            data,
+            {
+                "message": {
+                    "db_name": ["Missing data for required field."],
+                    "table_name": ["Missing data for required field."],
+                }
+            },
+        )
+
+    def test_warm_up_cache_table_not_found(self):
+        self.login()
+        rv = self.client.put(
+            "/api/v1/dataset/warm_up_cache",
+            json={"table_name": "not_here", "db_name": "abc"},
+        )
+        self.assertEqual(rv.status_code, 404)
+        data = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(
+            data,
+            {"message": "The provided table was not found in the provided database"},
+        )
diff --git a/tests/integration_tests/datasets/commands_tests.py b/tests/integration_tests/datasets/commands_tests.py
index 953c34059f..34a0625b36 100644
--- a/tests/integration_tests/datasets/commands_tests.py
+++ b/tests/integration_tests/datasets/commands_tests.py
@@ -31,13 +31,20 @@ from superset.datasets.commands.create import CreateDatasetCommand
 from superset.datasets.commands.exceptions import (
     DatasetInvalidError,
     DatasetNotFoundError,
+    WarmUpCacheTableNotFoundError,
 )
 from superset.datasets.commands.export import ExportDatasetsCommand
 from superset.datasets.commands.importers import v0, v1
+from superset.datasets.commands.warm_up_cache import DatasetWarmUpCacheCommand
 from superset.models.core import Database
+from superset.models.slice import Slice
 from superset.utils.core import get_example_default_schema
 from superset.utils.database import get_example_database
 from tests.integration_tests.base_tests import SupersetTestCase
+from tests.integration_tests.fixtures.birth_names_dashboard import (
+    load_birth_names_dashboard_with_slices,
+    load_birth_names_data,
+)
 from tests.integration_tests.fixtures.energy_dashboard import (
     load_energy_table_data,
     load_energy_table_with_slice,
@@ -575,3 +582,28 @@ class TestCreateDatasetCommand(SupersetTestCase):
         with examples_db.get_sqla_engine_with_context() as engine:
             engine.execute("DROP TABLE test_create_dataset_command")
         db.session.commit()
+
+
+class TestDatasetWarmUpCacheCommand(SupersetTestCase):
+    def test_warm_up_cache_command_table_not_found(self):
+        with self.assertRaises(WarmUpCacheTableNotFoundError):
+            DatasetWarmUpCacheCommand("not", "here", None, None).run()
+
+    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+    def test_warm_up_cache(self):
+        birth_table = self.get_birth_names_dataset()
+        birth_charts = (
+            db.session.query(Slice)
+            .filter(
+                Slice.datasource_id == birth_table.id, Slice.datasource_type == "table"
+            )
+            .all()
+        )
+        results = DatasetWarmUpCacheCommand(
+            get_example_database().database_name, "birth_names", None, None
+        ).run()
+        self.assertEqual(len(results), len(birth_charts))
+        for chart_result in results:
+            assert "chart_id" in chart_result
+            assert "viz_error" in chart_result
+            assert "viz_status" in chart_result
diff --git a/tests/integration_tests/strategy_tests.py b/tests/integration_tests/strategy_tests.py
index f6d664c649..6fec16ca74 100644
--- a/tests/integration_tests/strategy_tests.py
+++ b/tests/integration_tests/strategy_tests.py
@@ -76,14 +76,11 @@ class TestCacheWarmUp(SupersetTestCase):
             self.client.get(f"/superset/dashboard/{dash.id}/")
 
         strategy = TopNDashboardsStrategy(1)
-        result = sorted(strategy.get_urls())
-        expected = sorted(
-            [
-                f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}&dashboard_id={dash.id}"
-                for slc in dash.slices
-            ]
-        )
-        self.assertEqual(result, expected)
+        result = strategy.get_payloads()
+        expected = [
+            {"chart_id": chart.id, "dashboard_id": dash.id} for chart in dash.slices
+        ]
+        self.assertCountEqual(result, expected)
 
     def reset_tag(self, tag):
         """Remove associated object from tag, used to reset tests"""
@@ -95,57 +92,52 @@ class TestCacheWarmUp(SupersetTestCase):
     @pytest.mark.usefixtures(
         "load_unicode_dashboard_with_slice", "load_birth_names_dashboard_with_slices"
     )
-    def test_dashboard_tags(self):
+    def test_dashboard_tags_strategy(self):
         tag1 = get_tag("tag1", db.session, TagTypes.custom)
         # delete first to make test idempotent
         self.reset_tag(tag1)
 
         strategy = DashboardTagsStrategy(["tag1"])
-        result = sorted(strategy.get_urls())
+        result = strategy.get_payloads()
         expected = []
         self.assertEqual(result, expected)
 
         # tag dashboard 'births' with `tag1`
         tag1 = get_tag("tag1", db.session, TagTypes.custom)
         dash = self.get_dash_by_slug("births")
-        tag1_urls = sorted(
-            [
-                f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}"
-                for slc in dash.slices
-            ]
-        )
+        tag1_urls = [{"chart_id": chart.id} for chart in dash.slices]
         tagged_object = TaggedObject(
             tag_id=tag1.id, object_id=dash.id, object_type=ObjectTypes.dashboard
         )
         db.session.add(tagged_object)
         db.session.commit()
 
-        self.assertEqual(sorted(strategy.get_urls()), tag1_urls)
+        self.assertCountEqual(strategy.get_payloads(), tag1_urls)
 
         strategy = DashboardTagsStrategy(["tag2"])
         tag2 = get_tag("tag2", db.session, TagTypes.custom)
         self.reset_tag(tag2)
 
-        result = sorted(strategy.get_urls())
+        result = strategy.get_payloads()
         expected = []
         self.assertEqual(result, expected)
 
         # tag first slice
         dash = self.get_dash_by_slug("unicode-test")
-        slc = dash.slices[0]
-        tag2_urls = [f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}"]
-        object_id = slc.id
+        chart = dash.slices[0]
+        tag2_urls = [{"chart_id": chart.id}]
+        object_id = chart.id
         tagged_object = TaggedObject(
             tag_id=tag2.id, object_id=object_id, object_type=ObjectTypes.chart
         )
         db.session.add(tagged_object)
         db.session.commit()
 
-        result = sorted(strategy.get_urls())
-        self.assertEqual(result, tag2_urls)
+        result = strategy.get_payloads()
+        self.assertCountEqual(result, tag2_urls)
 
         strategy = DashboardTagsStrategy(["tag1", "tag2"])
 
-        result = sorted(strategy.get_urls())
-        expected = sorted(tag1_urls + tag2_urls)
-        self.assertEqual(result, expected)
+        result = strategy.get_payloads()
+        expected = tag1_urls + tag2_urls
+        self.assertCountEqual(result, expected)