You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/29 11:25:18 UTC
[airflow] branch main updated: Move all "old" SQL operators to common.sql providers (#25350)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new acab8f52dd Move all "old" SQL operators to common.sql providers (#25350)
acab8f52dd is described below
commit acab8f52dd8d90fd6583779127895dd343780f79
Author: Jarek Potiuk <ja...@polidea.com>
AuthorDate: Fri Jul 29 13:25:11 2022 +0200
Move all "old" SQL operators to common.sql providers (#25350)
Previously, in #24836 we moved Hooks and added some new operators to the
common.sql package. Now we are salso moving the operators
and sensors to common.sql.
---
airflow/operators/check_operator.py | 24 +-
airflow/operators/druid_check_operator.py | 4 +-
airflow/operators/presto_check_operator.py | 24 +-
airflow/operators/sql.py | 560 +--------------
airflow/operators/sql_branch_operator.py | 12 +-
.../apache/druid/operators/druid_check.py | 6 +-
airflow/providers/apache/druid/provider.yaml | 2 +-
airflow/providers/common/sql/CHANGELOG.rst | 5 +
airflow/providers/common/sql/operators/sql.py | 550 ++++++++++++++-
airflow/providers/common/sql/provider.yaml | 2 +-
.../providers/google/cloud/operators/bigquery.py | 6 +-
airflow/providers/google/provider.yaml | 2 +-
.../google/suite/transfers/sql_to_sheets.py | 2 +-
airflow/providers/qubole/operators/qubole_check.py | 2 +-
airflow/providers/qubole/provider.yaml | 2 +-
airflow/providers/snowflake/operators/snowflake.py | 6 +-
airflow/providers/snowflake/provider.yaml | 2 +-
airflow/sensors/sql_sensor.py | 2 +-
docs/apache-airflow/operators-and-hooks-ref.rst | 6 -
generated/provider_dependencies.json | 8 +-
tests/deprecated_classes.py | 18 +-
tests/operators/test_sql.py | 779 ---------------------
tests/providers/common/sql/operators/test_sql.py | 776 +++++++++++++++++++-
.../qubole/operators/test_qubole_check.py | 14 +-
tests/providers/slack/operators/test_slack.py | 8 +-
.../providers/slack/transfers/test_sql_to_slack.py | 6 +-
26 files changed, 1424 insertions(+), 1404 deletions(-)
diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py
index 130eb6e577..0575211a01 100644
--- a/airflow/operators/check_operator.py
+++ b/airflow/operators/check_operator.py
@@ -16,11 +16,11 @@
# specific language governing permissions and limitations
# under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.sql`."""
+"""This module is deprecated. Please use :mod:`airflow.providers.common.sql.operators.sql`."""
import warnings
-from airflow.operators.sql import (
+from airflow.providers.common.sql.operators.sql import (
SQLCheckOperator,
SQLIntervalCheckOperator,
SQLThresholdCheckOperator,
@@ -28,20 +28,22 @@ from airflow.operators.sql import (
)
warnings.warn(
- "This module is deprecated. Please use `airflow.operators.sql`.", DeprecationWarning, stacklevel=2
+ "This module is deprecated. Please use `airflow.providers.common.sql.operators.sql`.",
+ DeprecationWarning,
+ stacklevel=2,
)
class CheckOperator(SQLCheckOperator):
"""
This class is deprecated.
- Please use `airflow.operators.sql.SQLCheckOperator`.
+ Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`.
"""
def __init__(self, **kwargs):
warnings.warn(
"""This class is deprecated.
- Please use `airflow.operators.sql.SQLCheckOperator`.""",
+ Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`.""",
DeprecationWarning,
stacklevel=2,
)
@@ -51,13 +53,13 @@ class CheckOperator(SQLCheckOperator):
class IntervalCheckOperator(SQLIntervalCheckOperator):
"""
This class is deprecated.
- Please use `airflow.operators.sql.SQLIntervalCheckOperator`.
+ Please use `airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator`.
"""
def __init__(self, **kwargs):
warnings.warn(
"""This class is deprecated.
- Please use `airflow.operators.sql.SQLIntervalCheckOperator`.""",
+ Please use `airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator`.""",
DeprecationWarning,
stacklevel=2,
)
@@ -67,13 +69,13 @@ class IntervalCheckOperator(SQLIntervalCheckOperator):
class ThresholdCheckOperator(SQLThresholdCheckOperator):
"""
This class is deprecated.
- Please use `airflow.operators.sql.SQLThresholdCheckOperator`.
+ Please use `airflow.providers.common.sql.operators.sql.SQLThresholdCheckOperator`.
"""
def __init__(self, **kwargs):
warnings.warn(
"""This class is deprecated.
- Please use `airflow.operators.sql.SQLThresholdCheckOperator`.""",
+ Please use `airflow.providers.common.sql.operators.sql.SQLThresholdCheckOperator`.""",
DeprecationWarning,
stacklevel=2,
)
@@ -83,13 +85,13 @@ class ThresholdCheckOperator(SQLThresholdCheckOperator):
class ValueCheckOperator(SQLValueCheckOperator):
"""
This class is deprecated.
- Please use `airflow.operators.sql.SQLValueCheckOperator`.
+ Please use `airflow.providers.common.sql.operators.sql.SQLValueCheckOperator`.
"""
def __init__(self, **kwargs):
warnings.warn(
"""This class is deprecated.
- Please use `airflow.operators.sql.SQLValueCheckOperator`.""",
+ Please use `airflow.providers.common.sql.operators.sql.SQLValueCheckOperator`.""",
DeprecationWarning,
stacklevel=2,
)
diff --git a/airflow/operators/druid_check_operator.py b/airflow/operators/druid_check_operator.py
index 008a91750c..217c306f6d 100644
--- a/airflow/operators/druid_check_operator.py
+++ b/airflow/operators/druid_check_operator.py
@@ -15,14 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.druid.operators.druid_check`."""
+"""This module is deprecated. Please use :mod:`airflow.providers.common.sql.operators.sql`."""
import warnings
from airflow.providers.apache.druid.operators.druid_check import DruidCheckOperator # noqa
warnings.warn(
- "This module is deprecated. Please use `airflow.operators.sql.SQLCheckOperator`.",
+ "This module is deprecated. Please use `airflow.providers.common.sql.operators.sql` module.",
DeprecationWarning,
stacklevel=2,
)
diff --git a/airflow/operators/presto_check_operator.py b/airflow/operators/presto_check_operator.py
index 693471f18c..810eef39a4 100644
--- a/airflow/operators/presto_check_operator.py
+++ b/airflow/operators/presto_check_operator.py
@@ -15,27 +15,33 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.sql`."""
+"""This module is deprecated. Please use :mod:`airflow.providers.common.sql.operators.sql`."""
import warnings
-from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator
+from airflow.providers.common.sql.operators.sql import (
+ SQLCheckOperator,
+ SQLIntervalCheckOperator,
+ SQLValueCheckOperator,
+)
warnings.warn(
- "This module is deprecated. Please use `airflow.operators.sql`.", DeprecationWarning, stacklevel=2
+ "This module is deprecated. Please use `airflow.providers.common.sql.operators.sql`.",
+ DeprecationWarning,
+ stacklevel=2,
)
class PrestoCheckOperator(SQLCheckOperator):
"""
This class is deprecated.
- Please use `airflow.operators.sql.SQLCheckOperator`.
+ Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`.
"""
def __init__(self, **kwargs):
warnings.warn(
"""This class is deprecated.
- Please use `airflow.operators.sql.SQLCheckOperator`.""",
+ Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`.""",
DeprecationWarning,
stacklevel=2,
)
@@ -45,14 +51,14 @@ class PrestoCheckOperator(SQLCheckOperator):
class PrestoIntervalCheckOperator(SQLIntervalCheckOperator):
"""
This class is deprecated.
- Please use `airflow.operators.sql.SQLIntervalCheckOperator`.
+ Please use `airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator`.
"""
def __init__(self, **kwargs):
warnings.warn(
"""
This class is deprecated.l
- Please use `airflow.operators.sql.SQLIntervalCheckOperator`.
+ Please use `airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator`.
""",
DeprecationWarning,
stacklevel=2,
@@ -63,14 +69,14 @@ class PrestoIntervalCheckOperator(SQLIntervalCheckOperator):
class PrestoValueCheckOperator(SQLValueCheckOperator):
"""
This class is deprecated.
- Please use `airflow.operators.sql.SQLValueCheckOperator`.
+ Please use `airflow.providers.common.sql.operators.sql.SQLValueCheckOperator`.
"""
def __init__(self, **kwargs):
warnings.warn(
"""
This class is deprecated.l
- Please use `airflow.operators.sql.SQLValueCheckOperator`.
+ Please use `airflow.providers.common.sql.operators.sql.SQLValueCheckOperator`.
""",
DeprecationWarning,
stacklevel=2,
diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py
index fbce9b85a1..9bbe159c17 100644
--- a/airflow/operators/sql.py
+++ b/airflow/operators/sql.py
@@ -15,543 +15,23 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, SupportsAbs, Union
-
-from airflow.compat.functools import cached_property
-from airflow.exceptions import AirflowException
-from airflow.hooks.base import BaseHook
-from airflow.models import BaseOperator, SkipMixin
-from airflow.providers.common.sql.hooks.sql import DbApiHook
-from airflow.utils.context import Context
-
-
-def parse_boolean(val: str) -> Union[str, bool]:
- """Try to parse a string into boolean.
-
- Raises ValueError if the input is not a valid true- or false-like string value.
- """
- val = val.lower()
- if val in ('y', 'yes', 't', 'true', 'on', '1'):
- return True
- if val in ('n', 'no', 'f', 'false', 'off', '0'):
- return False
- raise ValueError(f"{val!r} is not a boolean-like string value")
-
-
-class BaseSQLOperator(BaseOperator):
- """
- This is a base class for generic SQL Operator to get a DB Hook
-
- The provided method is .get_db_hook(). The default behavior will try to
- retrieve the DB hook based on connection type.
- You can custom the behavior by overriding the .get_db_hook() method.
- """
-
- def __init__(
- self,
- *,
- conn_id: Optional[str] = None,
- database: Optional[str] = None,
- hook_params: Optional[Dict] = None,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.conn_id = conn_id
- self.database = database
- self.hook_params = {} if hook_params is None else hook_params
-
- @cached_property
- def _hook(self):
- """Get DB Hook based on connection type"""
- self.log.debug("Get connection for %s", self.conn_id)
- conn = BaseHook.get_connection(self.conn_id)
-
- hook = conn.get_hook(hook_params=self.hook_params)
- if not isinstance(hook, DbApiHook):
- raise AirflowException(
- f'The connection type is not supported by {self.__class__.__name__}. '
- f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}'
- )
-
- if self.database:
- hook.schema = self.database
-
- return hook
-
- def get_db_hook(self) -> DbApiHook:
- """
- Get the database hook for the connection.
-
- :return: the database hook object.
- :rtype: DbApiHook
- """
- return self._hook
-
-
-class SQLCheckOperator(BaseSQLOperator):
- """
- Performs checks against a db. The ``SQLCheckOperator`` expects
- a sql query that will return a single row. Each value on that
- first row is evaluated using python ``bool`` casting. If any of the
- values return ``False`` the check is failed and errors out.
-
- Note that Python bool casting evals the following as ``False``:
-
- * ``False``
- * ``0``
- * Empty string (``""``)
- * Empty list (``[]``)
- * Empty dictionary or set (``{}``)
-
- Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if
- the count ``== 0``. You can craft much more complex query that could,
- for instance, check that the table has the same number of rows as
- the source table upstream, or that the count of today's partition is
- greater than yesterday's partition, or that a set of metrics are less
- than 3 standard deviation for the 7 day average.
-
- This operator can be used as a data quality check in your pipeline, and
- depending on where you put it in your DAG, you have the choice to
- stop the critical path, preventing from
- publishing dubious data, or on the side and receive email alerts
- without stopping the progress of the DAG.
-
- :param sql: the sql to be executed. (templated)
- :param conn_id: the connection ID used to connect to the database.
- :param database: name of database which overwrite the defined one in connection
- """
-
- template_fields: Sequence[str] = ("sql",)
- template_ext: Sequence[str] = (
- ".hql",
- ".sql",
- )
- template_fields_renderers = {"sql": "sql"}
- ui_color = "#fff7e6"
-
- def __init__(
- self, *, sql: str, conn_id: Optional[str] = None, database: Optional[str] = None, **kwargs
- ) -> None:
- super().__init__(conn_id=conn_id, database=database, **kwargs)
- self.sql = sql
-
- def execute(self, context: Context):
- self.log.info("Executing SQL check: %s", self.sql)
- records = self.get_db_hook().get_first(self.sql)
-
- self.log.info("Record: %s", records)
- if not records:
- raise AirflowException("The query returned None")
- elif not all(bool(r) for r in records):
- raise AirflowException(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}")
-
- self.log.info("Success.")
-
-
-def _convert_to_float_if_possible(s):
- """
- A small helper function to convert a string to a numeric value
- if appropriate
-
- :param s: the string to be converted
- """
- try:
- ret = float(s)
- except (ValueError, TypeError):
- ret = s
- return ret
-
-
-class SQLValueCheckOperator(BaseSQLOperator):
- """
- Performs a simple value check using sql code.
-
- :param sql: the sql to be executed. (templated)
- :param conn_id: the connection ID used to connect to the database.
- :param database: name of database which overwrite the defined one in connection
- """
-
- __mapper_args__ = {"polymorphic_identity": "SQLValueCheckOperator"}
- template_fields: Sequence[str] = (
- "sql",
- "pass_value",
- )
- template_ext: Sequence[str] = (
- ".hql",
- ".sql",
- )
- template_fields_renderers = {"sql": "sql"}
- ui_color = "#fff7e6"
-
- def __init__(
- self,
- *,
- sql: str,
- pass_value: Any,
- tolerance: Any = None,
- conn_id: Optional[str] = None,
- database: Optional[str] = None,
- **kwargs,
- ):
- super().__init__(conn_id=conn_id, database=database, **kwargs)
- self.sql = sql
- self.pass_value = str(pass_value)
- tol = _convert_to_float_if_possible(tolerance)
- self.tol = tol if isinstance(tol, float) else None
- self.has_tolerance = self.tol is not None
-
- def execute(self, context=None):
- self.log.info("Executing SQL check: %s", self.sql)
- records = self.get_db_hook().get_first(self.sql)
-
- if not records:
- raise AirflowException("The query returned None")
-
- pass_value_conv = _convert_to_float_if_possible(self.pass_value)
- is_numeric_value_check = isinstance(pass_value_conv, float)
-
- tolerance_pct_str = str(self.tol * 100) + "%" if self.has_tolerance else None
- error_msg = (
- "Test failed.\nPass value:{pass_value_conv}\n"
- "Tolerance:{tolerance_pct_str}\n"
- "Query:\n{sql}\nResults:\n{records!s}"
- ).format(
- pass_value_conv=pass_value_conv,
- tolerance_pct_str=tolerance_pct_str,
- sql=self.sql,
- records=records,
- )
-
- if not is_numeric_value_check:
- tests = self._get_string_matches(records, pass_value_conv)
- elif is_numeric_value_check:
- try:
- numeric_records = self._to_float(records)
- except (ValueError, TypeError):
- raise AirflowException(f"Converting a result to float failed.\n{error_msg}")
- tests = self._get_numeric_matches(numeric_records, pass_value_conv)
- else:
- tests = []
-
- if not all(tests):
- raise AirflowException(error_msg)
-
- def _to_float(self, records):
- return [float(record) for record in records]
-
- def _get_string_matches(self, records, pass_value_conv):
- return [str(record) == pass_value_conv for record in records]
-
- def _get_numeric_matches(self, numeric_records, numeric_pass_value_conv):
- if self.has_tolerance:
- return [
- numeric_pass_value_conv * (1 - self.tol) <= record <= numeric_pass_value_conv * (1 + self.tol)
- for record in numeric_records
- ]
-
- return [record == numeric_pass_value_conv for record in numeric_records]
-
-
-class SQLIntervalCheckOperator(BaseSQLOperator):
- """
- Checks that the values of metrics given as SQL expressions are within
- a certain tolerance of the ones from days_back before.
-
- :param table: the table name
- :param conn_id: the connection ID used to connect to the database.
- :param database: name of database which will overwrite the defined one in connection
- :param days_back: number of days between ds and the ds we want to check
- against. Defaults to 7 days
- :param date_filter_column: The column name for the dates to filter on. Defaults to 'ds'
- :param ratio_formula: which formula to use to compute the ratio between
- the two metrics. Assuming cur is the metric of today and ref is
- the metric to today - days_back.
-
- max_over_min: computes max(cur, ref) / min(cur, ref)
- relative_diff: computes abs(cur-ref) / ref
-
- Default: 'max_over_min'
- :param ignore_zero: whether we should ignore zero metrics
- :param metrics_thresholds: a dictionary of ratios indexed by metrics
- """
-
- __mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"}
- template_fields: Sequence[str] = ("sql1", "sql2")
- template_ext: Sequence[str] = (
- ".hql",
- ".sql",
- )
- template_fields_renderers = {"sql1": "sql", "sql2": "sql"}
- ui_color = "#fff7e6"
-
- ratio_formulas = {
- "max_over_min": lambda cur, ref: float(max(cur, ref)) / min(cur, ref),
- "relative_diff": lambda cur, ref: float(abs(cur - ref)) / ref,
- }
-
- def __init__(
- self,
- *,
- table: str,
- metrics_thresholds: Dict[str, int],
- date_filter_column: Optional[str] = "ds",
- days_back: SupportsAbs[int] = -7,
- ratio_formula: Optional[str] = "max_over_min",
- ignore_zero: bool = True,
- conn_id: Optional[str] = None,
- database: Optional[str] = None,
- **kwargs,
- ):
- super().__init__(conn_id=conn_id, database=database, **kwargs)
- if ratio_formula not in self.ratio_formulas:
- msg_template = "Invalid diff_method: {diff_method}. Supported diff methods are: {diff_methods}"
-
- raise AirflowException(
- msg_template.format(diff_method=ratio_formula, diff_methods=self.ratio_formulas)
- )
- self.ratio_formula = ratio_formula
- self.ignore_zero = ignore_zero
- self.table = table
- self.metrics_thresholds = metrics_thresholds
- self.metrics_sorted = sorted(metrics_thresholds.keys())
- self.date_filter_column = date_filter_column
- self.days_back = -abs(days_back)
- sqlexp = ", ".join(self.metrics_sorted)
- sqlt = f"SELECT {sqlexp} FROM {table} WHERE {date_filter_column}="
-
- self.sql1 = sqlt + "'{{ ds }}'"
- self.sql2 = sqlt + "'{{ macros.ds_add(ds, " + str(self.days_back) + ") }}'"
-
- def execute(self, context=None):
- hook = self.get_db_hook()
- self.log.info("Using ratio formula: %s", self.ratio_formula)
- self.log.info("Executing SQL check: %s", self.sql2)
- row2 = hook.get_first(self.sql2)
- self.log.info("Executing SQL check: %s", self.sql1)
- row1 = hook.get_first(self.sql1)
-
- if not row2:
- raise AirflowException(f"The query {self.sql2} returned None")
- if not row1:
- raise AirflowException(f"The query {self.sql1} returned None")
-
- current = dict(zip(self.metrics_sorted, row1))
- reference = dict(zip(self.metrics_sorted, row2))
-
- ratios = {}
- test_results = {}
-
- for metric in self.metrics_sorted:
- cur = current[metric]
- ref = reference[metric]
- threshold = self.metrics_thresholds[metric]
- if cur == 0 or ref == 0:
- ratios[metric] = None
- test_results[metric] = self.ignore_zero
- else:
- ratios[metric] = self.ratio_formulas[self.ratio_formula](current[metric], reference[metric])
- test_results[metric] = ratios[metric] < threshold
-
- self.log.info(
- (
- "Current metric for %s: %s\n"
- "Past metric for %s: %s\n"
- "Ratio for %s: %s\n"
- "Threshold: %s\n"
- ),
- metric,
- cur,
- metric,
- ref,
- metric,
- ratios[metric],
- threshold,
- )
-
- if not all(test_results.values()):
- failed_tests = [it[0] for it in test_results.items() if not it[1]]
- self.log.warning(
- "The following %s tests out of %s failed:",
- len(failed_tests),
- len(self.metrics_sorted),
- )
- for k in failed_tests:
- self.log.warning(
- "'%s' check failed. %s is above %s",
- k,
- ratios[k],
- self.metrics_thresholds[k],
- )
- raise AirflowException(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}")
-
- self.log.info("All tests have passed")
-
-
-class SQLThresholdCheckOperator(BaseSQLOperator):
- """
- Performs a value check using sql code against a minimum threshold
- and a maximum threshold. Thresholds can be in the form of a numeric
- value OR a sql statement that results a numeric.
-
- :param sql: the sql to be executed. (templated)
- :param conn_id: the connection ID used to connect to the database.
- :param database: name of database which overwrite the defined one in connection
- :param min_threshold: numerical value or min threshold sql to be executed (templated)
- :param max_threshold: numerical value or max threshold sql to be executed (templated)
- """
-
- template_fields: Sequence[str] = ("sql", "min_threshold", "max_threshold")
- template_ext: Sequence[str] = (
- ".hql",
- ".sql",
- )
- template_fields_renderers = {"sql": "sql"}
-
- def __init__(
- self,
- *,
- sql: str,
- min_threshold: Any,
- max_threshold: Any,
- conn_id: Optional[str] = None,
- database: Optional[str] = None,
- **kwargs,
- ):
- super().__init__(conn_id=conn_id, database=database, **kwargs)
- self.sql = sql
- self.min_threshold = _convert_to_float_if_possible(min_threshold)
- self.max_threshold = _convert_to_float_if_possible(max_threshold)
-
- def execute(self, context=None):
- hook = self.get_db_hook()
- result = hook.get_first(self.sql)[0]
-
- if isinstance(self.min_threshold, float):
- lower_bound = self.min_threshold
- else:
- lower_bound = hook.get_first(self.min_threshold)[0]
-
- if isinstance(self.max_threshold, float):
- upper_bound = self.max_threshold
- else:
- upper_bound = hook.get_first(self.max_threshold)[0]
-
- meta_data = {
- "result": result,
- "task_id": self.task_id,
- "min_threshold": lower_bound,
- "max_threshold": upper_bound,
- "within_threshold": lower_bound <= result <= upper_bound,
- }
-
- self.push(meta_data)
- if not meta_data["within_threshold"]:
- error_msg = (
- f'Threshold Check: "{meta_data.get("task_id")}" failed.\n'
- f'DAG: {self.dag_id}\nTask_id: {meta_data.get("task_id")}\n'
- f'Check description: {meta_data.get("description")}\n'
- f"SQL: {self.sql}\n"
- f'Result: {round(meta_data.get("result"), 2)} is not within thresholds '
- f'{meta_data.get("min_threshold")} and {meta_data.get("max_threshold")}'
- )
- raise AirflowException(error_msg)
-
- self.log.info("Test %s Successful.", self.task_id)
-
- def push(self, meta_data):
- """
- Optional: Send data check info and metadata to an external database.
- Default functionality will log metadata.
- """
- info = "\n".join(f"""{key}: {item}""" for key, item in meta_data.items())
- self.log.info("Log from %s:\n%s", self.dag_id, info)
-
-
-class BranchSQLOperator(BaseSQLOperator, SkipMixin):
- """
- Allows a DAG to "branch" or follow a specified path based on the results of a SQL query.
-
- :param sql: The SQL code to be executed, should return true or false (templated)
- Template reference are recognized by str ending in '.sql'.
- Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
- or string (true/y/yes/1/on/false/n/no/0/off).
- :param follow_task_ids_if_true: task id or task ids to follow if query returns true
- :param follow_task_ids_if_false: task id or task ids to follow if query returns false
- :param conn_id: the connection ID used to connect to the database.
- :param database: name of database which overwrite the defined one in connection
- :param parameters: (optional) the parameters to render the SQL query with.
- """
-
- template_fields: Sequence[str] = ("sql",)
- template_ext: Sequence[str] = (".sql",)
- template_fields_renderers = {"sql": "sql"}
- ui_color = "#a22034"
- ui_fgcolor = "#F7F7F7"
-
- def __init__(
- self,
- *,
- sql: str,
- follow_task_ids_if_true: List[str],
- follow_task_ids_if_false: List[str],
- conn_id: str = "default_conn_id",
- database: Optional[str] = None,
- parameters: Optional[Union[Iterable, Mapping]] = None,
- **kwargs,
- ) -> None:
- super().__init__(conn_id=conn_id, database=database, **kwargs)
- self.sql = sql
- self.parameters = parameters
- self.follow_task_ids_if_true = follow_task_ids_if_true
- self.follow_task_ids_if_false = follow_task_ids_if_false
-
- def execute(self, context: Context):
- self.log.info(
- "Executing: %s (with parameters %s) with connection: %s",
- self.sql,
- self.parameters,
- self.conn_id,
- )
- record = self.get_db_hook().get_first(self.sql, self.parameters)
- if not record:
- raise AirflowException(
- "No rows returned from sql query. Operator expected True or False return value."
- )
-
- if isinstance(record, list):
- if isinstance(record[0], list):
- query_result = record[0][0]
- else:
- query_result = record[0]
- elif isinstance(record, tuple):
- query_result = record[0]
- else:
- query_result = record
-
- self.log.info("Query returns %s, type '%s'", query_result, type(query_result))
-
- follow_branch = None
- try:
- if isinstance(query_result, bool):
- if query_result:
- follow_branch = self.follow_task_ids_if_true
- elif isinstance(query_result, str):
- # return result is not Boolean, try to convert from String to Boolean
- if parse_boolean(query_result):
- follow_branch = self.follow_task_ids_if_true
- elif isinstance(query_result, int):
- if bool(query_result):
- follow_branch = self.follow_task_ids_if_true
- else:
- raise AirflowException(
- f"Unexpected query return result '{query_result}' type '{type(query_result)}'"
- )
-
- if follow_branch is None:
- follow_branch = self.follow_task_ids_if_false
- except ValueError:
- raise AirflowException(
- f"Unexpected query return result '{query_result}' type '{type(query_result)}'"
- )
-
- self.skip_all_except(context["ti"], follow_branch)
+import warnings
+
+from airflow.providers.common.sql.operators.sql import ( # noqa
+ BaseSQLOperator,
+ BranchSQLOperator,
+ SQLCheckOperator,
+ SQLColumnCheckOperator,
+ SQLIntervalCheckOperator,
+ SQLTableCheckOperator,
+ SQLThresholdCheckOperator,
+ SQLValueCheckOperator,
+ _convert_to_float_if_possible,
+ parse_boolean,
+)
+
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.common.sql.operators.sql`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/operators/sql_branch_operator.py b/airflow/operators/sql_branch_operator.py
index 5987bce610..90d5fcbc1b 100644
--- a/airflow/operators/sql_branch_operator.py
+++ b/airflow/operators/sql_branch_operator.py
@@ -14,26 +14,28 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.sql`."""
+"""This module is deprecated. Please use :mod:`airflow.providers.common.sql.operators.sql`."""
import warnings
-from airflow.operators.sql import BranchSQLOperator
+from airflow.providers.common.sql.operators.sql import BranchSQLOperator
warnings.warn(
- "This module is deprecated. Please use :mod:`airflow.operators.sql`.", DeprecationWarning, stacklevel=2
+ "This module is deprecated. Please use :mod:`airflow.providers.common.sql.operators.sql`.",
+ DeprecationWarning,
+ stacklevel=2,
)
class BranchSqlOperator(BranchSQLOperator):
"""
This class is deprecated.
- Please use `airflow.operators.sql.BranchSQLOperator`.
+ Please use `airflow.providers.common.sql.operators.sql.BranchSQLOperator`.
"""
def __init__(self, **kwargs):
warnings.warn(
"""This class is deprecated.
- Please use `airflow.operators.sql.BranchSQLOperator`.""",
+ Please use `airflow.providers.common.sql.operators.sql.BranchSQLOperator`.""",
DeprecationWarning,
stacklevel=2,
)
diff --git a/airflow/providers/apache/druid/operators/druid_check.py b/airflow/providers/apache/druid/operators/druid_check.py
index 33a4151350..75f2b17d0c 100644
--- a/airflow/providers/apache/druid/operators/druid_check.py
+++ b/airflow/providers/apache/druid/operators/druid_check.py
@@ -17,19 +17,19 @@
# under the License.
import warnings
-from airflow.operators.sql import SQLCheckOperator
+from airflow.providers.common.sql.operators.sql import SQLCheckOperator
class DruidCheckOperator(SQLCheckOperator):
"""
This class is deprecated.
- Please use `airflow.operators.sql.SQLCheckOperator`.
+ Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`.
"""
def __init__(self, druid_broker_conn_id: str = 'druid_broker_default', **kwargs):
warnings.warn(
"""This class is deprecated.
- Please use `airflow.operators.sql.SQLCheckOperator`.""",
+ Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`.""",
DeprecationWarning,
stacklevel=2,
)
diff --git a/airflow/providers/apache/druid/provider.yaml b/airflow/providers/apache/druid/provider.yaml
index 038ae7fc23..1ccaf98153 100644
--- a/airflow/providers/apache/druid/provider.yaml
+++ b/airflow/providers/apache/druid/provider.yaml
@@ -39,7 +39,7 @@ versions:
dependencies:
- apache-airflow>=2.2.0
- - apache-airflow-providers-common-sql
+ - apache-airflow-providers-common-sql>=1.1.0
- pydruid>=0.4.1
integrations:
diff --git a/airflow/providers/common/sql/CHANGELOG.rst b/airflow/providers/common/sql/CHANGELOG.rst
index d48dafc25d..c575d7a027 100644
--- a/airflow/providers/common/sql/CHANGELOG.rst
+++ b/airflow/providers/common/sql/CHANGELOG.rst
@@ -24,6 +24,11 @@
Changelog
---------
+1.1.0
+.....
+
+
+
1.0.0
.....
diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py
index a6883d9f08..f1c872cc54 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -15,10 +15,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, SupportsAbs, Union
+from packaging.version import Version
+
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
-from airflow.operators.sql import BaseSQLOperator
+from airflow.hooks.base import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.providers.common.sql.hooks.sql import DbApiHook, _backported_get_hook
+from airflow.version import version
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
def parse_boolean(val: str) -> Union[str, bool]:
@@ -48,6 +57,61 @@ def _get_failed_checks(checks, col=None):
]
+class BaseSQLOperator(BaseOperator):
+ """
+ This is a base class for generic SQL Operator to get a DB Hook
+
+ The provided method is .get_db_hook(). The default behavior will try to
+ retrieve the DB hook based on connection type.
+ You can custom the behavior by overriding the .get_db_hook() method.
+ """
+
+ def __init__(
+ self,
+ *,
+ conn_id: Optional[str] = None,
+ database: Optional[str] = None,
+ hook_params: Optional[Dict] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.conn_id = conn_id
+ self.database = database
+ self.hook_params = {} if hook_params is None else hook_params
+
+ @cached_property
+ def _hook(self):
+ """Get DB Hook based on connection type"""
+ self.log.debug("Get connection for %s", self.conn_id)
+ conn = BaseHook.get_connection(self.conn_id)
+ if Version(version) >= Version('2.3'):
+ # "hook_params" were introduced to into "get_hook()" only in Airflow 2.3.
+ hook = conn.get_hook(hook_params=self.hook_params) # ignore airflow compat check
+ else:
+ # For supporting Airflow versions < 2.3, we backport "get_hook()" method. This should be removed
+ # when "apache-airflow-providers-common-sql" will depend on Airflow >= 2.3.
+ hook = _backported_get_hook(conn, hook_params=self.hook_params)
+ if not isinstance(hook, DbApiHook):
+ raise AirflowException(
+ f'The connection type is not supported by {self.__class__.__name__}. '
+ f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}'
+ )
+
+ if self.database:
+ hook.schema = self.database
+
+ return hook
+
+ def get_db_hook(self) -> DbApiHook:
+ """
+ Get the database hook for the connection.
+
+ :return: the database hook object.
+ :rtype: DbApiHook
+ """
+ return self._hook
+
+
class SQLColumnCheckOperator(BaseSQLOperator):
"""
Performs one or more of the templated checks in the column_checks dictionary.
@@ -125,7 +189,7 @@ class SQLColumnCheckOperator(BaseSQLOperator):
# OpenLineage needs a valid SQL query with the input/output table(s) to parse
self.sql = f"SELECT * FROM {self.table};"
- def execute(self, context=None):
+ def execute(self, context: 'Context'):
hook = self.get_db_hook()
failed_tests = []
for column in self.column_mapping:
@@ -307,7 +371,7 @@ class SQLTableCheckOperator(BaseSQLOperator):
# OpenLineage needs a valid SQL query with the input/output table(s) to parse
self.sql = f"SELECT * FROM {self.table};"
- def execute(self, context=None):
+ def execute(self, context: 'Context'):
hook = self.get_db_hook()
checks_sql = " UNION ALL ".join(
[
@@ -343,3 +407,481 @@ class SQLTableCheckOperator(BaseSQLOperator):
)
self.log.info("All tests have passed")
+
+
+class SQLCheckOperator(BaseSQLOperator):
+ """
+ Performs checks against a db. The ``SQLCheckOperator`` expects
+ a sql query that will return a single row. Each value on that
+ first row is evaluated using python ``bool`` casting. If any of the
+ values return ``False`` the check is failed and errors out.
+
+ Note that Python bool casting evals the following as ``False``:
+
+ * ``False``
+ * ``0``
+ * Empty string (``""``)
+ * Empty list (``[]``)
+ * Empty dictionary or set (``{}``)
+
+ Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if
+ the count ``== 0``. You can craft much more complex query that could,
+ for instance, check that the table has the same number of rows as
+ the source table upstream, or that the count of today's partition is
+ greater than yesterday's partition, or that a set of metrics are less
+ than 3 standard deviation for the 7 day average.
+
+ This operator can be used as a data quality check in your pipeline, and
+ depending on where you put it in your DAG, you have the choice to
+ stop the critical path, preventing from
+ publishing dubious data, or on the side and receive email alerts
+ without stopping the progress of the DAG.
+
+ :param sql: the sql to be executed. (templated)
+ :param conn_id: the connection ID used to connect to the database.
+ :param database: name of database which overwrite the defined one in connection
+ """
+
+ template_fields: Sequence[str] = ("sql",)
+ template_ext: Sequence[str] = (
+ ".hql",
+ ".sql",
+ )
+ template_fields_renderers = {"sql": "sql"}
+ ui_color = "#fff7e6"
+
+ def __init__(
+ self, *, sql: str, conn_id: Optional[str] = None, database: Optional[str] = None, **kwargs
+ ) -> None:
+ super().__init__(conn_id=conn_id, database=database, **kwargs)
+ self.sql = sql
+
+ def execute(self, context: 'Context'):
+ self.log.info("Executing SQL check: %s", self.sql)
+ records = self.get_db_hook().get_first(self.sql)
+
+ self.log.info("Record: %s", records)
+ if not records:
+ raise AirflowException("The query returned None")
+ elif not all(bool(r) for r in records):
+ raise AirflowException(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}")
+
+ self.log.info("Success.")
+
+
+class SQLValueCheckOperator(BaseSQLOperator):
+ """
+ Performs a simple value check using sql code.
+
+ :param sql: the sql to be executed. (templated)
+ :param conn_id: the connection ID used to connect to the database.
+ :param database: name of database which overwrite the defined one in connection
+ """
+
+ __mapper_args__ = {"polymorphic_identity": "SQLValueCheckOperator"}
+ template_fields: Sequence[str] = (
+ "sql",
+ "pass_value",
+ )
+ template_ext: Sequence[str] = (
+ ".hql",
+ ".sql",
+ )
+ template_fields_renderers = {"sql": "sql"}
+ ui_color = "#fff7e6"
+
+ def __init__(
+ self,
+ *,
+ sql: str,
+ pass_value: Any,
+ tolerance: Any = None,
+ conn_id: Optional[str] = None,
+ database: Optional[str] = None,
+ **kwargs,
+ ):
+ super().__init__(conn_id=conn_id, database=database, **kwargs)
+ self.sql = sql
+ self.pass_value = str(pass_value)
+ tol = _convert_to_float_if_possible(tolerance)
+ self.tol = tol if isinstance(tol, float) else None
+ self.has_tolerance = self.tol is not None
+
+ def execute(self, context: 'Context'):
+ self.log.info("Executing SQL check: %s", self.sql)
+ records = self.get_db_hook().get_first(self.sql)
+
+ if not records:
+ raise AirflowException("The query returned None")
+
+ pass_value_conv = _convert_to_float_if_possible(self.pass_value)
+ is_numeric_value_check = isinstance(pass_value_conv, float)
+
+ tolerance_pct_str = str(self.tol * 100) + "%" if self.tol is not None else None
+ error_msg = (
+ "Test failed.\nPass value:{pass_value_conv}\n"
+ "Tolerance:{tolerance_pct_str}\n"
+ "Query:\n{sql}\nResults:\n{records!s}"
+ ).format(
+ pass_value_conv=pass_value_conv,
+ tolerance_pct_str=tolerance_pct_str,
+ sql=self.sql,
+ records=records,
+ )
+
+ if not is_numeric_value_check:
+ tests = self._get_string_matches(records, pass_value_conv)
+ elif is_numeric_value_check:
+ try:
+ numeric_records = self._to_float(records)
+ except (ValueError, TypeError):
+ raise AirflowException(f"Converting a result to float failed.\n{error_msg}")
+ tests = self._get_numeric_matches(numeric_records, pass_value_conv)
+ else:
+ tests = []
+
+ if not all(tests):
+ raise AirflowException(error_msg)
+
+ def _to_float(self, records):
+ return [float(record) for record in records]
+
+ def _get_string_matches(self, records, pass_value_conv):
+ return [str(record) == pass_value_conv for record in records]
+
+ def _get_numeric_matches(self, numeric_records, numeric_pass_value_conv):
+ if self.has_tolerance:
+ return [
+ numeric_pass_value_conv * (1 - self.tol) <= record <= numeric_pass_value_conv * (1 + self.tol)
+ for record in numeric_records
+ ]
+
+ return [record == numeric_pass_value_conv for record in numeric_records]
+
+
+class SQLIntervalCheckOperator(BaseSQLOperator):
+ """
+ Checks that the values of metrics given as SQL expressions are within
+ a certain tolerance of the ones from days_back before.
+
+ :param table: the table name
+ :param conn_id: the connection ID used to connect to the database.
+ :param database: name of database which will overwrite the defined one in connection
+ :param days_back: number of days between ds and the ds we want to check
+ against. Defaults to 7 days
+ :param date_filter_column: The column name for the dates to filter on. Defaults to 'ds'
+ :param ratio_formula: which formula to use to compute the ratio between
+ the two metrics. Assuming cur is the metric of today and ref is
+ the metric to today - days_back.
+
+ max_over_min: computes max(cur, ref) / min(cur, ref)
+ relative_diff: computes abs(cur-ref) / ref
+
+ Default: 'max_over_min'
+ :param ignore_zero: whether we should ignore zero metrics
+ :param metrics_thresholds: a dictionary of ratios indexed by metrics
+ """
+
+ __mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"}
+ template_fields: Sequence[str] = ("sql1", "sql2")
+ template_ext: Sequence[str] = (
+ ".hql",
+ ".sql",
+ )
+ template_fields_renderers = {"sql1": "sql", "sql2": "sql"}
+ ui_color = "#fff7e6"
+
+ ratio_formulas = {
+ "max_over_min": lambda cur, ref: float(max(cur, ref)) / min(cur, ref),
+ "relative_diff": lambda cur, ref: float(abs(cur - ref)) / ref,
+ }
+
+ def __init__(
+ self,
+ *,
+ table: str,
+ metrics_thresholds: Dict[str, int],
+ date_filter_column: Optional[str] = "ds",
+ days_back: SupportsAbs[int] = -7,
+ ratio_formula: Optional[str] = "max_over_min",
+ ignore_zero: bool = True,
+ conn_id: Optional[str] = None,
+ database: Optional[str] = None,
+ **kwargs,
+ ):
+ super().__init__(conn_id=conn_id, database=database, **kwargs)
+ if ratio_formula not in self.ratio_formulas:
+ msg_template = "Invalid diff_method: {diff_method}. Supported diff methods are: {diff_methods}"
+
+ raise AirflowException(
+ msg_template.format(diff_method=ratio_formula, diff_methods=self.ratio_formulas)
+ )
+ self.ratio_formula = ratio_formula
+ self.ignore_zero = ignore_zero
+ self.table = table
+ self.metrics_thresholds = metrics_thresholds
+ self.metrics_sorted = sorted(metrics_thresholds.keys())
+ self.date_filter_column = date_filter_column
+ self.days_back = -abs(days_back)
+ sqlexp = ", ".join(self.metrics_sorted)
+ sqlt = f"SELECT {sqlexp} FROM {table} WHERE {date_filter_column}="
+
+ self.sql1 = sqlt + "'{{ ds }}'"
+ self.sql2 = sqlt + "'{{ macros.ds_add(ds, " + str(self.days_back) + ") }}'"
+
+ def execute(self, context: 'Context'):
+ hook = self.get_db_hook()
+ self.log.info("Using ratio formula: %s", self.ratio_formula)
+ self.log.info("Executing SQL check: %s", self.sql2)
+ row2 = hook.get_first(self.sql2)
+ self.log.info("Executing SQL check: %s", self.sql1)
+ row1 = hook.get_first(self.sql1)
+
+ if not row2:
+ raise AirflowException(f"The query {self.sql2} returned None")
+ if not row1:
+ raise AirflowException(f"The query {self.sql1} returned None")
+
+ current = dict(zip(self.metrics_sorted, row1))
+ reference = dict(zip(self.metrics_sorted, row2))
+
+ ratios: Dict[str, Optional[int]] = {}
+ test_results = {}
+
+ for metric in self.metrics_sorted:
+ cur = current[metric]
+ ref = reference[metric]
+ threshold = self.metrics_thresholds[metric]
+ if cur == 0 or ref == 0:
+ ratios[metric] = None
+ test_results[metric] = self.ignore_zero
+ else:
+ ratio_metric = self.ratio_formulas[self.ratio_formula](current[metric], reference[metric])
+ ratios[metric] = ratio_metric
+ if ratio_metric is not None:
+ test_results[metric] = ratio_metric < threshold
+ else:
+ test_results[metric] = self.ignore_zero
+
+ self.log.info(
+ (
+ "Current metric for %s: %s\n"
+ "Past metric for %s: %s\n"
+ "Ratio for %s: %s\n"
+ "Threshold: %s\n"
+ ),
+ metric,
+ cur,
+ metric,
+ ref,
+ metric,
+ ratios[metric],
+ threshold,
+ )
+
+ if not all(test_results.values()):
+ failed_tests = [it[0] for it in test_results.items() if not it[1]]
+ self.log.warning(
+ "The following %s tests out of %s failed:",
+ len(failed_tests),
+ len(self.metrics_sorted),
+ )
+ for k in failed_tests:
+ self.log.warning(
+ "'%s' check failed. %s is above %s",
+ k,
+ ratios[k],
+ self.metrics_thresholds[k],
+ )
+ raise AirflowException(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}")
+
+ self.log.info("All tests have passed")
+
+
+class SQLThresholdCheckOperator(BaseSQLOperator):
+ """
+ Performs a value check using sql code against a minimum threshold
+ and a maximum threshold. Thresholds can be in the form of a numeric
+ value OR a sql statement that results a numeric.
+
+ :param sql: the sql to be executed. (templated)
+ :param conn_id: the connection ID used to connect to the database.
+ :param database: name of database which overwrite the defined one in connection
+ :param min_threshold: numerical value or min threshold sql to be executed (templated)
+ :param max_threshold: numerical value or max threshold sql to be executed (templated)
+ """
+
+ template_fields: Sequence[str] = ("sql", "min_threshold", "max_threshold")
+ template_ext: Sequence[str] = (
+ ".hql",
+ ".sql",
+ )
+ template_fields_renderers = {"sql": "sql"}
+
+ def __init__(
+ self,
+ *,
+ sql: str,
+ min_threshold: Any,
+ max_threshold: Any,
+ conn_id: Optional[str] = None,
+ database: Optional[str] = None,
+ **kwargs,
+ ):
+ super().__init__(conn_id=conn_id, database=database, **kwargs)
+ self.sql = sql
+ self.min_threshold = _convert_to_float_if_possible(min_threshold)
+ self.max_threshold = _convert_to_float_if_possible(max_threshold)
+
+ def execute(self, context: 'Context'):
+ hook = self.get_db_hook()
+ result = hook.get_first(self.sql)[0]
+
+ if isinstance(self.min_threshold, float):
+ lower_bound = self.min_threshold
+ else:
+ lower_bound = hook.get_first(self.min_threshold)[0]
+
+ if isinstance(self.max_threshold, float):
+ upper_bound = self.max_threshold
+ else:
+ upper_bound = hook.get_first(self.max_threshold)[0]
+
+ meta_data = {
+ "result": result,
+ "task_id": self.task_id,
+ "min_threshold": lower_bound,
+ "max_threshold": upper_bound,
+ "within_threshold": lower_bound <= result <= upper_bound,
+ }
+
+ self.push(meta_data)
+ if not meta_data["within_threshold"]:
+ result = (
+ round(meta_data.get("result"), 2) # type: ignore[arg-type]
+ if meta_data.get("result") is not None
+ else "<None>"
+ )
+ error_msg = (
+ f'Threshold Check: "{meta_data.get("task_id")}" failed.\n'
+ f'DAG: {self.dag_id}\nTask_id: {meta_data.get("task_id")}\n'
+ f'Check description: {meta_data.get("description")}\n'
+ f"SQL: {self.sql}\n"
+ f'Result: {result} is not within thresholds '
+ f'{meta_data.get("min_threshold")} and {meta_data.get("max_threshold")}'
+ )
+ raise AirflowException(error_msg)
+
+ self.log.info("Test %s Successful.", self.task_id)
+
+ def push(self, meta_data):
+ """
+ Optional: Send data check info and metadata to an external database.
+ Default functionality will log metadata.
+ """
+ info = "\n".join(f"""{key}: {item}""" for key, item in meta_data.items())
+ self.log.info("Log from %s:\n%s", self.dag_id, info)
+
+
+class BranchSQLOperator(BaseSQLOperator, SkipMixin):
+ """
+ Allows a DAG to "branch" or follow a specified path based on the results of a SQL query.
+
+ :param sql: The SQL code to be executed, should return true or false (templated)
+ Template reference are recognized by str ending in '.sql'.
+ Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+ or string (true/y/yes/1/on/false/n/no/0/off).
+ :param follow_task_ids_if_true: task id or task ids to follow if query returns true
+ :param follow_task_ids_if_false: task id or task ids to follow if query returns false
+ :param conn_id: the connection ID used to connect to the database.
+ :param database: name of database which overwrite the defined one in connection
+ :param parameters: (optional) the parameters to render the SQL query with.
+ """
+
+ template_fields: Sequence[str] = ("sql",)
+ template_ext: Sequence[str] = (".sql",)
+ template_fields_renderers = {"sql": "sql"}
+ ui_color = "#a22034"
+ ui_fgcolor = "#F7F7F7"
+
+ def __init__(
+ self,
+ *,
+ sql: str,
+ follow_task_ids_if_true: List[str],
+ follow_task_ids_if_false: List[str],
+ conn_id: str = "default_conn_id",
+ database: Optional[str] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(conn_id=conn_id, database=database, **kwargs)
+ self.sql = sql
+ self.parameters = parameters
+ self.follow_task_ids_if_true = follow_task_ids_if_true
+ self.follow_task_ids_if_false = follow_task_ids_if_false
+
+ def execute(self, context: 'Context'):
+ self.log.info(
+ "Executing: %s (with parameters %s) with connection: %s",
+ self.sql,
+ self.parameters,
+ self.conn_id,
+ )
+ record = self.get_db_hook().get_first(self.sql, self.parameters)
+ if not record:
+ raise AirflowException(
+ "No rows returned from sql query. Operator expected True or False return value."
+ )
+
+ if isinstance(record, list):
+ if isinstance(record[0], list):
+ query_result = record[0][0]
+ else:
+ query_result = record[0]
+ elif isinstance(record, tuple):
+ query_result = record[0]
+ else:
+ query_result = record
+
+ self.log.info("Query returns %s, type '%s'", query_result, type(query_result))
+
+ follow_branch = None
+ try:
+ if isinstance(query_result, bool):
+ if query_result:
+ follow_branch = self.follow_task_ids_if_true
+ elif isinstance(query_result, str):
+ # return result is not Boolean, try to convert from String to Boolean
+ if parse_boolean(query_result):
+ follow_branch = self.follow_task_ids_if_true
+ elif isinstance(query_result, int):
+ if bool(query_result):
+ follow_branch = self.follow_task_ids_if_true
+ else:
+ raise AirflowException(
+ f"Unexpected query return result '{query_result}' type '{type(query_result)}'"
+ )
+
+ if follow_branch is None:
+ follow_branch = self.follow_task_ids_if_false
+ except ValueError:
+ raise AirflowException(
+ f"Unexpected query return result '{query_result}' type '{type(query_result)}'"
+ )
+
+ self.skip_all_except(context["ti"], follow_branch)
+
+
+def _convert_to_float_if_possible(s):
+ """
+ A small helper function to convert a string to a numeric value
+ if appropriate
+
+ :param s: the string to be converted
+ """
+ try:
+ ret = float(s)
+ except (ValueError, TypeError):
+ ret = s
+ return ret
diff --git a/airflow/providers/common/sql/provider.yaml b/airflow/providers/common/sql/provider.yaml
index 39c8d483e4..30bc1258b3 100644
--- a/airflow/providers/common/sql/provider.yaml
+++ b/airflow/providers/common/sql/provider.yaml
@@ -22,7 +22,7 @@ description: |
`Common SQL Provider <https://en.wikipedia.org/wiki/SQL>`__
versions:
- - 1.0.0
+ - 1.1.0
dependencies:
- sqlparse>=0.4.2
diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py
index 550c317406..827f82f5f6 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -32,7 +32,11 @@ from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob, Q
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, BaseOperatorLink
from airflow.models.xcom import XCom
-from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator
+from airflow.providers.common.sql.operators.sql import (
+ SQLCheckOperator,
+ SQLIntervalCheckOperator,
+ SQLValueCheckOperator,
+)
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob
from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
from airflow.providers.google.cloud.links.bigquery import BigQueryDatasetLink, BigQueryTableLink
diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml
index 1533cf4612..c3589140aa 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -53,7 +53,7 @@ versions:
dependencies:
- apache-airflow>=2.2.0
- - apache-airflow-providers-common-sql
+ - apache-airflow-providers-common-sql>=1.1.0
# Google has very clear rules on what dependencies should be used. All the limits below
# follow strict guidelines of Google Libraries as quoted here:
# While this issue is open, dependents of google-api-core, google-cloud-core. and google-auth
diff --git a/airflow/providers/google/suite/transfers/sql_to_sheets.py b/airflow/providers/google/suite/transfers/sql_to_sheets.py
index 8626fb7227..b448aafa17 100644
--- a/airflow/providers/google/suite/transfers/sql_to_sheets.py
+++ b/airflow/providers/google/suite/transfers/sql_to_sheets.py
@@ -22,7 +22,7 @@ import numbers
from contextlib import closing
from typing import Any, Iterable, Mapping, Optional, Sequence, Union
-from airflow.operators.sql import BaseSQLOperator
+from airflow.providers.common.sql.operators.sql import BaseSQLOperator
from airflow.providers.google.suite.hooks.sheets import GSheetsHook
diff --git a/airflow/providers/qubole/operators/qubole_check.py b/airflow/providers/qubole/operators/qubole_check.py
index e63ff308b3..950c0d62fc 100644
--- a/airflow/providers/qubole/operators/qubole_check.py
+++ b/airflow/providers/qubole/operators/qubole_check.py
@@ -19,7 +19,7 @@
from typing import Callable, Optional, Sequence, Union
from airflow.exceptions import AirflowException
-from airflow.operators.sql import SQLCheckOperator, SQLValueCheckOperator
+from airflow.providers.common.sql.operators.sql import SQLCheckOperator, SQLValueCheckOperator
from airflow.providers.qubole.hooks.qubole_check import QuboleCheckHook
from airflow.providers.qubole.operators.qubole import QuboleOperator
diff --git a/airflow/providers/qubole/provider.yaml b/airflow/providers/qubole/provider.yaml
index 9b854da308..b45b2e1d3a 100644
--- a/airflow/providers/qubole/provider.yaml
+++ b/airflow/providers/qubole/provider.yaml
@@ -36,7 +36,7 @@ versions:
dependencies:
- apache-airflow>=2.2.0
- - apache-airflow-providers-common-sql
+ - apache-airflow-providers-common-sql>=1.1.0
- qds-sdk>=1.10.4
integrations:
diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py
index dd996cc526..697cc9a344 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -18,8 +18,12 @@
from typing import Any, Iterable, List, Mapping, Optional, Sequence, SupportsAbs, Union
from airflow.models import BaseOperator
-from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
+from airflow.providers.common.sql.operators.sql import (
+ SQLCheckOperator,
+ SQLIntervalCheckOperator,
+ SQLValueCheckOperator,
+)
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
diff --git a/airflow/providers/snowflake/provider.yaml b/airflow/providers/snowflake/provider.yaml
index 6ce66b1c66..cbf8fed30c 100644
--- a/airflow/providers/snowflake/provider.yaml
+++ b/airflow/providers/snowflake/provider.yaml
@@ -44,7 +44,7 @@ versions:
dependencies:
- apache-airflow>=2.2.0
- - apache-airflow-providers-common-sql
+ - apache-airflow-providers-common-sql>=1.1.0
- snowflake-connector-python>=2.4.1
- snowflake-sqlalchemy>=1.1.0
diff --git a/airflow/sensors/sql_sensor.py b/airflow/sensors/sql_sensor.py
index fafc5335a2..6f7b1e46c0 100644
--- a/airflow/sensors/sql_sensor.py
+++ b/airflow/sensors/sql_sensor.py
@@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""This module is deprecated. Please use :mod:`airflow.sensors.sql`."""
+"""This module is deprecated. Please use :mod:`airflow.providers.common.sql.sensors.sql`."""
import warnings
diff --git a/docs/apache-airflow/operators-and-hooks-ref.rst b/docs/apache-airflow/operators-and-hooks-ref.rst
index c08a5deb00..7b5e148c5f 100644
--- a/docs/apache-airflow/operators-and-hooks-ref.rst
+++ b/docs/apache-airflow/operators-and-hooks-ref.rst
@@ -77,9 +77,6 @@ For details see: :doc:`apache-airflow-providers:operators-and-hooks-ref/index`.
* - :mod:`airflow.operators.python`
- :doc:`How to use <howto/operator/python>`
- * - :mod:`airflow.operators.sql`
- -
-
* - :mod:`airflow.operators.subdag`
-
@@ -112,9 +109,6 @@ For details see: :doc:`apache-airflow-providers:operators-and-hooks-ref/index`.
* - :mod:`airflow.sensors.smart_sensor`
- :doc:`concepts/smart-sensors`
- * - :mod:`airflow.sensors.sql`
- -
-
* - :mod:`airflow.sensors.time_delta`
-
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index 51e3ea26f8..bf1b18a598 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -70,7 +70,7 @@
},
"apache.druid": {
"deps": [
- "apache-airflow-providers-common-sql",
+ "apache-airflow-providers-common-sql>=1.1.0",
"apache-airflow>=2.2.0",
"pydruid>=0.4.1"
],
@@ -292,7 +292,7 @@
"google": {
"deps": [
"PyOpenSSL",
- "apache-airflow-providers-common-sql",
+ "apache-airflow-providers-common-sql>=1.1.0",
"apache-airflow>=2.2.0",
"google-ads>=15.1.1",
"google-api-core>=2.7.0,<3.0.0",
@@ -576,7 +576,7 @@
},
"qubole": {
"deps": [
- "apache-airflow-providers-common-sql",
+ "apache-airflow-providers-common-sql>=1.1.0",
"apache-airflow>=2.2.0",
"qds-sdk>=1.10.4"
],
@@ -650,7 +650,7 @@
},
"snowflake": {
"deps": [
- "apache-airflow-providers-common-sql",
+ "apache-airflow-providers-common-sql>=1.1.0",
"apache-airflow>=2.2.0",
"snowflake-connector-python>=2.4.1",
"snowflake-sqlalchemy>=1.1.0"
diff --git a/tests/deprecated_classes.py b/tests/deprecated_classes.py
index 5b76af905d..36ac69a2b9 100644
--- a/tests/deprecated_classes.py
+++ b/tests/deprecated_classes.py
@@ -1028,7 +1028,7 @@ OPERATORS = [
'airflow.contrib.operators.sqoop_operator.SqoopOperator',
),
(
- 'airflow.operators.sql.SQLCheckOperator',
+ 'airflow.providers.common.sql.operators.sql.SQLCheckOperator',
'airflow.operators.druid_check_operator.DruidCheckOperator',
),
(
@@ -1232,35 +1232,35 @@ OPERATORS = [
'airflow.operators.papermill_operator.PapermillOperator',
),
(
- 'airflow.operators.sql.SQLCheckOperator',
+ 'airflow.providers.common.sql.operators.sql.SQLCheckOperator',
'airflow.operators.presto_check_operator.PrestoCheckOperator',
),
(
- 'airflow.operators.sql.SQLIntervalCheckOperator',
+ 'airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator',
'airflow.operators.presto_check_operator.PrestoIntervalCheckOperator',
),
(
- 'airflow.operators.sql.SQLValueCheckOperator',
+ 'airflow.providers.common.sql.operators.sql.SQLValueCheckOperator',
'airflow.operators.presto_check_operator.PrestoValueCheckOperator',
),
(
- 'airflow.operators.sql.SQLCheckOperator',
+ 'airflow.providers.common.sql.operators.sql.SQLCheckOperator',
'airflow.operators.check_operator.CheckOperator',
),
(
- 'airflow.operators.sql.SQLIntervalCheckOperator',
+ 'airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator',
'airflow.operators.check_operator.IntervalCheckOperator',
),
(
- 'airflow.operators.sql.SQLValueCheckOperator',
+ 'airflow.providers.common.sql.operators.sql.SQLValueCheckOperator',
'airflow.operators.check_operator.ValueCheckOperator',
),
(
- 'airflow.operators.sql.SQLThresholdCheckOperator',
+ 'airflow.providers.common.sql.operators.sql.SQLThresholdCheckOperator',
'airflow.operators.check_operator.ThresholdCheckOperator',
),
(
- 'airflow.operators.sql.BranchSQLOperator',
+ 'airflow.providers.common.sql.operators.sql.BranchSQLOperator',
'airflow.operators.sql_branch_operator.BranchSqlOperator',
),
(
diff --git a/tests/operators/test_sql.py b/tests/operators/test_sql.py
deleted file mode 100644
index 43d202819a..0000000000
--- a/tests/operators/test_sql.py
+++ /dev/null
@@ -1,779 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import datetime
-import unittest
-from unittest import mock
-
-import pytest
-
-from airflow.exceptions import AirflowException
-from airflow.models import DAG, Connection, DagRun, TaskInstance as TI, XCom
-from airflow.operators.empty import EmptyOperator
-from airflow.operators.sql import (
- BranchSQLOperator,
- SQLCheckOperator,
- SQLIntervalCheckOperator,
- SQLThresholdCheckOperator,
- SQLValueCheckOperator,
-)
-from airflow.providers.postgres.hooks.postgres import PostgresHook
-from airflow.utils import timezone
-from airflow.utils.session import create_session
-from airflow.utils.state import State
-from tests.providers.apache.hive import TestHiveEnvironment
-
-DEFAULT_DATE = timezone.datetime(2016, 1, 1)
-INTERVAL = datetime.timedelta(hours=12)
-
-SUPPORTED_TRUE_VALUES = [
- ["True"],
- ["true"],
- ["1"],
- ["on"],
- [1],
- True,
- "true",
- "1",
- "on",
- 1,
-]
-SUPPORTED_FALSE_VALUES = [
- ["False"],
- ["false"],
- ["0"],
- ["off"],
- [0],
- False,
- "false",
- "0",
- "off",
- 0,
-]
-
-
-@mock.patch(
- 'airflow.operators.sql.BaseHook.get_connection',
- return_value=Connection(conn_id='sql_default', conn_type='postgres'),
-)
-class TestSQLCheckOperatorDbHook:
- def setup_method(self):
- self.task_id = "test_task"
- self.conn_id = "sql_default"
- self._operator = SQLCheckOperator(task_id=self.task_id, conn_id=self.conn_id, sql="sql")
-
- @pytest.mark.parametrize('database', [None, 'test-db'])
- def test_get_hook(self, mock_get_conn, database):
- if database:
- self._operator.database = database
- assert isinstance(self._operator._hook, PostgresHook)
- assert self._operator._hook.schema == database
- mock_get_conn.assert_called_once_with(self.conn_id)
-
- def test_not_allowed_conn_type(self, mock_get_conn):
- mock_get_conn.return_value = Connection(conn_id='sql_default', conn_type='s3')
- with pytest.raises(AirflowException, match=r"The connection type is not supported"):
- self._operator._hook
-
- def test_sql_operator_hook_params_snowflake(self, mock_get_conn):
- mock_get_conn.return_value = Connection(conn_id='snowflake_default', conn_type='snowflake')
- self._operator.hook_params = {
- 'warehouse': 'warehouse',
- 'database': 'database',
- 'role': 'role',
- 'schema': 'schema',
- 'log_sql': False,
- }
- assert self._operator._hook.conn_type == 'snowflake'
- assert self._operator._hook.warehouse == 'warehouse'
- assert self._operator._hook.database == 'database'
- assert self._operator._hook.role == 'role'
- assert self._operator._hook.schema == 'schema'
- assert not self._operator._hook.log_sql
-
- def test_sql_operator_hook_params_biguery(self, mock_get_conn):
- mock_get_conn.return_value = Connection(
- conn_id='google_cloud_bigquery_default', conn_type='gcpbigquery'
- )
- self._operator.hook_params = {'use_legacy_sql': True, 'location': 'us-east1'}
- assert self._operator._hook.conn_type == 'gcpbigquery'
- assert self._operator._hook.use_legacy_sql
- assert self._operator._hook.location == 'us-east1'
-
-
-class TestCheckOperator(unittest.TestCase):
- def setUp(self):
- self._operator = SQLCheckOperator(task_id="test_task", sql="sql")
-
- @mock.patch.object(SQLCheckOperator, "get_db_hook")
- def test_execute_no_records(self, mock_get_db_hook):
- mock_get_db_hook.return_value.get_first.return_value = []
-
- with pytest.raises(AirflowException, match=r"The query returned None"):
- self._operator.execute({})
-
- @mock.patch.object(SQLCheckOperator, "get_db_hook")
- def test_execute_not_all_records_are_true(self, mock_get_db_hook):
- mock_get_db_hook.return_value.get_first.return_value = ["data", ""]
-
- with pytest.raises(AirflowException, match=r"Test failed."):
- self._operator.execute({})
-
-
-class TestValueCheckOperator(unittest.TestCase):
- def setUp(self):
- self.task_id = "test_task"
- self.conn_id = "default_conn"
-
- def _construct_operator(self, sql, pass_value, tolerance=None):
- dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1))
-
- return SQLValueCheckOperator(
- dag=dag,
- task_id=self.task_id,
- conn_id=self.conn_id,
- sql=sql,
- pass_value=pass_value,
- tolerance=tolerance,
- )
-
- def test_pass_value_template_string(self):
- pass_value_str = "2018-03-22"
- operator = self._construct_operator("select date from tab1;", "{{ ds }}")
-
- operator.render_template_fields({"ds": pass_value_str})
-
- assert operator.task_id == self.task_id
- assert operator.pass_value == pass_value_str
-
- def test_pass_value_template_string_float(self):
- pass_value_float = 4.0
- operator = self._construct_operator("select date from tab1;", pass_value_float)
-
- operator.render_template_fields({})
-
- assert operator.task_id == self.task_id
- assert operator.pass_value == str(pass_value_float)
-
- @mock.patch.object(SQLValueCheckOperator, "get_db_hook")
- def test_execute_pass(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.return_value = [10]
- mock_get_db_hook.return_value = mock_hook
- sql = "select value from tab1 limit 1;"
- operator = self._construct_operator(sql, 5, 1)
-
- operator.execute(None)
-
- mock_hook.get_first.assert_called_once_with(sql)
-
- @mock.patch.object(SQLValueCheckOperator, "get_db_hook")
- def test_execute_fail(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.return_value = [11]
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator("select value from tab1 limit 1;", 5, 1)
-
- with pytest.raises(AirflowException, match="Tolerance:100.0%"):
- operator.execute()
-
-
-class TestIntervalCheckOperator(unittest.TestCase):
- def _construct_operator(self, table, metric_thresholds, ratio_formula, ignore_zero):
- return SQLIntervalCheckOperator(
- task_id="test_task",
- table=table,
- metrics_thresholds=metric_thresholds,
- ratio_formula=ratio_formula,
- ignore_zero=ignore_zero,
- )
-
- def test_invalid_ratio_formula(self):
- with pytest.raises(AirflowException, match="Invalid diff_method"):
- self._construct_operator(
- table="test_table",
- metric_thresholds={
- "f1": 1,
- },
- ratio_formula="abs",
- ignore_zero=False,
- )
-
- @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
- def test_execute_not_ignore_zero(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.return_value = [0]
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator(
- table="test_table",
- metric_thresholds={
- "f1": 1,
- },
- ratio_formula="max_over_min",
- ignore_zero=False,
- )
-
- with pytest.raises(AirflowException):
- operator.execute()
-
- @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
- def test_execute_ignore_zero(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.return_value = [0]
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator(
- table="test_table",
- metric_thresholds={
- "f1": 1,
- },
- ratio_formula="max_over_min",
- ignore_zero=True,
- )
-
- operator.execute()
-
- @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
- def test_execute_min_max(self, mock_get_db_hook):
- mock_hook = mock.Mock()
-
- def returned_row():
- rows = [
- [2, 2, 2, 2], # reference
- [1, 1, 1, 1], # current
- ]
-
- yield from rows
-
- mock_hook.get_first.side_effect = returned_row()
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator(
- table="test_table",
- metric_thresholds={
- "f0": 1.0,
- "f1": 1.5,
- "f2": 2.0,
- "f3": 2.5,
- },
- ratio_formula="max_over_min",
- ignore_zero=True,
- )
-
- with pytest.raises(AirflowException, match="f0, f1, f2"):
- operator.execute()
-
- @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
- def test_execute_diff(self, mock_get_db_hook):
- mock_hook = mock.Mock()
-
- def returned_row():
- rows = [
- [3, 3, 3, 3], # reference
- [1, 1, 1, 1], # current
- ]
-
- yield from rows
-
- mock_hook.get_first.side_effect = returned_row()
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator(
- table="test_table",
- metric_thresholds={
- "f0": 0.5,
- "f1": 0.6,
- "f2": 0.7,
- "f3": 0.8,
- },
- ratio_formula="relative_diff",
- ignore_zero=True,
- )
-
- with pytest.raises(AirflowException, match="f0, f1"):
- operator.execute()
-
-
-class TestThresholdCheckOperator(unittest.TestCase):
- def _construct_operator(self, sql, min_threshold, max_threshold):
- dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1))
-
- return SQLThresholdCheckOperator(
- task_id="test_task",
- sql=sql,
- min_threshold=min_threshold,
- max_threshold=max_threshold,
- dag=dag,
- )
-
- @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
- def test_pass_min_value_max_value(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.return_value = (10,)
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator("Select avg(val) from table1 limit 1", 1, 100)
-
- operator.execute()
-
- @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
- def test_fail_min_value_max_value(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.return_value = (10,)
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator("Select avg(val) from table1 limit 1", 20, 100)
-
- with pytest.raises(AirflowException, match="10.*20.0.*100.0"):
- operator.execute()
-
- @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
- def test_pass_min_sql_max_sql(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator("Select 10", "Select 1", "Select 100")
-
- operator.execute()
-
- @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
- def test_fail_min_sql_max_sql(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator("Select 10", "Select 20", "Select 100")
-
- with pytest.raises(AirflowException, match="10.*20.*100"):
- operator.execute()
-
- @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
- def test_pass_min_value_max_sql(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator("Select 75", 45, "Select 100")
-
- operator.execute()
-
- @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
- def test_fail_min_sql_max_value(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator("Select 155", "Select 45", 100)
-
- with pytest.raises(AirflowException, match="155.*45.*100.0"):
- operator.execute()
-
-
-class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
- """
- Test for SQL Branch Operator
- """
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
-
- with create_session() as session:
- session.query(DagRun).delete()
- session.query(TI).delete()
- session.query(XCom).delete()
-
- def setUp(self):
- super().setUp()
- self.dag = DAG(
- "sql_branch_operator_test",
- default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
- schedule_interval=INTERVAL,
- )
- self.branch_1 = EmptyOperator(task_id="branch_1", dag=self.dag)
- self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag)
- self.branch_3 = None
-
- def tearDown(self):
- super().tearDown()
-
- with create_session() as session:
- session.query(DagRun).delete()
- session.query(TI).delete()
- session.query(XCom).delete()
-
- def test_unsupported_conn_type(self):
- """Check if BranchSQLOperator throws an exception for unsupported connection type"""
- op = BranchSQLOperator(
- task_id="make_choice",
- conn_id="redis_default",
- sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
- follow_task_ids_if_true="branch_1",
- follow_task_ids_if_false="branch_2",
- dag=self.dag,
- )
-
- with pytest.raises(AirflowException):
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
-
- def test_invalid_conn(self):
- """Check if BranchSQLOperator throws an exception for invalid connection"""
- op = BranchSQLOperator(
- task_id="make_choice",
- conn_id="invalid_connection",
- sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
- follow_task_ids_if_true="branch_1",
- follow_task_ids_if_false="branch_2",
- dag=self.dag,
- )
-
- with pytest.raises(AirflowException):
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
-
- def test_invalid_follow_task_true(self):
- """Check if BranchSQLOperator throws an exception for invalid connection"""
- op = BranchSQLOperator(
- task_id="make_choice",
- conn_id="invalid_connection",
- sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
- follow_task_ids_if_true=None,
- follow_task_ids_if_false="branch_2",
- dag=self.dag,
- )
-
- with pytest.raises(AirflowException):
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
-
- def test_invalid_follow_task_false(self):
- """Check if BranchSQLOperator throws an exception for invalid connection"""
- op = BranchSQLOperator(
- task_id="make_choice",
- conn_id="invalid_connection",
- sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
- follow_task_ids_if_true="branch_1",
- follow_task_ids_if_false=None,
- dag=self.dag,
- )
-
- with pytest.raises(AirflowException):
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
-
- @pytest.mark.backend("mysql")
- def test_sql_branch_operator_mysql(self):
- """Check if BranchSQLOperator works with backend"""
- branch_op = BranchSQLOperator(
- task_id="make_choice",
- conn_id="mysql_default",
- sql="SELECT 1",
- follow_task_ids_if_true="branch_1",
- follow_task_ids_if_false="branch_2",
- dag=self.dag,
- )
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
-
- @pytest.mark.backend("postgres")
- def test_sql_branch_operator_postgres(self):
- """Check if BranchSQLOperator works with backend"""
- branch_op = BranchSQLOperator(
- task_id="make_choice",
- conn_id="postgres_default",
- sql="SELECT 1",
- follow_task_ids_if_true="branch_1",
- follow_task_ids_if_false="branch_2",
- dag=self.dag,
- )
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
-
- @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
- def test_branch_single_value_with_dag_run(self, mock_get_db_hook):
- """Check BranchSQLOperator branch operation"""
- branch_op = BranchSQLOperator(
- task_id="make_choice",
- conn_id="mysql_default",
- sql="SELECT 1",
- follow_task_ids_if_true="branch_1",
- follow_task_ids_if_false="branch_2",
- dag=self.dag,
- )
-
- self.branch_1.set_upstream(branch_op)
- self.branch_2.set_upstream(branch_op)
- self.dag.clear()
-
- dr = self.dag.create_dagrun(
- run_id="manual__",
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
-
- mock_get_records = mock_get_db_hook.return_value.get_first
-
- mock_get_records.return_value = 1
-
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- tis = dr.get_task_instances()
- for ti in tis:
- if ti.task_id == "make_choice":
- assert ti.state == State.SUCCESS
- elif ti.task_id == "branch_1":
- assert ti.state == State.NONE
- elif ti.task_id == "branch_2":
- assert ti.state == State.SKIPPED
- else:
- raise ValueError(f"Invalid task id {ti.task_id} found!")
-
- @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
- def test_branch_true_with_dag_run(self, mock_get_db_hook):
- """Check BranchSQLOperator branch operation"""
- branch_op = BranchSQLOperator(
- task_id="make_choice",
- conn_id="mysql_default",
- sql="SELECT 1",
- follow_task_ids_if_true="branch_1",
- follow_task_ids_if_false="branch_2",
- dag=self.dag,
- )
-
- self.branch_1.set_upstream(branch_op)
- self.branch_2.set_upstream(branch_op)
- self.dag.clear()
-
- dr = self.dag.create_dagrun(
- run_id="manual__",
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
-
- mock_get_records = mock_get_db_hook.return_value.get_first
-
- for true_value in SUPPORTED_TRUE_VALUES:
- mock_get_records.return_value = true_value
-
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- tis = dr.get_task_instances()
- for ti in tis:
- if ti.task_id == "make_choice":
- assert ti.state == State.SUCCESS
- elif ti.task_id == "branch_1":
- assert ti.state == State.NONE
- elif ti.task_id == "branch_2":
- assert ti.state == State.SKIPPED
- else:
- raise ValueError(f"Invalid task id {ti.task_id} found!")
-
- @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
- def test_branch_false_with_dag_run(self, mock_get_db_hook):
- """Check BranchSQLOperator branch operation"""
- branch_op = BranchSQLOperator(
- task_id="make_choice",
- conn_id="mysql_default",
- sql="SELECT 1",
- follow_task_ids_if_true="branch_1",
- follow_task_ids_if_false="branch_2",
- dag=self.dag,
- )
-
- self.branch_1.set_upstream(branch_op)
- self.branch_2.set_upstream(branch_op)
- self.dag.clear()
-
- dr = self.dag.create_dagrun(
- run_id="manual__",
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
-
- mock_get_records = mock_get_db_hook.return_value.get_first
-
- for false_value in SUPPORTED_FALSE_VALUES:
- mock_get_records.return_value = false_value
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- tis = dr.get_task_instances()
- for ti in tis:
- if ti.task_id == "make_choice":
- assert ti.state == State.SUCCESS
- elif ti.task_id == "branch_1":
- assert ti.state == State.SKIPPED
- elif ti.task_id == "branch_2":
- assert ti.state == State.NONE
- else:
- raise ValueError(f"Invalid task id {ti.task_id} found!")
-
- @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
- def test_branch_list_with_dag_run(self, mock_get_db_hook):
- """Checks if the BranchSQLOperator supports branching off to a list of tasks."""
- branch_op = BranchSQLOperator(
- task_id="make_choice",
- conn_id="mysql_default",
- sql="SELECT 1",
- follow_task_ids_if_true=["branch_1", "branch_2"],
- follow_task_ids_if_false="branch_3",
- dag=self.dag,
- )
-
- self.branch_1.set_upstream(branch_op)
- self.branch_2.set_upstream(branch_op)
- self.branch_3 = EmptyOperator(task_id="branch_3", dag=self.dag)
- self.branch_3.set_upstream(branch_op)
- self.dag.clear()
-
- dr = self.dag.create_dagrun(
- run_id="manual__",
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
-
- mock_get_records = mock_get_db_hook.return_value.get_first
- mock_get_records.return_value = [["1"]]
-
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- tis = dr.get_task_instances()
- for ti in tis:
- if ti.task_id == "make_choice":
- assert ti.state == State.SUCCESS
- elif ti.task_id == "branch_1":
- assert ti.state == State.NONE
- elif ti.task_id == "branch_2":
- assert ti.state == State.NONE
- elif ti.task_id == "branch_3":
- assert ti.state == State.SKIPPED
- else:
- raise ValueError(f"Invalid task id {ti.task_id} found!")
-
- @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
- def test_invalid_query_result_with_dag_run(self, mock_get_db_hook):
- """Check BranchSQLOperator branch operation"""
- branch_op = BranchSQLOperator(
- task_id="make_choice",
- conn_id="mysql_default",
- sql="SELECT 1",
- follow_task_ids_if_true="branch_1",
- follow_task_ids_if_false="branch_2",
- dag=self.dag,
- )
-
- self.branch_1.set_upstream(branch_op)
- self.branch_2.set_upstream(branch_op)
- self.dag.clear()
-
- self.dag.create_dagrun(
- run_id="manual__",
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
-
- mock_get_records = mock_get_db_hook.return_value.get_first
-
- mock_get_records.return_value = ["Invalid Value"]
-
- with pytest.raises(AirflowException):
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
- def test_with_skip_in_branch_downstream_dependencies(self, mock_get_db_hook):
- """Test SQL Branch with skipping all downstream dependencies"""
- branch_op = BranchSQLOperator(
- task_id="make_choice",
- conn_id="mysql_default",
- sql="SELECT 1",
- follow_task_ids_if_true="branch_1",
- follow_task_ids_if_false="branch_2",
- dag=self.dag,
- )
-
- branch_op >> self.branch_1 >> self.branch_2
- branch_op >> self.branch_2
- self.dag.clear()
-
- dr = self.dag.create_dagrun(
- run_id="manual__",
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
-
- mock_get_records = mock_get_db_hook.return_value.get_first
-
- for true_value in SUPPORTED_TRUE_VALUES:
- mock_get_records.return_value = [true_value]
-
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- tis = dr.get_task_instances()
- for ti in tis:
- if ti.task_id == "make_choice":
- assert ti.state == State.SUCCESS
- elif ti.task_id == "branch_1":
- assert ti.state == State.NONE
- elif ti.task_id == "branch_2":
- assert ti.state == State.NONE
- else:
- raise ValueError(f"Invalid task id {ti.task_id} found!")
-
- @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
- def test_with_skip_in_branch_downstream_dependencies2(self, mock_get_db_hook):
- """Test skipping downstream dependency for false condition"""
- branch_op = BranchSQLOperator(
- task_id="make_choice",
- conn_id="mysql_default",
- sql="SELECT 1",
- follow_task_ids_if_true="branch_1",
- follow_task_ids_if_false="branch_2",
- dag=self.dag,
- )
-
- branch_op >> self.branch_1 >> self.branch_2
- branch_op >> self.branch_2
- self.dag.clear()
-
- dr = self.dag.create_dagrun(
- run_id="manual__",
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
-
- mock_get_records = mock_get_db_hook.return_value.get_first
-
- for false_value in SUPPORTED_FALSE_VALUES:
- mock_get_records.return_value = [false_value]
-
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- tis = dr.get_task_instances()
- for ti in tis:
- if ti.task_id == "make_choice":
- assert ti.state == State.SUCCESS
- elif ti.task_id == "branch_1":
- assert ti.state == State.SKIPPED
- elif ti.task_id == "branch_2":
- assert ti.state == State.NONE
- else:
- raise ValueError(f"Invalid task id {ti.task_id} found!")
diff --git a/tests/providers/common/sql/operators/test_sql.py b/tests/providers/common/sql/operators/test_sql.py
index da53e6134f..c6dc3d0757 100644
--- a/tests/providers/common/sql/operators/test_sql.py
+++ b/tests/providers/common/sql/operators/test_sql.py
@@ -15,12 +15,32 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import datetime
+import unittest
+from unittest import mock
+from unittest.mock import MagicMock
import pandas as pd
import pytest
+from airflow import DAG
from airflow.exceptions import AirflowException
-from airflow.providers.common.sql.operators.sql import SQLColumnCheckOperator, SQLTableCheckOperator
+from airflow.models import Connection, DagRun, TaskInstance as TI, XCom
+from airflow.operators.empty import EmptyOperator
+from airflow.providers.common.sql.operators.sql import (
+ BranchSQLOperator,
+ SQLCheckOperator,
+ SQLColumnCheckOperator,
+ SQLIntervalCheckOperator,
+ SQLTableCheckOperator,
+ SQLThresholdCheckOperator,
+ SQLValueCheckOperator,
+)
+from airflow.providers.postgres.hooks.postgres import PostgresHook
+from airflow.utils import timezone
+from airflow.utils.session import create_session
+from airflow.utils.state import State
+from tests.providers.apache.hive import TestHiveEnvironment
class MockHook:
@@ -66,30 +86,30 @@ class TestColumnCheckOperator:
def test_pass_all_checks_exact_check(self, monkeypatch):
operator = self._construct_operator(monkeypatch, self.valid_column_mapping, (0, 10, 10, 1, 19))
- operator.execute()
+ operator.execute(context=MagicMock())
def test_max_less_than_fails_check(self, monkeypatch):
with pytest.raises(AirflowException):
operator = self._construct_operator(monkeypatch, self.valid_column_mapping, (0, 10, 10, 1, 21))
- operator.execute()
+ operator.execute(context=MagicMock())
assert operator.column_mapping["X"]["max"]["success"] is False
def test_max_greater_than_fails_check(self, monkeypatch):
with pytest.raises(AirflowException):
operator = self._construct_operator(monkeypatch, self.valid_column_mapping, (0, 10, 10, 1, 9))
- operator.execute()
+ operator.execute(context=MagicMock())
assert operator.column_mapping["X"]["max"]["success"] is False
def test_pass_all_checks_inexact_check(self, monkeypatch):
operator = self._construct_operator(monkeypatch, self.valid_column_mapping, (0, 9, 12, 0, 15))
- operator.execute()
+ operator.execute(context=MagicMock())
def test_fail_all_checks_check(self, monkeypatch):
operator = operator = self._construct_operator(
monkeypatch, self.valid_column_mapping, (1, 12, 11, -1, 20)
)
with pytest.raises(AirflowException):
- operator.execute()
+ operator.execute(context=MagicMock())
class TestTableCheckOperator:
@@ -119,7 +139,7 @@ class TestTableCheckOperator:
}
)
operator = self._construct_operator(monkeypatch, self.checks, df)
- operator.execute()
+ operator.execute(context=MagicMock())
def test_fail_all_checks_check(self, monkeypatch):
df = pd.DataFrame(
@@ -127,4 +147,744 @@ class TestTableCheckOperator:
)
operator = self._construct_operator(monkeypatch, self.checks, df)
with pytest.raises(AirflowException):
- operator.execute()
+ operator.execute(context=MagicMock())
+
+
+DEFAULT_DATE = timezone.datetime(2016, 1, 1)
+INTERVAL = datetime.timedelta(hours=12)
+SUPPORTED_TRUE_VALUES = [
+ ["True"],
+ ["true"],
+ ["1"],
+ ["on"],
+ [1],
+ True,
+ "true",
+ "1",
+ "on",
+ 1,
+]
+SUPPORTED_FALSE_VALUES = [
+ ["False"],
+ ["false"],
+ ["0"],
+ ["off"],
+ [0],
+ False,
+ "false",
+ "0",
+ "off",
+ 0,
+]
+
+
+@mock.patch(
+ 'airflow.providers.common.sql.operators.sql.BaseHook.get_connection',
+ return_value=Connection(conn_id='sql_default', conn_type='postgres'),
+)
+class TestSQLCheckOperatorDbHook:
+ def setup_method(self):
+ self.task_id = "test_task"
+ self.conn_id = "sql_default"
+ self._operator = SQLCheckOperator(task_id=self.task_id, conn_id=self.conn_id, sql="sql")
+
+ @pytest.mark.parametrize('database', [None, 'test-db'])
+ def test_get_hook(self, mock_get_conn, database):
+ if database:
+ self._operator.database = database
+ assert isinstance(self._operator._hook, PostgresHook)
+ assert self._operator._hook.schema == database
+ mock_get_conn.assert_called_once_with(self.conn_id)
+
+ def test_not_allowed_conn_type(self, mock_get_conn):
+ mock_get_conn.return_value = Connection(conn_id='sql_default', conn_type='s3')
+ with pytest.raises(AirflowException, match=r"The connection type is not supported"):
+ self._operator._hook
+
+ def test_sql_operator_hook_params_snowflake(self, mock_get_conn):
+ mock_get_conn.return_value = Connection(conn_id='snowflake_default', conn_type='snowflake')
+ self._operator.hook_params = {
+ 'warehouse': 'warehouse',
+ 'database': 'database',
+ 'role': 'role',
+ 'schema': 'schema',
+ 'log_sql': False,
+ }
+ assert self._operator._hook.conn_type == 'snowflake'
+ assert self._operator._hook.warehouse == 'warehouse'
+ assert self._operator._hook.database == 'database'
+ assert self._operator._hook.role == 'role'
+ assert self._operator._hook.schema == 'schema'
+ assert not self._operator._hook.log_sql
+
+ def test_sql_operator_hook_params_biguery(self, mock_get_conn):
+ mock_get_conn.return_value = Connection(
+ conn_id='google_cloud_bigquery_default', conn_type='gcpbigquery'
+ )
+ self._operator.hook_params = {'use_legacy_sql': True, 'location': 'us-east1'}
+ assert self._operator._hook.conn_type == 'gcpbigquery'
+ assert self._operator._hook.use_legacy_sql
+ assert self._operator._hook.location == 'us-east1'
+
+
+class TestCheckOperator(unittest.TestCase):
+ def setUp(self):
+ self._operator = SQLCheckOperator(task_id="test_task", sql="sql")
+
+ @mock.patch.object(SQLCheckOperator, "get_db_hook")
+ def test_execute_no_records(self, mock_get_db_hook):
+ mock_get_db_hook.return_value.get_first.return_value = []
+
+ with pytest.raises(AirflowException, match=r"The query returned None"):
+ self._operator.execute({})
+
+ @mock.patch.object(SQLCheckOperator, "get_db_hook")
+ def test_execute_not_all_records_are_true(self, mock_get_db_hook):
+ mock_get_db_hook.return_value.get_first.return_value = ["data", ""]
+
+ with pytest.raises(AirflowException, match=r"Test failed."):
+ self._operator.execute({})
+
+
+class TestValueCheckOperator(unittest.TestCase):
+ def setUp(self):
+ self.task_id = "test_task"
+ self.conn_id = "default_conn"
+
+ def _construct_operator(self, sql, pass_value, tolerance=None):
+ dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1))
+
+ return SQLValueCheckOperator(
+ dag=dag,
+ task_id=self.task_id,
+ conn_id=self.conn_id,
+ sql=sql,
+ pass_value=pass_value,
+ tolerance=tolerance,
+ )
+
+ def test_pass_value_template_string(self):
+ pass_value_str = "2018-03-22"
+ operator = self._construct_operator("select date from tab1;", "{{ ds }}")
+
+ operator.render_template_fields({"ds": pass_value_str})
+
+ assert operator.task_id == self.task_id
+ assert operator.pass_value == pass_value_str
+
+ def test_pass_value_template_string_float(self):
+ pass_value_float = 4.0
+ operator = self._construct_operator("select date from tab1;", pass_value_float)
+
+ operator.render_template_fields({})
+
+ assert operator.task_id == self.task_id
+ assert operator.pass_value == str(pass_value_float)
+
+ @mock.patch.object(SQLValueCheckOperator, "get_db_hook")
+ def test_execute_pass(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.return_value = [10]
+ mock_get_db_hook.return_value = mock_hook
+ sql = "select value from tab1 limit 1;"
+ operator = self._construct_operator(sql, 5, 1)
+
+ operator.execute(None)
+
+ mock_hook.get_first.assert_called_once_with(sql)
+
+ @mock.patch.object(SQLValueCheckOperator, "get_db_hook")
+ def test_execute_fail(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.return_value = [11]
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator("select value from tab1 limit 1;", 5, 1)
+
+ with pytest.raises(AirflowException, match="Tolerance:100.0%"):
+ operator.execute(context=MagicMock())
+
+
+class TestIntervalCheckOperator(unittest.TestCase):
+ def _construct_operator(self, table, metric_thresholds, ratio_formula, ignore_zero):
+ return SQLIntervalCheckOperator(
+ task_id="test_task",
+ table=table,
+ metrics_thresholds=metric_thresholds,
+ ratio_formula=ratio_formula,
+ ignore_zero=ignore_zero,
+ )
+
+ def test_invalid_ratio_formula(self):
+ with pytest.raises(AirflowException, match="Invalid diff_method"):
+ self._construct_operator(
+ table="test_table",
+ metric_thresholds={
+ "f1": 1,
+ },
+ ratio_formula="abs",
+ ignore_zero=False,
+ )
+
+ @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
+ def test_execute_not_ignore_zero(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.return_value = [0]
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator(
+ table="test_table",
+ metric_thresholds={
+ "f1": 1,
+ },
+ ratio_formula="max_over_min",
+ ignore_zero=False,
+ )
+
+ with pytest.raises(AirflowException):
+ operator.execute(context=MagicMock())
+
+ @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
+ def test_execute_ignore_zero(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.return_value = [0]
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator(
+ table="test_table",
+ metric_thresholds={
+ "f1": 1,
+ },
+ ratio_formula="max_over_min",
+ ignore_zero=True,
+ )
+
+ operator.execute(context=MagicMock())
+
+ @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
+ def test_execute_min_max(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+
+ def returned_row():
+ rows = [
+ [2, 2, 2, 2], # reference
+ [1, 1, 1, 1], # current
+ ]
+
+ yield from rows
+
+ mock_hook.get_first.side_effect = returned_row()
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator(
+ table="test_table",
+ metric_thresholds={
+ "f0": 1.0,
+ "f1": 1.5,
+ "f2": 2.0,
+ "f3": 2.5,
+ },
+ ratio_formula="max_over_min",
+ ignore_zero=True,
+ )
+
+ with pytest.raises(AirflowException, match="f0, f1, f2"):
+ operator.execute(context=MagicMock())
+
+ @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
+ def test_execute_diff(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+
+ def returned_row():
+ rows = [
+ [3, 3, 3, 3], # reference
+ [1, 1, 1, 1], # current
+ ]
+
+ yield from rows
+
+ mock_hook.get_first.side_effect = returned_row()
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator(
+ table="test_table",
+ metric_thresholds={
+ "f0": 0.5,
+ "f1": 0.6,
+ "f2": 0.7,
+ "f3": 0.8,
+ },
+ ratio_formula="relative_diff",
+ ignore_zero=True,
+ )
+
+ with pytest.raises(AirflowException, match="f0, f1"):
+ operator.execute(context=MagicMock())
+
+
+class TestThresholdCheckOperator(unittest.TestCase):
+ def _construct_operator(self, sql, min_threshold, max_threshold):
+ dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1))
+
+ return SQLThresholdCheckOperator(
+ task_id="test_task",
+ sql=sql,
+ min_threshold=min_threshold,
+ max_threshold=max_threshold,
+ dag=dag,
+ )
+
+ @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
+ def test_pass_min_value_max_value(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.return_value = (10,)
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator("Select avg(val) from table1 limit 1", 1, 100)
+
+ operator.execute(context=MagicMock())
+
+ @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
+ def test_fail_min_value_max_value(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.return_value = (10,)
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator("Select avg(val) from table1 limit 1", 20, 100)
+
+ with pytest.raises(AirflowException, match="10.*20.0.*100.0"):
+ operator.execute(context=MagicMock())
+
+ @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
+ def test_pass_min_sql_max_sql(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator("Select 10", "Select 1", "Select 100")
+
+ operator.execute(context=MagicMock())
+
+ @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
+ def test_fail_min_sql_max_sql(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator("Select 10", "Select 20", "Select 100")
+
+ with pytest.raises(AirflowException, match="10.*20.*100"):
+ operator.execute(context=MagicMock())
+
+ @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
+ def test_pass_min_value_max_sql(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator("Select 75", 45, "Select 100")
+
+ operator.execute(context=MagicMock())
+
+ @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
+ def test_fail_min_sql_max_value(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator("Select 155", "Select 45", 100)
+
+ with pytest.raises(AirflowException, match="155.*45.*100.0"):
+ operator.execute(context=MagicMock())
+
+
+class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
+ """
+ Test for SQL Branch Operator
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+
+ with create_session() as session:
+ session.query(DagRun).delete()
+ session.query(TI).delete()
+ session.query(XCom).delete()
+
+ def setUp(self):
+ super().setUp()
+ self.dag = DAG(
+ "sql_branch_operator_test",
+ default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+ schedule_interval=INTERVAL,
+ )
+ self.branch_1 = EmptyOperator(task_id="branch_1", dag=self.dag)
+ self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag)
+ self.branch_3 = None
+
+ def tearDown(self):
+ super().tearDown()
+
+ with create_session() as session:
+ session.query(DagRun).delete()
+ session.query(TI).delete()
+ session.query(XCom).delete()
+
+ def test_unsupported_conn_type(self):
+ """Check if BranchSQLOperator throws an exception for unsupported connection type"""
+ op = BranchSQLOperator(
+ task_id="make_choice",
+ conn_id="redis_default",
+ sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ with pytest.raises(AirflowException):
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+ def test_invalid_conn(self):
+ """Check if BranchSQLOperator throws an exception for invalid connection"""
+ op = BranchSQLOperator(
+ task_id="make_choice",
+ conn_id="invalid_connection",
+ sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ with pytest.raises(AirflowException):
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+ def test_invalid_follow_task_true(self):
+ """Check if BranchSQLOperator throws an exception for invalid connection"""
+ op = BranchSQLOperator(
+ task_id="make_choice",
+ conn_id="invalid_connection",
+ sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+ follow_task_ids_if_true=None,
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ with pytest.raises(AirflowException):
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+ def test_invalid_follow_task_false(self):
+ """Check if BranchSQLOperator throws an exception for invalid connection"""
+ op = BranchSQLOperator(
+ task_id="make_choice",
+ conn_id="invalid_connection",
+ sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false=None,
+ dag=self.dag,
+ )
+
+ with pytest.raises(AirflowException):
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+ @pytest.mark.backend("mysql")
+ def test_sql_branch_operator_mysql(self):
+ """Check if BranchSQLOperator works with backend"""
+ branch_op = BranchSQLOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+ @pytest.mark.backend("postgres")
+ def test_sql_branch_operator_postgres(self):
+ """Check if BranchSQLOperator works with backend"""
+ branch_op = BranchSQLOperator(
+ task_id="make_choice",
+ conn_id="postgres_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+ @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
+ def test_branch_single_value_with_dag_run(self, mock_get_db_hook):
+ """Check BranchSQLOperator branch operation"""
+ branch_op = BranchSQLOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ self.branch_1.set_upstream(branch_op)
+ self.branch_2.set_upstream(branch_op)
+ self.dag.clear()
+
+ dr = self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+
+ mock_get_records = mock_get_db_hook.return_value.get_first
+
+ mock_get_records.return_value = 1
+
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == "make_choice":
+ assert ti.state == State.SUCCESS
+ elif ti.task_id == "branch_1":
+ assert ti.state == State.NONE
+ elif ti.task_id == "branch_2":
+ assert ti.state == State.SKIPPED
+ else:
+ raise ValueError(f"Invalid task id {ti.task_id} found!")
+
+ @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
+ def test_branch_true_with_dag_run(self, mock_get_db_hook):
+ """Check BranchSQLOperator branch operation"""
+ branch_op = BranchSQLOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ self.branch_1.set_upstream(branch_op)
+ self.branch_2.set_upstream(branch_op)
+ self.dag.clear()
+
+ dr = self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+
+ mock_get_records = mock_get_db_hook.return_value.get_first
+
+ for true_value in SUPPORTED_TRUE_VALUES:
+ mock_get_records.return_value = true_value
+
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == "make_choice":
+ assert ti.state == State.SUCCESS
+ elif ti.task_id == "branch_1":
+ assert ti.state == State.NONE
+ elif ti.task_id == "branch_2":
+ assert ti.state == State.SKIPPED
+ else:
+ raise ValueError(f"Invalid task id {ti.task_id} found!")
+
+ @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
+ def test_branch_false_with_dag_run(self, mock_get_db_hook):
+ """Check BranchSQLOperator branch operation"""
+ branch_op = BranchSQLOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ self.branch_1.set_upstream(branch_op)
+ self.branch_2.set_upstream(branch_op)
+ self.dag.clear()
+
+ dr = self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+
+ mock_get_records = mock_get_db_hook.return_value.get_first
+
+ for false_value in SUPPORTED_FALSE_VALUES:
+ mock_get_records.return_value = false_value
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == "make_choice":
+ assert ti.state == State.SUCCESS
+ elif ti.task_id == "branch_1":
+ assert ti.state == State.SKIPPED
+ elif ti.task_id == "branch_2":
+ assert ti.state == State.NONE
+ else:
+ raise ValueError(f"Invalid task id {ti.task_id} found!")
+
+ @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
+ def test_branch_list_with_dag_run(self, mock_get_db_hook):
+ """Checks if the BranchSQLOperator supports branching off to a list of tasks."""
+ branch_op = BranchSQLOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true=["branch_1", "branch_2"],
+ follow_task_ids_if_false="branch_3",
+ dag=self.dag,
+ )
+
+ self.branch_1.set_upstream(branch_op)
+ self.branch_2.set_upstream(branch_op)
+ self.branch_3 = EmptyOperator(task_id="branch_3", dag=self.dag)
+ self.branch_3.set_upstream(branch_op)
+ self.dag.clear()
+
+ dr = self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+
+ mock_get_records = mock_get_db_hook.return_value.get_first
+ mock_get_records.return_value = [["1"]]
+
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == "make_choice":
+ assert ti.state == State.SUCCESS
+ elif ti.task_id == "branch_1":
+ assert ti.state == State.NONE
+ elif ti.task_id == "branch_2":
+ assert ti.state == State.NONE
+ elif ti.task_id == "branch_3":
+ assert ti.state == State.SKIPPED
+ else:
+ raise ValueError(f"Invalid task id {ti.task_id} found!")
+
+ @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
+ def test_invalid_query_result_with_dag_run(self, mock_get_db_hook):
+ """Check BranchSQLOperator branch operation"""
+ branch_op = BranchSQLOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ self.branch_1.set_upstream(branch_op)
+ self.branch_2.set_upstream(branch_op)
+ self.dag.clear()
+
+ self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+
+ mock_get_records = mock_get_db_hook.return_value.get_first
+
+ mock_get_records.return_value = ["Invalid Value"]
+
+ with pytest.raises(AirflowException):
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
+ def test_with_skip_in_branch_downstream_dependencies(self, mock_get_db_hook):
+ """Test SQL Branch with skipping all downstream dependencies"""
+ branch_op = BranchSQLOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ branch_op >> self.branch_1 >> self.branch_2
+ branch_op >> self.branch_2
+ self.dag.clear()
+
+ dr = self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+
+ mock_get_records = mock_get_db_hook.return_value.get_first
+
+ for true_value in SUPPORTED_TRUE_VALUES:
+ mock_get_records.return_value = [true_value]
+
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == "make_choice":
+ assert ti.state == State.SUCCESS
+ elif ti.task_id == "branch_1":
+ assert ti.state == State.NONE
+ elif ti.task_id == "branch_2":
+ assert ti.state == State.NONE
+ else:
+ raise ValueError(f"Invalid task id {ti.task_id} found!")
+
+ @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
+ def test_with_skip_in_branch_downstream_dependencies2(self, mock_get_db_hook):
+ """Test skipping downstream dependency for false condition"""
+ branch_op = BranchSQLOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ branch_op >> self.branch_1 >> self.branch_2
+ branch_op >> self.branch_2
+ self.dag.clear()
+
+ dr = self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+
+ mock_get_records = mock_get_db_hook.return_value.get_first
+
+ for false_value in SUPPORTED_FALSE_VALUES:
+ mock_get_records.return_value = [false_value]
+
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == "make_choice":
+ assert ti.state == State.SUCCESS
+ elif ti.task_id == "branch_1":
+ assert ti.state == State.SKIPPED
+ elif ti.task_id == "branch_2":
+ assert ti.state == State.NONE
+ else:
+ raise ValueError(f"Invalid task id {ti.task_id} found!")
diff --git a/tests/providers/qubole/operators/test_qubole_check.py b/tests/providers/qubole/operators/test_qubole_check.py
index ac8f1edb18..3d2c9cbe1d 100644
--- a/tests/providers/qubole/operators/test_qubole_check.py
+++ b/tests/providers/qubole/operators/test_qubole_check.py
@@ -19,19 +19,19 @@
import unittest
from datetime import datetime
from unittest import mock
+from unittest.mock import MagicMock
import pytest
from qds_sdk.commands import HiveCommand
from airflow.exceptions import AirflowException
from airflow.models import DAG
+from airflow.providers.common.sql.operators.sql import SQLCheckOperator, SQLValueCheckOperator
from airflow.providers.qubole.hooks.qubole import QuboleHook
from airflow.providers.qubole.hooks.qubole_check import QuboleCheckHook
from airflow.providers.qubole.operators.qubole_check import (
QuboleCheckOperator,
QuboleValueCheckOperator,
- SQLCheckOperator,
- SQLValueCheckOperator,
_QuboleCheckOperatorMixin,
)
@@ -80,7 +80,7 @@ class TestQuboleCheckMixin:
operator = self.__construct_operator(operator_class=operator_class, **kwargs)
with mock.patch.object(parent_check_operator, 'execute') as mock_execute:
- operator.execute()
+ operator.execute(context=MagicMock())
mock_execute.assert_called_once()
@mock.patch('airflow.providers.qubole.operators.qubole_check.handle_airflow_exception')
@@ -89,7 +89,7 @@ class TestQuboleCheckMixin:
with mock.patch.object(parent_check_operator, 'execute') as mock_execute:
mock_execute.side_effect = AirflowException()
- operator.execute()
+ operator.execute(context=MagicMock())
mock_execute.assert_called_once()
mock_handle_airflow_exception.assert_called_once()
@@ -153,7 +153,7 @@ class TestQuboleValueCheckOperator(unittest.TestCase):
operator = self.__construct_operator('select value from tab1 limit 1;', 5, 1)
with pytest.raises(AirflowException, match='Qubole Command Id: ' + str(mock_cmd.id)):
- operator.execute()
+ operator.execute(context=MagicMock())
mock_cmd.is_success.assert_called_once_with(mock_cmd.status)
@@ -173,7 +173,7 @@ class TestQuboleValueCheckOperator(unittest.TestCase):
operator = self.__construct_operator('select value from tab1 limit 1;', 5, 1)
with pytest.raises(AirflowException) as ctx:
- operator.execute()
+ operator.execute(context=MagicMock())
assert 'Qubole Command Id: ' not in str(ctx.value)
mock_cmd.is_success.assert_called_once_with(mock_cmd.status)
@@ -193,5 +193,5 @@ class TestQuboleValueCheckOperator(unittest.TestCase):
operator = self.__construct_operator(
'select value from tab1 limit 1;', pass_value, None, results_parser_callable
)
- operator.execute()
+ operator.execute(context=MagicMock())
results_parser_callable.assert_called_once_with([pass_value])
diff --git a/tests/providers/slack/operators/test_slack.py b/tests/providers/slack/operators/test_slack.py
index 8b8ac66441..495dc7f6ba 100644
--- a/tests/providers/slack/operators/test_slack.py
+++ b/tests/providers/slack/operators/test_slack.py
@@ -15,10 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
import json
import unittest
from unittest import mock
+from unittest.mock import MagicMock
from airflow.providers.slack.operators.slack import SlackAPIFileOperator, SlackAPIPostOperator
@@ -119,7 +119,7 @@ class TestSlackAPIPostOperator(unittest.TestCase):
slack_conn_id=test_slack_conn_id,
)
- slack_api_post_operator.execute()
+ slack_api_post_operator.execute(context=MagicMock())
expected_api_params = {
'channel': "#general",
@@ -195,7 +195,7 @@ class TestSlackAPIFileOperator(unittest.TestCase):
task_id='slack', slack_conn_id=test_slack_conn_id, content='test-content'
)
- slack_api_post_operator.execute()
+ slack_api_post_operator.execute(context=MagicMock())
expected_api_params = {
'channels': '#general',
@@ -221,7 +221,7 @@ class TestSlackAPIFileOperator(unittest.TestCase):
task_id='slack', slack_conn_id=test_slack_conn_id, filename=file_path, filetype='csv'
)
- slack_api_post_operator.execute()
+ slack_api_post_operator.execute(context=MagicMock())
expected_api_params = {
'channels': '#general',
diff --git a/tests/providers/slack/transfers/test_sql_to_slack.py b/tests/providers/slack/transfers/test_sql_to_slack.py
index 0390a56b18..23b791fb7b 100644
--- a/tests/providers/slack/transfers/test_sql_to_slack.py
+++ b/tests/providers/slack/transfers/test_sql_to_slack.py
@@ -149,7 +149,7 @@ class TestSqlToSlackOperator:
# Test that the Slack hook's execute method gets run once
slack_webhook_hook.execute.assert_called_once()
- @mock.patch('airflow.operators.sql.BaseHook.get_connection')
+ @mock.patch('airflow.providers.common.sql.operators.sql.BaseHook.get_connection')
def test_hook_params_building(self, mock_get_conn):
mock_get_conn.return_value = Connection(conn_id='snowflake_connection', conn_type='snowflake')
hook_params = {
@@ -172,7 +172,7 @@ class TestSqlToSlackOperator:
assert sql_to_slack_operator.sql_hook_params == hook_params
- @mock.patch('airflow.operators.sql.BaseHook.get_connection')
+ @mock.patch('airflow.providers.common.sql.operators.sql.BaseHook.get_connection')
def test_hook_params(self, mock_get_conn):
mock_get_conn.return_value = Connection(conn_id='postgres_test', conn_type='postgres')
op = SqlToSlackOperator(
@@ -188,7 +188,7 @@ class TestSqlToSlackOperator:
hook = op._get_hook()
assert hook.schema == 'public'
- @mock.patch('airflow.operators.sql.BaseHook.get_connection')
+ @mock.patch('airflow.providers.common.sql.operators.sql.BaseHook.get_connection')
def test_hook_params_snowflake(self, mock_get_conn):
mock_get_conn.return_value = Connection(conn_id='snowflake_default', conn_type='snowflake')
op = SqlToSlackOperator(