You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by dp...@apache.org on 2020/11/25 08:51:04 UTC
[incubator-superset] branch master updated: feat: new reports
scheduler (#11711)
This is an automated email from the ASF dual-hosted git repository.
dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new f27ebc4 feat: new reports scheduler (#11711)
f27ebc4 is described below
commit f27ebc4be5932ee9733a7635390eabdf4452f67a
Author: Daniel Vaz Gaspar <da...@gmail.com>
AuthorDate: Wed Nov 25 08:50:30 2020 +0000
feat: new reports scheduler (#11711)
* feat(reports): scheduler and delivery system
* working version
* improvements and fix grace_period
* add tests and fix bugs
* fix report API test
* test MySQL test fail
* delete-orphans
* fix MySQL tests
* address comments
* lint
---
superset/dao/base.py | 8 +-
superset/exceptions.py | 4 +-
superset/models/reports.py | 4 +-
superset/reports/commands/alert.py | 101 ++++
superset/reports/commands/exceptions.py | 37 ++
superset/reports/commands/execute.py | 256 ++++++++++
superset/reports/commands/log_prune.py | 48 ++
superset/reports/dao.py | 55 ++-
.../notifications/__init__.py} | 33 +-
superset/reports/notifications/base.py | 62 +++
superset/reports/notifications/email.py | 98 ++++
.../notifications/exceptions.py} | 14 +-
superset/reports/notifications/slack.py | 89 ++++
superset/tasks/celery_app.py | 2 +-
superset/tasks/scheduler.py | 69 +++
superset/utils/urls.py | 12 +-
tests/reports/api_tests.py | 60 +--
tests/reports/commands_tests.py | 531 +++++++++++++++++++++
tests/reports/utils.py | 68 +++
19 files changed, 1463 insertions(+), 88 deletions(-)
diff --git a/superset/dao/base.py b/superset/dao/base.py
index 6b33c4e..abfa4ac 100644
--- a/superset/dao/base.py
+++ b/superset/dao/base.py
@@ -20,6 +20,7 @@ from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.models.sqla.interface import SQLAInterface
from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.orm import Session
from superset.dao.exceptions import (
DAOConfigError,
@@ -46,13 +47,14 @@ class BaseDAO:
"""
@classmethod
- def find_by_id(cls, model_id: int) -> Model:
+ def find_by_id(cls, model_id: int, session: Session = None) -> Model:
"""
Find a model by id, if defined applies `base_filter`
"""
- query = db.session.query(cls.model_cls)
+ session = session or db.session
+ query = session.query(cls.model_cls)
if cls.base_filter:
- data_model = SQLAInterface(cls.model_cls, db.session)
+ data_model = SQLAInterface(cls.model_cls, session)
query = cls.base_filter( # pylint: disable=not-callable
"id", data_model
).apply(query, None)
diff --git a/superset/exceptions.py b/superset/exceptions.py
index c0d55f8..fd95a59 100644
--- a/superset/exceptions.py
+++ b/superset/exceptions.py
@@ -25,7 +25,9 @@ class SupersetException(Exception):
status = 500
message = ""
- def __init__(self, message: str = "", exception: Optional[Exception] = None):
+ def __init__(
+ self, message: str = "", exception: Optional[Exception] = None,
+ ) -> None:
if message:
self.message = message
self._exception = exception
diff --git a/superset/models/reports.py b/superset/models/reports.py
index 731d1f9..7b1f183 100644
--- a/superset/models/reports.py
+++ b/superset/models/reports.py
@@ -60,7 +60,9 @@ class ReportRecipientType(str, enum.Enum):
class ReportLogState(str, enum.Enum):
SUCCESS = "Success"
+ WORKING = "Working"
ERROR = "Error"
+ NOOP = "Not triggered"
class ReportEmailFormat(str, enum.Enum):
@@ -175,6 +177,6 @@ class ReportExecutionLog(Model): # pylint: disable=too-few-public-methods
)
report_schedule = relationship(
ReportSchedule,
- backref=backref("logs", cascade="all,delete"),
+ backref=backref("logs", cascade="all,delete,delete-orphan"),
foreign_keys=[report_schedule_id],
)
diff --git a/superset/reports/commands/alert.py b/superset/reports/commands/alert.py
new file mode 100644
index 0000000..cab294c
--- /dev/null
+++ b/superset/reports/commands/alert.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import json
+import logging
+from operator import eq, ge, gt, le, lt, ne
+from typing import Optional
+
+import numpy as np
+from flask_babel import lazy_gettext as _
+
+from superset import jinja_context
+from superset.commands.base import BaseCommand
+from superset.models.reports import ReportSchedule, ReportScheduleValidatorType
+from superset.reports.commands.exceptions import (
+ AlertQueryInvalidTypeError,
+ AlertQueryMultipleColumnsError,
+ AlertQueryMultipleRowsError,
+)
+
+logger = logging.getLogger(__name__)
+
+
+OPERATOR_FUNCTIONS = {">=": ge, ">": gt, "<=": le, "<": lt, "==": eq, "!=": ne}
+
+
+class AlertCommand(BaseCommand):
+ def __init__(self, report_schedule: ReportSchedule):
+ self._report_schedule = report_schedule
+ self._result: Optional[float] = None
+
+ def run(self) -> bool:
+ self.validate()
+
+ if self._report_schedule.validator_type == ReportScheduleValidatorType.NOT_NULL:
+ self._report_schedule.last_value_row_json = self._result
+ return self._result not in (0, None, np.nan)
+ self._report_schedule.last_value = self._result
+ operator = json.loads(self._report_schedule.validator_config_json)["op"]
+ threshold = json.loads(self._report_schedule.validator_config_json)["threshold"]
+ return OPERATOR_FUNCTIONS[operator](self._result, threshold)
+
+ def _validate_not_null(self, rows: np.recarray) -> None:
+ self._result = rows[0][1]
+
+ def _validate_operator(self, rows: np.recarray) -> None:
+ # check if query return more then one row
+ if len(rows) > 1:
+ raise AlertQueryMultipleRowsError(
+ message=_(
+ "Alert query returned more then one row. %s rows returned"
+ % len(rows),
+ )
+ )
+ # check if query returned more then one column
+ if len(rows[0]) > 2:
+ raise AlertQueryMultipleColumnsError(
+ _(
+ "Alert query returned more then one column. %s columns returned"
+ % len(rows[0])
+ )
+ )
+ if rows[0][1] is None:
+ return
+ try:
+ # Check if it's float or if we can convert it
+ self._result = float(rows[0][1])
+ return
+ except (AssertionError, TypeError, ValueError):
+ raise AlertQueryInvalidTypeError()
+
+ def validate(self) -> None:
+ """
+ Validate the query result as a Pandas DataFrame
+ """
+ sql_template = jinja_context.get_template_processor(
+ database=self._report_schedule.database
+ )
+ rendered_sql = sql_template.process_template(self._report_schedule.sql)
+ df = self._report_schedule.database.get_df(rendered_sql)
+
+ if df.empty:
+ return
+ rows = df.to_records()
+ if self._report_schedule.validator_type == ReportScheduleValidatorType.NOT_NULL:
+ self._validate_not_null(rows)
+ return
+ self._validate_operator(rows)
diff --git a/superset/reports/commands/exceptions.py b/superset/reports/commands/exceptions.py
index 23a2142..3a56a49 100644
--- a/superset/reports/commands/exceptions.py
+++ b/superset/reports/commands/exceptions.py
@@ -103,6 +103,22 @@ class ReportScheduleDeleteFailedError(CommandException):
message = _("Report Schedule delete failed.")
+class PruneReportScheduleLogFailedError(CommandException):
+ message = _("Report Schedule log prune failed.")
+
+
+class ReportScheduleScreenshotFailedError(CommandException):
+ message = _("Report Schedule execution failed when generating a screenshot.")
+
+
+class ReportScheduleExecuteUnexpectedError(CommandException):
+ message = _("Report Schedule execution got an unexpected error.")
+
+
+class ReportSchedulePreviousWorkingError(CommandException):
+ message = _("Report Schedule is still working, refusing to re-compute.")
+
+
class ReportScheduleNameUniquenessValidationError(ValidationError):
"""
Marshmallow validation error for Report Schedule name already exists
@@ -110,3 +126,24 @@ class ReportScheduleNameUniquenessValidationError(ValidationError):
def __init__(self) -> None:
super().__init__([_("Name must be unique")], field_name="name")
+
+
+class AlertQueryMultipleRowsError(CommandException):
+
+ message = _("Alert query returned more then one row.")
+
+
+class AlertQueryMultipleColumnsError(CommandException):
+ message = _("Alert query returned more then one column.")
+
+
+class AlertQueryInvalidTypeError(CommandException):
+ message = _("Alert query returned a non-number value.")
+
+
+class ReportScheduleAlertGracePeriodError(CommandException):
+ message = _("Alert fired during grace period.")
+
+
+class ReportScheduleNotificationError(CommandException):
+ message = _("Alert on grace period")
diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py
new file mode 100644
index 0000000..bb33847
--- /dev/null
+++ b/superset/reports/commands/execute.py
@@ -0,0 +1,256 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import logging
+from datetime import datetime, timedelta
+from typing import Optional
+
+from sqlalchemy.orm import Session
+
+from superset import app, thumbnail_cache
+from superset.commands.base import BaseCommand
+from superset.commands.exceptions import CommandException
+from superset.extensions import security_manager
+from superset.models.reports import (
+ ReportExecutionLog,
+ ReportLogState,
+ ReportSchedule,
+ ReportScheduleType,
+)
+from superset.reports.commands.alert import AlertCommand
+from superset.reports.commands.exceptions import (
+ ReportScheduleAlertGracePeriodError,
+ ReportScheduleExecuteUnexpectedError,
+ ReportScheduleNotFoundError,
+ ReportScheduleNotificationError,
+ ReportSchedulePreviousWorkingError,
+ ReportScheduleScreenshotFailedError,
+)
+from superset.reports.dao import ReportScheduleDAO
+from superset.reports.notifications import create_notification
+from superset.reports.notifications.base import NotificationContent, ScreenshotData
+from superset.reports.notifications.exceptions import NotificationError
+from superset.utils.celery import session_scope
+from superset.utils.screenshots import (
+ BaseScreenshot,
+ ChartScreenshot,
+ DashboardScreenshot,
+)
+from superset.utils.urls import get_url_path
+
+logger = logging.getLogger(__name__)
+
+
+class AsyncExecuteReportScheduleCommand(BaseCommand):
+ """
+ Execute all types of report schedules.
+ - On reports takes chart or dashboard screenshots and sends configured notifications
+ - On Alerts uses related Command AlertCommand and sends configured notifications
+ """
+
+ def __init__(self, model_id: int, scheduled_dttm: datetime):
+ self._model_id = model_id
+ self._model: Optional[ReportSchedule] = None
+ self._scheduled_dttm = scheduled_dttm
+
+ def set_state_and_log(
+ self,
+ session: Session,
+ start_dttm: datetime,
+ state: ReportLogState,
+ error_message: Optional[str] = None,
+ ) -> None:
+ """
+ Updates current ReportSchedule state and TS. If on final state writes the log
+ for this execution
+ """
+ now_dttm = datetime.utcnow()
+ if state == ReportLogState.WORKING:
+ self.set_state(session, state, now_dttm)
+ return
+ self.set_state(session, state, now_dttm)
+ self.create_log(
+ session, start_dttm, now_dttm, state, error_message=error_message,
+ )
+
+ def set_state(
+ self, session: Session, state: ReportLogState, dttm: datetime
+ ) -> None:
+ """
+ Set the current report schedule state, on this case we want to
+ commit immediately
+ """
+ if self._model:
+ self._model.last_state = state
+ self._model.last_eval_dttm = dttm
+ session.commit()
+
+ def create_log( # pylint: disable=too-many-arguments
+ self,
+ session: Session,
+ start_dttm: datetime,
+ end_dttm: datetime,
+ state: ReportLogState,
+ error_message: Optional[str] = None,
+ ) -> None:
+ """
+ Creates a Report execution log, uses the current computed last_value for Alerts
+ """
+ if self._model:
+ log = ReportExecutionLog(
+ scheduled_dttm=self._scheduled_dttm,
+ start_dttm=start_dttm,
+ end_dttm=end_dttm,
+ value=self._model.last_value,
+ value_row_json=self._model.last_value_row_json,
+ state=state,
+ error_message=error_message,
+ report_schedule=self._model,
+ )
+ session.add(log)
+
+ @staticmethod
+ def _get_url(report_schedule: ReportSchedule, user_friendly: bool = False) -> str:
+ """
+ Get the url for this report schedule: chart or dashboard
+ """
+ if report_schedule.chart:
+ return get_url_path(
+ "Superset.slice",
+ user_friendly=user_friendly,
+ slice_id=report_schedule.chart_id,
+ standalone="true",
+ )
+ return get_url_path(
+ "Superset.dashboard",
+ user_friendly=user_friendly,
+ dashboard_id_or_slug=report_schedule.dashboard_id,
+ )
+
+ def _get_screenshot(self, report_schedule: ReportSchedule) -> ScreenshotData:
+ """
+ Get a chart or dashboard screenshot
+ :raises: ReportScheduleScreenshotFailedError
+ """
+ url = self._get_url(report_schedule)
+ screenshot: Optional[BaseScreenshot] = None
+ if report_schedule.chart:
+ screenshot = ChartScreenshot(url, report_schedule.chart.digest)
+ else:
+ screenshot = DashboardScreenshot(url, report_schedule.dashboard.digest)
+ image_url = self._get_url(report_schedule, user_friendly=True)
+ user = security_manager.find_user(app.config["THUMBNAIL_SELENIUM_USER"])
+ image_data = screenshot.compute_and_cache(
+ user=user, cache=thumbnail_cache, force=True,
+ )
+ if not image_data:
+ raise ReportScheduleScreenshotFailedError()
+ return ScreenshotData(url=image_url, image=image_data)
+
+ def _get_notification_content(
+ self, report_schedule: ReportSchedule
+ ) -> NotificationContent:
+ """
+ Gets a notification content, this is composed by a title and a screenshot
+ :raises: ReportScheduleScreenshotFailedError
+ """
+ screenshot_data = self._get_screenshot(report_schedule)
+ if report_schedule.chart:
+ name = report_schedule.chart.slice_name
+ else:
+ name = report_schedule.dashboard.dashboard_title
+ return NotificationContent(name=name, screenshot=screenshot_data)
+
+ def _send(self, report_schedule: ReportSchedule) -> None:
+ """
+ Creates the notification content and sends them to all recipients
+
+ :raises: ReportScheduleNotificationError
+ """
+ notification_errors = []
+ notification_content = self._get_notification_content(report_schedule)
+ for recipient in report_schedule.recipients:
+ notification = create_notification(recipient, notification_content)
+ try:
+ notification.send()
+ except NotificationError as ex:
+ # collect notification errors but keep processing them
+ notification_errors.append(str(ex))
+ if notification_errors:
+ raise ReportScheduleNotificationError(";".join(notification_errors))
+
+ def run(self) -> None:
+ with session_scope(nullpool=True) as session:
+ try:
+ start_dttm = datetime.utcnow()
+ self.validate(session=session)
+ if not self._model:
+ raise ReportScheduleExecuteUnexpectedError()
+ self.set_state_and_log(session, start_dttm, ReportLogState.WORKING)
+ # If it's an alert check if the alert is triggered
+ if self._model.type == ReportScheduleType.ALERT:
+ if not AlertCommand(self._model).run():
+ self.set_state_and_log(session, start_dttm, ReportLogState.NOOP)
+ return
+
+ self._send(self._model)
+
+ # Log, state and TS
+ self.set_state_and_log(session, start_dttm, ReportLogState.SUCCESS)
+ except ReportScheduleAlertGracePeriodError as ex:
+ self.set_state_and_log(
+ session, start_dttm, ReportLogState.NOOP, error_message=str(ex)
+ )
+ except ReportSchedulePreviousWorkingError as ex:
+ self.create_log(
+ session,
+ start_dttm,
+ datetime.utcnow(),
+ state=ReportLogState.ERROR,
+ error_message=str(ex),
+ )
+ session.commit()
+ raise
+ except CommandException as ex:
+ self.set_state_and_log(
+ session, start_dttm, ReportLogState.ERROR, error_message=str(ex)
+ )
+ # We want to actually commit the state and log inside the scope
+ session.commit()
+ raise
+
+ def validate( # pylint: disable=arguments-differ
+ self, session: Session = None
+ ) -> None:
+ # Validate/populate model exists
+ self._model = ReportScheduleDAO.find_by_id(self._model_id, session=session)
+ if not self._model:
+ raise ReportScheduleNotFoundError()
+ # Avoid overlap processing
+ if self._model.last_state == ReportLogState.WORKING:
+ raise ReportSchedulePreviousWorkingError()
+ # Check grace period
+ if self._model.type == ReportScheduleType.ALERT:
+ last_success = ReportScheduleDAO.find_last_success_log(session)
+ if (
+ last_success
+ and self._model.last_state
+ in (ReportLogState.SUCCESS, ReportLogState.NOOP)
+ and self._model.grace_period
+ and datetime.utcnow() - timedelta(seconds=self._model.grace_period)
+ < last_success.end_dttm
+ ):
+ raise ReportScheduleAlertGracePeriodError()
diff --git a/superset/reports/commands/log_prune.py b/superset/reports/commands/log_prune.py
new file mode 100644
index 0000000..9825a35
--- /dev/null
+++ b/superset/reports/commands/log_prune.py
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import logging
+from datetime import datetime, timedelta
+
+from superset.commands.base import BaseCommand
+from superset.models.reports import ReportSchedule
+from superset.reports.dao import ReportScheduleDAO
+from superset.utils.celery import session_scope
+
+logger = logging.getLogger(__name__)
+
+
+class AsyncPruneReportScheduleLogCommand(BaseCommand):
+ """
+ Prunes logs from all report schedules
+ """
+
+ def __init__(self, worker_context: bool = True):
+ self._worker_context = worker_context
+
+ def run(self) -> None:
+ with session_scope(nullpool=True) as session:
+ self.validate()
+ for report_schedule in session.query(ReportSchedule).all():
+ from_date = datetime.utcnow() - timedelta(
+ days=report_schedule.log_retention
+ )
+ ReportScheduleDAO.bulk_delete_logs(
+ report_schedule, from_date, session=session, commit=False
+ )
+
+ def validate(self) -> None:
+ pass
diff --git a/superset/reports/dao.py b/superset/reports/dao.py
index e02770a..6081fc8 100644
--- a/superset/reports/dao.py
+++ b/superset/reports/dao.py
@@ -15,15 +15,22 @@
# specific language governing permissions and limitations
# under the License.
import logging
+from datetime import datetime
from typing import Any, Dict, List, Optional
from flask_appbuilder import Model
from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.orm import Session
from superset.dao.base import BaseDAO
from superset.dao.exceptions import DAOCreateFailedError, DAODeleteFailedError
from superset.extensions import db
-from superset.models.reports import ReportRecipients, ReportSchedule
+from superset.models.reports import (
+ ReportExecutionLog,
+ ReportLogState,
+ ReportRecipients,
+ ReportSchedule,
+)
logger = logging.getLogger(__name__)
@@ -135,3 +142,49 @@ class ReportScheduleDAO(BaseDAO):
except SQLAlchemyError:
db.session.rollback()
raise DAOCreateFailedError
+
+ @staticmethod
+ def find_active(session: Optional[Session] = None) -> List[ReportSchedule]:
+ """
+ Find all active reports. If session is passed it will be used instead of the
+ default `db.session`, this is useful when on a celery worker session context
+ """
+ session = session or db.session
+ return (
+ session.query(ReportSchedule).filter(ReportSchedule.active.is_(True)).all()
+ )
+
+ @staticmethod
+ def find_last_success_log(
+ session: Optional[Session] = None,
+ ) -> Optional[ReportExecutionLog]:
+ """
+ Finds last success execution log
+ """
+ session = session or db.session
+ return (
+ session.query(ReportExecutionLog)
+ .filter(ReportExecutionLog.state == ReportLogState.SUCCESS)
+ .order_by(ReportExecutionLog.end_dttm.desc())
+ .first()
+ )
+
+ @staticmethod
+ def bulk_delete_logs(
+ model: ReportSchedule,
+ from_date: datetime,
+ session: Optional[Session] = None,
+ commit: bool = True,
+ ) -> None:
+ session = session or db.session
+ try:
+ session.query(ReportExecutionLog).filter(
+ ReportExecutionLog.report_schedule == model,
+ ReportExecutionLog.end_dttm < from_date,
+ ).delete(synchronize_session="fetch")
+ if commit:
+ session.commit()
+ except SQLAlchemyError as ex:
+ if commit:
+ session.rollback()
+ raise ex
diff --git a/superset/tasks/celery_app.py b/superset/reports/notifications/__init__.py
similarity index 52%
copy from superset/tasks/celery_app.py
copy to superset/reports/notifications/__init__.py
index 0f3cd0e..2553053 100644
--- a/superset/tasks/celery_app.py
+++ b/superset/reports/notifications/__init__.py
@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -14,22 +15,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from superset.models.reports import ReportRecipients
+from superset.reports.notifications.base import BaseNotification, NotificationContent
+from superset.reports.notifications.email import EmailNotification
+from superset.reports.notifications.slack import SlackNotification
-"""
-This is the main entrypoint used by Celery workers. As such,
-it needs to call create_app() in order to initialize things properly
-"""
-# Superset framework imports
-from superset import create_app
-from superset.extensions import celery_app
-
-# Init the Flask app / configure everything
-create_app()
-
-# Need to import late, as the celery_app will have been setup by "create_app()"
-# pylint: disable=wrong-import-position, unused-import
-from . import cache, schedules # isort:skip
-
-# Export the celery app globally for Celery (as run on the cmd line) to find
-app = celery_app
+def create_notification(
+ recipient: ReportRecipients, screenshot_data: NotificationContent
+) -> BaseNotification:
+ """
+ Notification polymorphic factory
+ Returns the Notification class for the recipient type
+ """
+ for plugin in BaseNotification.plugins:
+ if plugin.type == recipient.type:
+ return plugin(recipient, screenshot_data)
+ raise Exception("Recipient type not supported")
diff --git a/superset/reports/notifications/base.py b/superset/reports/notifications/base.py
new file mode 100644
index 0000000..f55154c
--- /dev/null
+++ b/superset/reports/notifications/base.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from dataclasses import dataclass
+from typing import Any, List, Optional, Type
+
+from superset.models.reports import ReportRecipients, ReportRecipientType
+
+
+@dataclass
+class ScreenshotData:
+ url: str # url to chart/dashboard for this screenshot
+ image: bytes # bytes for the screenshot
+
+
+@dataclass
+class NotificationContent:
+ name: str
+ screenshot: ScreenshotData
+
+
+class BaseNotification: # pylint: disable=too-few-public-methods
+ """
+ Serves has base for all notifications and creates a simple plugin system
+ for extending future implementations.
+ Child implementations get automatically registered and should identify the
+ notification type
+ """
+
+ plugins: List[Type["BaseNotification"]] = []
+ type: Optional[ReportRecipientType] = None
+ """
+ Child classes set their notification type ex: `type = "email"` this string will be
+ used by ReportRecipients.type to map to the correct implementation
+ """
+
+ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
+ super().__init_subclass__(*args, **kwargs) # type: ignore
+ cls.plugins.append(cls)
+
+ def __init__(
+ self, recipient: ReportRecipients, content: NotificationContent
+ ) -> None:
+ self._recipient = recipient
+ self._content = content
+
+ def send(self) -> None:
+ raise NotImplementedError()
diff --git a/superset/reports/notifications/email.py b/superset/reports/notifications/email.py
new file mode 100644
index 0000000..e99a7f4
--- /dev/null
+++ b/superset/reports/notifications/email.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import json
+import logging
+from dataclasses import dataclass
+from email.utils import make_msgid, parseaddr
+from typing import Dict
+
+from flask_babel import gettext as __
+
+from superset import app
+from superset.models.reports import ReportRecipientType
+from superset.reports.notifications.base import BaseNotification
+from superset.reports.notifications.exceptions import NotificationError
+from superset.utils.core import send_email_smtp
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class EmailContent:
+ body: str
+ images: Dict[str, bytes]
+
+
+class EmailNotification(BaseNotification): # pylint: disable=too-few-public-methods
+ """
+ Sends an email notification for a report recipient
+ """
+
+ type = ReportRecipientType.EMAIL
+
+ @staticmethod
+ def _get_smtp_domain() -> str:
+ return parseaddr(app.config["SMTP_MAIL_FROM"])[1].split("@")[1]
+
+ def _get_content(self) -> EmailContent:
+ # Get the domain from the 'From' address ..
+ # and make a message id without the < > in the ends
+ domain = self._get_smtp_domain()
+ msgid = make_msgid(domain)[1:-1]
+
+ image = {msgid: self._content.screenshot.image}
+ body = __(
+ """
+ <b><a href="%(url)s">Explore in Superset</a></b><p></p>
+ <img src="cid:%(msgid)s">
+ """,
+ url=self._content.screenshot.url,
+ msgid=msgid,
+ )
+ return EmailContent(body=body, images=image)
+
+ def _get_subject(self) -> str:
+ return __(
+ "%(prefix)s %(title)s",
+ prefix=app.config["EMAIL_REPORTS_SUBJECT_PREFIX"],
+ title=self._content.name,
+ )
+
+ def _get_to(self) -> str:
+ return json.loads(self._recipient.recipient_config_json)["target"]
+
+ def send(self) -> None:
+ subject = self._get_subject()
+ content = self._get_content()
+ to = self._get_to()
+ try:
+ send_email_smtp(
+ to,
+ subject,
+ content.body,
+ app.config,
+ files=[],
+ data=None,
+ images=content.images,
+ bcc="",
+ mime_subtype="related",
+ dryrun=False,
+ )
+ logger.info("Report sent to email")
+ except Exception as ex:
+ raise NotificationError(ex)
diff --git a/superset/utils/urls.py b/superset/reports/notifications/exceptions.py
similarity index 67%
copy from superset/utils/urls.py
copy to superset/reports/notifications/exceptions.py
index 9053769..749a91f 100644
--- a/superset/utils/urls.py
+++ b/superset/reports/notifications/exceptions.py
@@ -14,17 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import urllib
-from typing import Any
-from flask import current_app, url_for
-
-def headless_url(path: str) -> str:
- base_url = current_app.config.get("WEBDRIVER_BASEURL", "")
- return urllib.parse.urljoin(base_url, path)
-
-
-def get_url_path(view: str, **kwargs: Any) -> str:
- with current_app.test_request_context():
- return headless_url(url_for(view, **kwargs))
+class NotificationError(Exception):
+ pass
diff --git a/superset/reports/notifications/slack.py b/superset/reports/notifications/slack.py
new file mode 100644
index 0000000..8e859ff
--- /dev/null
+++ b/superset/reports/notifications/slack.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import json
+import logging
+from io import IOBase
+from typing import cast, Optional, Union
+
+from flask_babel import gettext as __
+from retry.api import retry
+from slack import WebClient
+from slack.errors import SlackApiError, SlackClientError
+from slack.web.slack_response import SlackResponse
+
+from superset import app
+from superset.models.reports import ReportRecipientType
+from superset.reports.notifications.base import BaseNotification
+from superset.reports.notifications.exceptions import NotificationError
+
+logger = logging.getLogger(__name__)
+
+
+class SlackNotification(BaseNotification): # pylint: disable=too-few-public-methods
+ """
+ Sends a slack notification for a report recipient
+ """
+
+ type = ReportRecipientType.SLACK
+
+ def _get_channel(self) -> str:
+ return json.loads(self._recipient.recipient_config_json)["target"]
+
+ def _get_body(self) -> str:
+ return __(
+ """
+ *%(name)s*\n
+ <%(url)s|Explore in Superset>
+ """,
+ name=self._content.name,
+ url=self._content.screenshot.url,
+ )
+
+ def _get_inline_screenshot(self) -> Optional[Union[str, IOBase, bytes]]:
+ return self._content.screenshot.image
+
+ @retry(SlackApiError, delay=10, backoff=2, tries=5)
+ def send(self) -> None:
+ file = self._get_inline_screenshot()
+ channel = self._get_channel()
+ body = self._get_body()
+
+ try:
+ client = WebClient(
+ token=app.config["SLACK_API_TOKEN"], proxy=app.config["SLACK_PROXY"]
+ )
+ # files_upload returns SlackResponse as we run it in sync mode.
+ if file:
+ response = cast(
+ SlackResponse,
+ client.files_upload(
+ channels=channel,
+ file=file,
+ initial_comment=body,
+ title="subject",
+ ),
+ )
+ assert response["file"], str(response) # the uploaded file
+ else:
+ response = cast(
+ SlackResponse, client.chat_postMessage(channel=channel, text=body),
+ )
+ assert response["message"]["text"], str(response)
+ logger.info("Report sent to slack")
+ except SlackClientError as ex:
+ raise NotificationError(ex)
diff --git a/superset/tasks/celery_app.py b/superset/tasks/celery_app.py
index 0f3cd0e..d84273f 100644
--- a/superset/tasks/celery_app.py
+++ b/superset/tasks/celery_app.py
@@ -29,7 +29,7 @@ create_app()
# Need to import late, as the celery_app will have been setup by "create_app()"
# pylint: disable=wrong-import-position, unused-import
-from . import cache, schedules # isort:skip
+from . import cache, schedules, scheduler # isort:skip
# Export the celery app globally for Celery (as run on the cmd line) to find
app = celery_app
diff --git a/superset/tasks/scheduler.py b/superset/tasks/scheduler.py
new file mode 100644
index 0000000..62398f0
--- /dev/null
+++ b/superset/tasks/scheduler.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import logging
+from datetime import datetime, timedelta
+from typing import Iterator
+
+import croniter
+
+from superset.commands.exceptions import CommandException
+from superset.extensions import celery_app
+from superset.reports.commands.execute import AsyncExecuteReportScheduleCommand
+from superset.reports.commands.log_prune import AsyncPruneReportScheduleLogCommand
+from superset.reports.dao import ReportScheduleDAO
+from superset.utils.celery import session_scope
+
+logger = logging.getLogger(__name__)
+
+
+def cron_schedule_window(cron: str, window_size: int = 10) -> Iterator[datetime]:
+ utc_now = datetime.utcnow()
+ start_at = utc_now - timedelta(seconds=1)
+ stop_at = utc_now + timedelta(seconds=window_size)
+ crons = croniter.croniter(cron, start_at)
+ for schedule in crons.all_next(datetime):
+ if schedule >= stop_at:
+ break
+ yield schedule
+
+
+@celery_app.task(name="reports.scheduler")
+def scheduler() -> None:
+ """
+ Celery beat main scheduler for reports
+ """
+ with session_scope(nullpool=True) as session:
+ active_schedules = ReportScheduleDAO.find_active(session)
+ for active_schedule in active_schedules:
+ for schedule in cron_schedule_window(active_schedule.crontab):
+ execute.apply_async((active_schedule.id, schedule,), eta=schedule)
+
+
+@celery_app.task(name="reports.execute")
+def execute(report_schedule_id: int, scheduled_dttm: datetime) -> None:
+ try:
+ AsyncExecuteReportScheduleCommand(report_schedule_id, scheduled_dttm).run()
+ except CommandException as ex:
+ logger.error("An exception occurred while executing the report: %s", ex)
+
+
+@celery_app.task(name="reports.prune_log")
+def prune_log() -> None:
+ try:
+ AsyncPruneReportScheduleLogCommand().run()
+ except CommandException as ex:
+ logger.error("An exception occurred while pruning report schedule logs: %s", ex)
diff --git a/superset/utils/urls.py b/superset/utils/urls.py
index 9053769..fe9455d 100644
--- a/superset/utils/urls.py
+++ b/superset/utils/urls.py
@@ -20,11 +20,15 @@ from typing import Any
from flask import current_app, url_for
-def headless_url(path: str) -> str:
- base_url = current_app.config.get("WEBDRIVER_BASEURL", "")
+def headless_url(path: str, user_friendly: bool = False) -> str:
+ base_url = (
+ current_app.config["WEBDRIVER_BASEURL_USER_FRIENDLY"]
+ if user_friendly
+ else current_app.config["WEBDRIVER_BASEURL"]
+ )
return urllib.parse.urljoin(base_url, path)
-def get_url_path(view: str, **kwargs: Any) -> str:
+def get_url_path(view: str, user_friendly: bool = False, **kwargs: Any) -> str:
with current_app.test_request_context():
- return headless_url(url_for(view, **kwargs))
+ return headless_url(url_for(view, **kwargs), user_friendly=user_friendly)
diff --git a/tests/reports/api_tests.py b/tests/reports/api_tests.py
index eb5425d..26dbe16 100644
--- a/tests/reports/api_tests.py
+++ b/tests/reports/api_tests.py
@@ -40,6 +40,7 @@ from superset.models.reports import (
)
from tests.base_tests import SupersetTestCase
+from tests.reports.utils import insert_report_schedule
from superset.utils.core import get_example_database
@@ -47,48 +48,6 @@ REPORTS_COUNT = 10
class TestReportSchedulesApi(SupersetTestCase):
- def insert_report_schedule(
- self,
- type: str,
- name: str,
- crontab: str,
- sql: Optional[str] = None,
- description: Optional[str] = None,
- chart: Optional[Slice] = None,
- dashboard: Optional[Dashboard] = None,
- database: Optional[Database] = None,
- owners: Optional[List[User]] = None,
- validator_type: Optional[str] = None,
- validator_config_json: Optional[str] = None,
- log_retention: Optional[int] = None,
- grace_period: Optional[int] = None,
- recipients: Optional[List[ReportRecipients]] = None,
- logs: Optional[List[ReportExecutionLog]] = None,
- ) -> ReportSchedule:
- owners = owners or []
- recipients = recipients or []
- logs = logs or []
- report_schedule = ReportSchedule(
- type=type,
- name=name,
- crontab=crontab,
- sql=sql,
- description=description,
- chart=chart,
- dashboard=dashboard,
- database=database,
- owners=owners,
- validator_type=validator_type,
- validator_config_json=validator_config_json,
- log_retention=log_retention,
- grace_period=grace_period,
- recipients=recipients,
- logs=logs,
- )
- db.session.add(report_schedule)
- db.session.commit()
- return report_schedule
-
@pytest.fixture()
def create_report_schedules(self):
with self.create_app().app_context():
@@ -116,7 +75,7 @@ class TestReportSchedulesApi(SupersetTestCase):
)
)
report_schedules.append(
- self.insert_report_schedule(
+ insert_report_schedule(
type=ReportScheduleType.ALERT,
name=f"name{cx}",
crontab=f"*/{cx} * * * *",
@@ -169,10 +128,6 @@ class TestReportSchedulesApi(SupersetTestCase):
"last_value_row_json": report_schedule.last_value_row_json,
"log_retention": report_schedule.log_retention,
"name": report_schedule.name,
- "owners": [
- {"first_name": "admin", "id": 1, "last_name": "user"},
- {"first_name": "alpha", "id": 5, "last_name": "user"},
- ],
"recipients": [
{
"id": report_schedule.recipients[0].id,
@@ -184,7 +139,16 @@ class TestReportSchedulesApi(SupersetTestCase):
"validator_config_json": report_schedule.validator_config_json,
"validator_type": report_schedule.validator_type,
}
- assert data["result"] == expected_result
+ for key in expected_result:
+ assert data["result"][key] == expected_result[key]
+ # needed because order may vary
+ assert {"first_name": "admin", "id": 1, "last_name": "user"} in data["result"][
+ "owners"
+ ]
+ assert {"first_name": "alpha", "id": 5, "last_name": "user"} in data["result"][
+ "owners"
+ ]
+ assert len(data["result"]["owners"]) == 2
def test_info_report_schedule(self):
"""
diff --git a/tests/reports/commands_tests.py b/tests/reports/commands_tests.py
new file mode 100644
index 0000000..d556694
--- /dev/null
+++ b/tests/reports/commands_tests.py
@@ -0,0 +1,531 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import json
+from datetime import datetime
+from typing import List, Optional
+from unittest.mock import patch
+
+import pytest
+from contextlib2 import contextmanager
+from freezegun import freeze_time
+from sqlalchemy.sql import func
+
+from superset import db
+from superset.models.core import Database
+from superset.models.dashboard import Dashboard
+from superset.models.reports import (
+ ReportExecutionLog,
+ ReportLogState,
+ ReportRecipients,
+ ReportRecipientType,
+ ReportSchedule,
+ ReportScheduleType,
+ ReportScheduleValidatorType,
+)
+from superset.models.slice import Slice
+from superset.reports.commands.exceptions import (
+ AlertQueryMultipleColumnsError,
+ AlertQueryMultipleRowsError,
+ ReportScheduleNotFoundError,
+ ReportScheduleNotificationError,
+ ReportSchedulePreviousWorkingError,
+)
+from superset.reports.commands.execute import AsyncExecuteReportScheduleCommand
+from superset.utils.core import get_example_database
+from tests.reports.utils import insert_report_schedule
+from tests.test_app import app
+from tests.utils import read_fixture
+
+
+def get_target_from_report_schedule(report_schedule) -> List[str]:
+ return [
+ json.loads(recipient.recipient_config_json)["target"]
+ for recipient in report_schedule.recipients
+ ]
+
+
+def assert_log(state: str, error_message: Optional[str] = None):
+ db.session.commit()
+ logs = db.session.query(ReportExecutionLog).all()
+ assert len(logs) == 1
+ assert logs[0].error_message == error_message
+ assert logs[0].state == state
+
+
+def create_report_notification(
+ email_target: Optional[str] = None,
+ slack_channel: Optional[str] = None,
+ chart: Optional[Slice] = None,
+ dashboard: Optional[Dashboard] = None,
+ database: Optional[Database] = None,
+ sql: Optional[str] = None,
+ report_type: Optional[str] = None,
+ validator_type: Optional[str] = None,
+ validator_config_json: Optional[str] = None,
+) -> ReportSchedule:
+ report_type = report_type or ReportScheduleType.REPORT
+ target = email_target or slack_channel
+ config_json = {"target": target}
+ if slack_channel:
+ recipient = ReportRecipients(
+ type=ReportRecipientType.SLACK,
+ recipient_config_json=json.dumps(config_json),
+ )
+ else:
+ recipient = ReportRecipients(
+ type=ReportRecipientType.EMAIL,
+ recipient_config_json=json.dumps(config_json),
+ )
+
+ report_schedule = insert_report_schedule(
+ type=report_type,
+ name=f"report",
+ crontab=f"0 9 * * *",
+ description=f"Daily report",
+ sql=sql,
+ chart=chart,
+ dashboard=dashboard,
+ database=database,
+ recipients=[recipient],
+ validator_type=validator_type,
+ validator_config_json=validator_config_json,
+ )
+ return report_schedule
+
+
+@pytest.yield_fixture()
+def create_report_email_chart():
+ with app.app_context():
+ chart = db.session.query(Slice).first()
+ report_schedule = create_report_notification(
+ email_target="target@email.com", chart=chart
+ )
+ yield report_schedule
+
+ db.session.delete(report_schedule)
+ db.session.commit()
+
+
+@pytest.yield_fixture()
+def create_report_email_dashboard():
+ with app.app_context():
+ dashboard = db.session.query(Dashboard).first()
+ report_schedule = create_report_notification(
+ email_target="target@email.com", dashboard=dashboard
+ )
+ yield report_schedule
+
+ db.session.delete(report_schedule)
+ db.session.commit()
+
+
+@pytest.yield_fixture()
+def create_report_slack_chart():
+ with app.app_context():
+ chart = db.session.query(Slice).first()
+ report_schedule = create_report_notification(
+ slack_channel="slack_channel", chart=chart
+ )
+ yield report_schedule
+
+ db.session.delete(report_schedule)
+ db.session.commit()
+
+
+@pytest.yield_fixture()
+def create_report_slack_chart_working():
+ with app.app_context():
+ chart = db.session.query(Slice).first()
+ report_schedule = create_report_notification(
+ slack_channel="slack_channel", chart=chart
+ )
+ report_schedule.last_state = ReportLogState.WORKING
+ db.session.commit()
+ yield report_schedule
+
+ db.session.delete(report_schedule)
+ db.session.commit()
+
+
+@pytest.yield_fixture(
+ params=["alert1", "alert2", "alert3", "alert4", "alert5", "alert6", "alert7"]
+)
+def create_alert_email_chart(request):
+ param_config = {
+ "alert1": {
+ "sql": "SELECT 10 as metric",
+ "validator_type": ReportScheduleValidatorType.OPERATOR,
+ "validator_config_json": '{"op": ">", "threshold": 9}',
+ },
+ "alert2": {
+ "sql": "SELECT 10 as metric",
+ "validator_type": ReportScheduleValidatorType.OPERATOR,
+ "validator_config_json": '{"op": ">=", "threshold": 10}',
+ },
+ "alert3": {
+ "sql": "SELECT 10 as metric",
+ "validator_type": ReportScheduleValidatorType.OPERATOR,
+ "validator_config_json": '{"op": "<", "threshold": 11}',
+ },
+ "alert4": {
+ "sql": "SELECT 10 as metric",
+ "validator_type": ReportScheduleValidatorType.OPERATOR,
+ "validator_config_json": '{"op": "<=", "threshold": 10}',
+ },
+ "alert5": {
+ "sql": "SELECT 10 as metric",
+ "validator_type": ReportScheduleValidatorType.OPERATOR,
+ "validator_config_json": '{"op": "!=", "threshold": 11}',
+ },
+ "alert6": {
+ "sql": "SELECT 'something' as metric",
+ "validator_type": ReportScheduleValidatorType.NOT_NULL,
+ "validator_config_json": "{}",
+ },
+ "alert7": {
+ "sql": "SELECT {{ 5 + 5 }} as metric",
+ "validator_type": ReportScheduleValidatorType.OPERATOR,
+ "validator_config_json": '{"op": "!=", "threshold": 11}',
+ },
+ }
+ with app.app_context():
+ chart = db.session.query(Slice).first()
+ example_database = get_example_database()
+
+ report_schedule = create_report_notification(
+ email_target="target@email.com",
+ chart=chart,
+ report_type=ReportScheduleType.ALERT,
+ database=example_database,
+ sql=param_config[request.param]["sql"],
+ validator_type=param_config[request.param]["validator_type"],
+ validator_config_json=param_config[request.param]["validator_config_json"],
+ )
+ yield report_schedule
+
+ db.session.delete(report_schedule)
+ db.session.commit()
+
+
+@contextmanager
+def create_test_table_context(database: Database):
+ database.get_sqla_engine().execute(
+ "CREATE TABLE test_table AS SELECT 1 as first, 2 as second"
+ )
+ database.get_sqla_engine().execute(
+ "INSERT INTO test_table (first, second) VALUES (1, 2)"
+ )
+ database.get_sqla_engine().execute(
+ "INSERT INTO test_table (first, second) VALUES (3, 4)"
+ )
+
+ yield db.session
+ database.get_sqla_engine().execute("DROP TABLE test_table")
+
+
+@pytest.yield_fixture(
+ params=["alert1", "alert2", "alert3", "alert4", "alert5", "alert6"]
+)
+def create_no_alert_email_chart(request):
+ param_config = {
+ "alert1": {
+ "sql": "SELECT 10 as metric",
+ "validator_type": ReportScheduleValidatorType.OPERATOR,
+ "validator_config_json": '{"op": "<", "threshold": 10}',
+ },
+ "alert2": {
+ "sql": "SELECT 10 as metric",
+ "validator_type": ReportScheduleValidatorType.OPERATOR,
+ "validator_config_json": '{"op": ">=", "threshold": 11}',
+ },
+ "alert3": {
+ "sql": "SELECT 10 as metric",
+ "validator_type": ReportScheduleValidatorType.OPERATOR,
+ "validator_config_json": '{"op": "<", "threshold": 10}',
+ },
+ "alert4": {
+ "sql": "SELECT 10 as metric",
+ "validator_type": ReportScheduleValidatorType.OPERATOR,
+ "validator_config_json": '{"op": "<=", "threshold": 9}',
+ },
+ "alert5": {
+ "sql": "SELECT 10 as metric",
+ "validator_type": ReportScheduleValidatorType.OPERATOR,
+ "validator_config_json": '{"op": "!=", "threshold": 10}',
+ },
+ "alert6": {
+ "sql": "SELECT first from test_table where first=0",
+ "validator_type": ReportScheduleValidatorType.NOT_NULL,
+ "validator_config_json": "{}",
+ },
+ }
+ with app.app_context():
+ chart = db.session.query(Slice).first()
+ example_database = get_example_database()
+ with create_test_table_context(example_database):
+
+ report_schedule = create_report_notification(
+ email_target="target@email.com",
+ chart=chart,
+ report_type=ReportScheduleType.ALERT,
+ database=example_database,
+ sql=param_config[request.param]["sql"],
+ validator_type=param_config[request.param]["validator_type"],
+ validator_config_json=param_config[request.param][
+ "validator_config_json"
+ ],
+ )
+ yield report_schedule
+
+ db.session.delete(report_schedule)
+ db.session.commit()
+
+
+@pytest.yield_fixture(params=["alert1", "alert2"])
+def create_mul_alert_email_chart(request):
+ param_config = {
+ "alert1": {
+ "sql": "SELECT first from test_table",
+ "validator_type": ReportScheduleValidatorType.OPERATOR,
+ "validator_config_json": '{"op": "<", "threshold": 10}',
+ },
+ "alert2": {
+ "sql": "SELECT first, second from test_table",
+ "validator_type": ReportScheduleValidatorType.OPERATOR,
+ "validator_config_json": '{"op": "<", "threshold": 10}',
+ },
+ }
+ with app.app_context():
+ chart = db.session.query(Slice).first()
+ example_database = get_example_database()
+ with create_test_table_context(example_database):
+
+ report_schedule = create_report_notification(
+ email_target="target@email.com",
+ chart=chart,
+ report_type=ReportScheduleType.ALERT,
+ database=example_database,
+ sql=param_config[request.param]["sql"],
+ validator_type=param_config[request.param]["validator_type"],
+ validator_config_json=param_config[request.param][
+ "validator_config_json"
+ ],
+ )
+ yield report_schedule
+
+ # needed for MySQL
+ logs = (
+ db.session.query(ReportExecutionLog)
+ .filter(ReportExecutionLog.report_schedule == report_schedule)
+ .all()
+ )
+ for log in logs:
+ db.session.delete(log)
+ db.session.commit()
+ db.session.delete(report_schedule)
+ db.session.commit()
+
+
+@pytest.mark.usefixtures("create_report_email_chart")
+@patch("superset.reports.notifications.email.send_email_smtp")
+@patch("superset.utils.screenshots.ChartScreenshot.compute_and_cache")
+def test_email_chart_report_schedule(
+ screenshot_mock, email_mock, create_report_email_chart
+):
+ """
+ ExecuteReport Command: Test chart email report schedule
+ """
+ # setup screenshot mock
+ screenshot = read_fixture("sample.png")
+ screenshot_mock.return_value = screenshot
+
+ with freeze_time("2020-01-01T00:00:00Z"):
+ AsyncExecuteReportScheduleCommand(
+ create_report_email_chart.id, datetime.utcnow()
+ ).run()
+
+ notification_targets = get_target_from_report_schedule(
+ create_report_email_chart
+ )
+ # Assert the email smtp address
+ assert email_mock.call_args[0][0] == notification_targets[0]
+ # Assert the email inline screenshot
+ smtp_images = email_mock.call_args[1]["images"]
+ assert smtp_images[list(smtp_images.keys())[0]] == screenshot
+ # Assert logs are correct
+ assert_log(ReportLogState.SUCCESS)
+
+
+@pytest.mark.usefixtures("create_report_email_dashboard")
+@patch("superset.reports.notifications.email.send_email_smtp")
+@patch("superset.utils.screenshots.DashboardScreenshot.compute_and_cache")
+def test_email_dashboard_report_schedule(
+ screenshot_mock, email_mock, create_report_email_dashboard
+):
+ """
+ ExecuteReport Command: Test dashboard email report schedule
+ """
+ # setup screenshot mock
+ screenshot = read_fixture("sample.png")
+ screenshot_mock.return_value = screenshot
+
+ with freeze_time("2020-01-01T00:00:00Z"):
+ AsyncExecuteReportScheduleCommand(
+ create_report_email_dashboard.id, datetime.utcnow()
+ ).run()
+
+ notification_targets = get_target_from_report_schedule(
+ create_report_email_dashboard
+ )
+ # Assert the email smtp address
+ assert email_mock.call_args[0][0] == notification_targets[0]
+ # Assert the email inline screenshot
+ smtp_images = email_mock.call_args[1]["images"]
+ assert smtp_images[list(smtp_images.keys())[0]] == screenshot
+ # Assert logs are correct
+ assert_log(ReportLogState.SUCCESS)
+
+
+@pytest.mark.usefixtures("create_report_slack_chart")
+@patch("superset.reports.notifications.slack.WebClient.files_upload")
+@patch("superset.utils.screenshots.ChartScreenshot.compute_and_cache")
+def test_slack_chart_report_schedule(
+ screenshot_mock, file_upload_mock, create_report_slack_chart
+):
+ """
+ ExecuteReport Command: Test chart slack report schedule
+ """
+ # setup screenshot mock
+ screenshot = read_fixture("sample.png")
+ screenshot_mock.return_value = screenshot
+
+ with freeze_time("2020-01-01T00:00:00Z"):
+ AsyncExecuteReportScheduleCommand(
+ create_report_slack_chart.id, datetime.utcnow()
+ ).run()
+
+ notification_targets = get_target_from_report_schedule(
+ create_report_slack_chart
+ )
+ assert file_upload_mock.call_args[1]["channels"] == notification_targets[0]
+ assert file_upload_mock.call_args[1]["file"] == screenshot
+
+ # Assert logs are correct
+ assert_log(ReportLogState.SUCCESS)
+
+
+@pytest.mark.usefixtures("create_report_slack_chart")
+def test_report_schedule_not_found(create_report_slack_chart):
+ """
+ ExecuteReport Command: Test report schedule not found
+ """
+ max_id = db.session.query(func.max(ReportSchedule.id)).scalar()
+ with pytest.raises(ReportScheduleNotFoundError):
+ AsyncExecuteReportScheduleCommand(max_id + 1, datetime.utcnow()).run()
+
+
+@pytest.mark.usefixtures("create_report_slack_chart_working")
+def test_report_schedule_working(create_report_slack_chart_working):
+ """
+ ExecuteReport Command: Test report schedule still working
+ """
+ # setup screenshot mock
+ with pytest.raises(ReportSchedulePreviousWorkingError):
+ AsyncExecuteReportScheduleCommand(
+ create_report_slack_chart_working.id, datetime.utcnow()
+ ).run()
+
+ assert_log(
+ ReportLogState.ERROR, error_message=ReportSchedulePreviousWorkingError.message
+ )
+ assert create_report_slack_chart_working.last_state == ReportLogState.WORKING
+
+
+@pytest.mark.usefixtures("create_report_email_dashboard")
+@patch("superset.reports.notifications.email.send_email_smtp")
+@patch("superset.utils.screenshots.DashboardScreenshot.compute_and_cache")
+def test_email_dashboard_report_fails(
+ screenshot_mock, email_mock, create_report_email_dashboard
+):
+ """
+ ExecuteReport Command: Test dashboard email report schedule notification fails
+ """
+ # setup screenshot mock
+ from smtplib import SMTPException
+
+ screenshot = read_fixture("sample.png")
+ screenshot_mock.return_value = screenshot
+ email_mock.side_effect = SMTPException("Could not connect to SMTP XPTO")
+
+ with pytest.raises(ReportScheduleNotificationError):
+ AsyncExecuteReportScheduleCommand(
+ create_report_email_dashboard.id, datetime.utcnow()
+ ).run()
+
+ assert_log(ReportLogState.ERROR, error_message="Could not connect to SMTP XPTO")
+
+
+@pytest.mark.usefixtures("create_alert_email_chart")
+@patch("superset.reports.notifications.email.send_email_smtp")
+@patch("superset.utils.screenshots.ChartScreenshot.compute_and_cache")
+def test_slack_chart_alert(screenshot_mock, email_mock, create_alert_email_chart):
+ """
+ ExecuteReport Command: Test chart slack alert
+ """
+ # setup screenshot mock
+ screenshot = read_fixture("sample.png")
+ screenshot_mock.return_value = screenshot
+
+ with freeze_time("2020-01-01T00:00:00Z"):
+ AsyncExecuteReportScheduleCommand(
+ create_alert_email_chart.id, datetime.utcnow()
+ ).run()
+
+ notification_targets = get_target_from_report_schedule(create_alert_email_chart)
+ # Assert the email smtp address
+ assert email_mock.call_args[0][0] == notification_targets[0]
+ # Assert the email inline screenshot
+ smtp_images = email_mock.call_args[1]["images"]
+ assert smtp_images[list(smtp_images.keys())[0]] == screenshot
+ # Assert logs are correct
+ assert_log(ReportLogState.SUCCESS)
+
+
+@pytest.mark.usefixtures("create_no_alert_email_chart")
+def test_email_chart_no_alert(create_no_alert_email_chart):
+ """
+ ExecuteReport Command: Test chart email no alert
+ """
+ with freeze_time("2020-01-01T00:00:00Z"):
+ AsyncExecuteReportScheduleCommand(
+ create_no_alert_email_chart.id, datetime.utcnow()
+ ).run()
+ assert_log(ReportLogState.NOOP)
+
+
+@pytest.mark.usefixtures("create_mul_alert_email_chart")
+def test_email_mul_alert(create_mul_alert_email_chart):
+ """
+ ExecuteReport Command: Test chart email multiple rows
+ """
+ with freeze_time("2020-01-01T00:00:00Z"):
+ with pytest.raises(
+ (AlertQueryMultipleRowsError, AlertQueryMultipleColumnsError)
+ ):
+ AsyncExecuteReportScheduleCommand(
+ create_mul_alert_email_chart.id, datetime.utcnow()
+ ).run()
diff --git a/tests/reports/utils.py b/tests/reports/utils.py
new file mode 100644
index 0000000..841ae4d
--- /dev/null
+++ b/tests/reports/utils.py
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List, Optional
+
+from flask_appbuilder.security.sqla.models import User
+
+from superset import db
+from superset.models.core import Database
+from superset.models.dashboard import Dashboard
+from superset.models.reports import ReportExecutionLog, ReportRecipients, ReportSchedule
+from superset.models.slice import Slice
+
+
+def insert_report_schedule(
+ type: str,
+ name: str,
+ crontab: str,
+ sql: Optional[str] = None,
+ description: Optional[str] = None,
+ chart: Optional[Slice] = None,
+ dashboard: Optional[Dashboard] = None,
+ database: Optional[Database] = None,
+ owners: Optional[List[User]] = None,
+ validator_type: Optional[str] = None,
+ validator_config_json: Optional[str] = None,
+ log_retention: Optional[int] = None,
+ grace_period: Optional[int] = None,
+ recipients: Optional[List[ReportRecipients]] = None,
+ logs: Optional[List[ReportExecutionLog]] = None,
+) -> ReportSchedule:
+ owners = owners or []
+ recipients = recipients or []
+ logs = logs or []
+ report_schedule = ReportSchedule(
+ type=type,
+ name=name,
+ crontab=crontab,
+ sql=sql,
+ description=description,
+ chart=chart,
+ dashboard=dashboard,
+ database=database,
+ owners=owners,
+ validator_type=validator_type,
+ validator_config_json=validator_config_json,
+ log_retention=log_retention,
+ grace_period=grace_period,
+ recipients=recipients,
+ logs=logs,
+ )
+ db.session.add(report_schedule)
+ db.session.commit()
+ return report_schedule