You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2020/06/20 10:15:09 UTC
[airflow] 07/07: Merging multiple sql operators (#9124)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 0ac4185ca564b0e1a741069bfd670c551efefe31
Author: samuelkhtu <46...@users.noreply.github.com>
AuthorDate: Wed Jun 17 14:32:46 2020 -0400
Merging multiple sql operators (#9124)
* Merge various SQL Operators into sql.py
* Fix unit test code format
* Merge multiple SQL operators into one
1. Merge check_operator.py into airflow.operators.sql
2. Merge sql_branch_operator.py into airflow.operators.sql
3. Merge unit test for both into test_sql.py
* Rename test_core_to_contrib Interval/ValueCheckOperator to SQLInterval/ValueCheckOperator
* Fixed deprecated class and added check to test_core_to_contrib
(cherry picked from commit 0b9bf4a285a074bbde270839a90fb53c257340be)
---
...eea_add_precision_to_execution_date_in_mysql.py | 2 +-
airflow/operators/check_operator.py | 425 +++------------------
airflow/operators/{check_operator.py => sql.py} | 420 +++++++++++++++-----
airflow/operators/sql_branch_operator.py | 162 +-------
docs/operators-and-hooks-ref.rst | 30 +-
tests/api/common/experimental/test_pool.py | 4 +-
tests/contrib/hooks/test_gcp_api_base_hook.py | 2 +-
tests/contrib/hooks/test_gcp_cloud_build_hook.py | 4 +-
tests/contrib/hooks/test_gcp_transfer_hook.py | 4 +-
.../operators/test_gcp_cloud_build_operator.py | 8 +-
tests/contrib/operators/test_gcs_to_gdrive.py | 2 +-
tests/contrib/operators/test_sftp_operator.py | 6 +-
tests/contrib/operators/test_ssh_operator.py | 6 +-
tests/contrib/secrets/test_hashicorp_vault.py | 4 +-
.../contrib/utils/test_gcp_credentials_provider.py | 2 +-
tests/operators/test_check_operator.py | 327 ----------------
.../{test_sql_branch_operator.py => test_sql.py} | 339 ++++++++++++++--
tests/www_rbac/test_validators.py | 4 +-
18 files changed, 707 insertions(+), 1044 deletions(-)
diff --git a/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py b/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py
index ecb589d..59098a8 100644
--- a/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py
+++ b/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py
@@ -29,7 +29,7 @@ from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision = 'a66efa278eea'
-down_revision = '8f966b9c467a'
+down_revision = '952da73b5eff'
branch_labels = None
depends_on = None
diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py
index b6d3a18..12ac472 100644
--- a/airflow/operators/check_operator.py
+++ b/airflow/operators/check_operator.py
@@ -17,409 +17,70 @@
# specific language governing permissions and limitations
# under the License.
-from builtins import str, zip
-from typing import Optional, Any, Iterable, Dict, SupportsAbs
+"""This module is deprecated. Please use `airflow.operators.sql`."""
-from airflow.exceptions import AirflowException
-from airflow.hooks.base_hook import BaseHook
-from airflow.models import BaseOperator
-from airflow.utils.decorators import apply_defaults
+import warnings
+from airflow.operators.sql import (
+ SQLCheckOperator, SQLIntervalCheckOperator, SQLThresholdCheckOperator, SQLValueCheckOperator,
+)
-class CheckOperator(BaseOperator):
- """
- Performs checks against a db. The ``CheckOperator`` expects
- a sql query that will return a single row. Each value on that
- first row is evaluated using python ``bool`` casting. If any of the
- values return ``False`` the check is failed and errors out.
-
- Note that Python bool casting evals the following as ``False``:
-
- * ``False``
- * ``0``
- * Empty string (``""``)
- * Empty list (``[]``)
- * Empty dictionary or set (``{}``)
-
- Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if
- the count ``== 0``. You can craft much more complex query that could,
- for instance, check that the table has the same number of rows as
- the source table upstream, or that the count of today's partition is
- greater than yesterday's partition, or that a set of metrics are less
- than 3 standard deviation for the 7 day average.
-
- This operator can be used as a data quality check in your pipeline, and
- depending on where you put it in your DAG, you have the choice to
- stop the critical path, preventing from
- publishing dubious data, or on the side and receive email alerts
- without stopping the progress of the DAG.
- Note that this is an abstract class and get_db_hook
- needs to be defined. Whereas a get_db_hook is hook that gets a
- single record from an external source.
-
- :param sql: the sql to be executed. (templated)
- :type sql: str
+class CheckOperator(SQLCheckOperator):
"""
-
- template_fields = ('sql',) # type: Iterable[str]
- template_ext = ('.hql', '.sql',) # type: Iterable[str]
- ui_color = '#fff7e6'
-
- @apply_defaults
- def __init__(
- self,
- sql, # type: str
- conn_id=None, # type: Optional[str]
- *args,
- **kwargs
- ):
- super(CheckOperator, self).__init__(*args, **kwargs)
- self.conn_id = conn_id
- self.sql = sql
-
- def execute(self, context=None):
- self.log.info('Executing SQL check: %s', self.sql)
- records = self.get_db_hook().get_first(self.sql)
-
- self.log.info('Record: %s', records)
- if not records:
- raise AirflowException("The query returned None")
- elif not all([bool(r) for r in records]):
- raise AirflowException("Test failed.\nQuery:\n{query}\nResults:\n{records!s}".format(
- query=self.sql, records=records))
-
- self.log.info("Success.")
-
- def get_db_hook(self):
- return BaseHook.get_hook(conn_id=self.conn_id)
-
-
-def _convert_to_float_if_possible(s):
+ This class is deprecated.
+ Please use `airflow.operators.sql.SQLCheckOperator`.
"""
- A small helper function to convert a string to a numeric value
- if appropriate
- :param s: the string to be converted
- :type s: str
- """
- try:
- ret = float(s)
- except (ValueError, TypeError):
- ret = s
- return ret
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ """This class is deprecated.
+ Please use `airflow.operators.sql.SQLCheckOperator`.""",
+ DeprecationWarning, stacklevel=2
+ )
+ super(CheckOperator, self).__init__(*args, **kwargs)
-class ValueCheckOperator(BaseOperator):
+class IntervalCheckOperator(SQLIntervalCheckOperator):
"""
- Performs a simple value check using sql code.
-
- Note that this is an abstract class and get_db_hook
- needs to be defined. Whereas a get_db_hook is hook that gets a
- single record from an external source.
-
- :param sql: the sql to be executed. (templated)
- :type sql: str
+ This class is deprecated.
+ Please use `airflow.operators.sql.SQLIntervalCheckOperator`.
"""
- __mapper_args__ = {
- 'polymorphic_identity': 'ValueCheckOperator'
- }
- template_fields = ('sql', 'pass_value',) # type: Iterable[str]
- template_ext = ('.hql', '.sql',) # type: Iterable[str]
- ui_color = '#fff7e6'
-
- @apply_defaults
- def __init__(
- self,
- sql, # type: str
- pass_value, # type: Any
- tolerance=None, # type: Any
- conn_id=None, # type: Optional[str]
- *args,
- **kwargs
- ):
- super(ValueCheckOperator, self).__init__(*args, **kwargs)
- self.sql = sql
- self.conn_id = conn_id
- self.pass_value = str(pass_value)
- tol = _convert_to_float_if_possible(tolerance)
- self.tol = tol if isinstance(tol, float) else None
- self.has_tolerance = self.tol is not None
-
- def execute(self, context=None):
- self.log.info('Executing SQL check: %s', self.sql)
- records = self.get_db_hook().get_first(self.sql)
-
- if not records:
- raise AirflowException("The query returned None")
-
- pass_value_conv = _convert_to_float_if_possible(self.pass_value)
- is_numeric_value_check = isinstance(pass_value_conv, float)
-
- tolerance_pct_str = str(self.tol * 100) + '%' if self.has_tolerance else None
- error_msg = ("Test failed.\nPass value:{pass_value_conv}\n"
- "Tolerance:{tolerance_pct_str}\n"
- "Query:\n{sql}\nResults:\n{records!s}").format(
- pass_value_conv=pass_value_conv,
- tolerance_pct_str=tolerance_pct_str,
- sql=self.sql,
- records=records
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ """This class is deprecated.
+ Please use `airflow.operators.sql.SQLIntervalCheckOperator`.""",
+ DeprecationWarning, stacklevel=2
)
-
- if not is_numeric_value_check:
- tests = self._get_string_matches(records, pass_value_conv)
- elif is_numeric_value_check:
- try:
- numeric_records = self._to_float(records)
- except (ValueError, TypeError):
- raise AirflowException("Converting a result to float failed.\n{}".format(error_msg))
- tests = self._get_numeric_matches(numeric_records, pass_value_conv)
- else:
- tests = []
-
- if not all(tests):
- raise AirflowException(error_msg)
-
- def _to_float(self, records):
- return [float(record) for record in records]
-
- def _get_string_matches(self, records, pass_value_conv):
- return [str(record) == pass_value_conv for record in records]
-
- def _get_numeric_matches(self, numeric_records, numeric_pass_value_conv):
- if self.has_tolerance:
- return [
- numeric_pass_value_conv * (1 - self.tol) <= record <= numeric_pass_value_conv * (1 + self.tol)
- for record in numeric_records
- ]
-
- return [record == numeric_pass_value_conv for record in numeric_records]
-
- def get_db_hook(self):
- return BaseHook.get_hook(conn_id=self.conn_id)
+ super(IntervalCheckOperator, self).__init__(*args, **kwargs)
-class IntervalCheckOperator(BaseOperator):
+class ThresholdCheckOperator(SQLThresholdCheckOperator):
"""
- Checks that the values of metrics given as SQL expressions are within
- a certain tolerance of the ones from days_back before.
-
- Note that this is an abstract class and get_db_hook
- needs to be defined. Whereas a get_db_hook is hook that gets a
- single record from an external source.
-
- :param table: the table name
- :type table: str
- :param days_back: number of days between ds and the ds we want to check
- against. Defaults to 7 days
- :type days_back: int
- :param ratio_formula: which formula to use to compute the ratio between
- the two metrics. Assuming cur is the metric of today and ref is
- the metric to today - days_back.
-
- max_over_min: computes max(cur, ref) / min(cur, ref)
- relative_diff: computes abs(cur-ref) / ref
-
- Default: 'max_over_min'
- :type ratio_formula: str
- :param ignore_zero: whether we should ignore zero metrics
- :type ignore_zero: bool
- :param metrics_threshold: a dictionary of ratios indexed by metrics
- :type metrics_threshold: dict
+ This class is deprecated.
+ Please use `airflow.operators.sql.SQLThresholdCheckOperator`.
"""
- __mapper_args__ = {
- 'polymorphic_identity': 'IntervalCheckOperator'
- }
- template_fields = ('sql1', 'sql2') # type: Iterable[str]
- template_ext = ('.hql', '.sql',) # type: Iterable[str]
- ui_color = '#fff7e6'
-
- ratio_formulas = {
- 'max_over_min': lambda cur, ref: float(max(cur, ref)) / min(cur, ref),
- 'relative_diff': lambda cur, ref: float(abs(cur - ref)) / ref,
- }
-
- @apply_defaults
- def __init__(
- self,
- table, # type: str
- metrics_thresholds, # type: Dict[str, int]
- date_filter_column='ds', # type: Optional[str]
- days_back=-7, # type: SupportsAbs[int]
- ratio_formula='max_over_min', # type: Optional[str]
- ignore_zero=True, # type: Optional[bool]
- conn_id=None, # type: Optional[str]
- *args, **kwargs
- ):
- super(IntervalCheckOperator, self).__init__(*args, **kwargs)
- if ratio_formula not in self.ratio_formulas:
- msg_template = "Invalid diff_method: {diff_method}. " \
- "Supported diff methods are: {diff_methods}"
-
- raise AirflowException(
- msg_template.format(diff_method=ratio_formula,
- diff_methods=self.ratio_formulas)
- )
- self.ratio_formula = ratio_formula
- self.ignore_zero = ignore_zero
- self.table = table
- self.metrics_thresholds = metrics_thresholds
- self.metrics_sorted = sorted(metrics_thresholds.keys())
- self.date_filter_column = date_filter_column
- self.days_back = -abs(days_back)
- self.conn_id = conn_id
- sqlexp = ', '.join(self.metrics_sorted)
- sqlt = "SELECT {sqlexp} FROM {table} WHERE {date_filter_column}=".format(
- sqlexp=sqlexp, table=table, date_filter_column=date_filter_column
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ """This class is deprecated.
+ Please use `airflow.operators.sql.SQLThresholdCheckOperator`.""",
+ DeprecationWarning, stacklevel=2
)
-
- self.sql1 = sqlt + "'{{ ds }}'"
- self.sql2 = sqlt + "'{{ macros.ds_add(ds, " + str(self.days_back) + ") }}'"
-
- def execute(self, context=None):
- hook = self.get_db_hook()
- self.log.info('Using ratio formula: %s', self.ratio_formula)
- self.log.info('Executing SQL check: %s', self.sql2)
- row2 = hook.get_first(self.sql2)
- self.log.info('Executing SQL check: %s', self.sql1)
- row1 = hook.get_first(self.sql1)
-
- if not row2:
- raise AirflowException("The query {} returned None".format(self.sql2))
- if not row1:
- raise AirflowException("The query {} returned None".format(self.sql1))
-
- current = dict(zip(self.metrics_sorted, row1))
- reference = dict(zip(self.metrics_sorted, row2))
-
- ratios = {}
- test_results = {}
-
- for m in self.metrics_sorted:
- cur = current[m]
- ref = reference[m]
- threshold = self.metrics_thresholds[m]
- if cur == 0 or ref == 0:
- ratios[m] = None
- test_results[m] = self.ignore_zero
- else:
- ratios[m] = self.ratio_formulas[self.ratio_formula](current[m], reference[m])
- test_results[m] = ratios[m] < threshold
-
- self.log.info(
- (
- "Current metric for %s: %s\n"
- "Past metric for %s: %s\n"
- "Ratio for %s: %s\n"
- "Threshold: %s\n"
- ), m, cur, m, ref, m, ratios[m], threshold)
-
- if not all(test_results.values()):
- failed_tests = [it[0] for it in test_results.items() if not it[1]]
- j = len(failed_tests)
- n = len(self.metrics_sorted)
- self.log.warning("The following %s tests out of %s failed:", j, n)
- for k in failed_tests:
- self.log.warning(
- "'%s' check failed. %s is above %s", k, ratios[k], self.metrics_thresholds[k]
- )
- raise AirflowException("The following tests have failed:\n {0}".format(", ".join(
- sorted(failed_tests))))
-
- self.log.info("All tests have passed")
-
- def get_db_hook(self):
- return BaseHook.get_hook(conn_id=self.conn_id)
+ super(ThresholdCheckOperator, self).__init__(*args, **kwargs)
-class ThresholdCheckOperator(BaseOperator):
+class ValueCheckOperator(SQLValueCheckOperator):
"""
- Performs a value check using sql code against a mininmum threshold
- and a maximum threshold. Thresholds can be in the form of a numeric
- value OR a sql statement that results a numeric.
-
- Note that this is an abstract class and get_db_hook
- needs to be defined. Whereas a get_db_hook is hook that gets a
- single record from an external source.
-
- :param sql: the sql to be executed. (templated)
- :type sql: str
- :param min_threshold: numerical value or min threshold sql to be executed (templated)
- :type min_threshold: numeric or str
- :param max_threshold: numerical value or max threshold sql to be executed (templated)
- :type max_threshold: numeric or str
+ This class is deprecated.
+ Please use `airflow.operators.sql.SQLValueCheckOperator`.
"""
- template_fields = ('sql', 'min_threshold', 'max_threshold') # type: Iterable[str]
- template_ext = ('.hql', '.sql',) # type: Iterable[str]
-
- @apply_defaults
- def __init__(
- self,
- sql, # type: str
- min_threshold, # type: Any
- max_threshold, # type: Any
- conn_id=None, # type: Optional[str]
- *args, **kwargs
- ):
- super(ThresholdCheckOperator, self).__init__(*args, **kwargs)
- self.sql = sql
- self.conn_id = conn_id
- self.min_threshold = _convert_to_float_if_possible(min_threshold)
- self.max_threshold = _convert_to_float_if_possible(max_threshold)
-
- def execute(self, context=None):
- hook = self.get_db_hook()
- result = hook.get_first(self.sql)[0][0]
-
- if isinstance(self.min_threshold, float):
- lower_bound = self.min_threshold
- else:
- lower_bound = hook.get_first(self.min_threshold)[0][0]
-
- if isinstance(self.max_threshold, float):
- upper_bound = self.max_threshold
- else:
- upper_bound = hook.get_first(self.max_threshold)[0][0]
-
- meta_data = {
- "result": result,
- "task_id": self.task_id,
- "min_threshold": lower_bound,
- "max_threshold": upper_bound,
- "within_threshold": lower_bound <= result <= upper_bound
- }
-
- self.push(meta_data)
- if not meta_data["within_threshold"]:
- error_msg = (
- 'Threshold Check: "{task_id}" failed.\n'
- 'DAG: {dag_id}\nTask_id: {task_id}\n'
- 'Check description: {description}\n'
- 'SQL: {sql}\n'
- 'Result: {result} is not within thresholds '
- '{min_threshold} and {max_threshold}'
- ).format(
- task_id=self.task_id, dag_id=self.dag_id,
- description=meta_data.get("description"), sql=self.sql,
- result=round(meta_data.get("result"), 2),
- min_threshold=meta_data.get("min_threshold"),
- max_threshold=meta_data.get("max_threshold")
- )
- raise AirflowException(error_msg)
-
- self.log.info("Test %s Successful.", self.task_id)
-
- def push(self, meta_data):
- """
- Optional: Send data check info and metadata to an external database.
- Default functionality will log metadata.
- """
-
- info = "\n".join(["""{}: {}""".format(key, item) for key, item in meta_data.items()])
- self.log.info("Log from %s:\n%s", self.dag_id, info)
-
- def get_db_hook(self):
- return BaseHook.get_hook(conn_id=self.conn_id)
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ """This class is deprecated.
+ Please use `airflow.operators.sql.SQLValueCheckOperator`.""",
+ DeprecationWarning, stacklevel=2
+ )
+ super(ValueCheckOperator, self).__init__(*args, **kwargs)
diff --git a/airflow/operators/check_operator.py b/airflow/operators/sql.py
similarity index 50%
copy from airflow/operators/check_operator.py
copy to airflow/operators/sql.py
index b6d3a18..91ddc1a 100644
--- a/airflow/operators/check_operator.py
+++ b/airflow/operators/sql.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
@@ -16,19 +15,31 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-from builtins import str, zip
-from typing import Optional, Any, Iterable, Dict, SupportsAbs
+from distutils.util import strtobool
+from typing import Iterable
from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
-from airflow.models import BaseOperator
+from airflow.models import BaseOperator, SkipMixin
from airflow.utils.decorators import apply_defaults
-
-class CheckOperator(BaseOperator):
+ALLOWED_CONN_TYPE = {
+ "google_cloud_platform",
+ "jdbc",
+ "mssql",
+ "mysql",
+ "odbc",
+ "oracle",
+ "postgres",
+ "presto",
+ "sqlite",
+ "vertica",
+}
+
+
+class SQLCheckOperator(BaseOperator):
"""
- Performs checks against a db. The ``CheckOperator`` expects
+ Performs checks against a db. The ``SQLCheckOperator`` expects
a sql query that will return a single row. Each value on that
first row is evaluated using python ``bool`` casting. If any of the
values return ``False`` the check is failed and errors out.
@@ -62,36 +73,44 @@ class CheckOperator(BaseOperator):
:type sql: str
"""
- template_fields = ('sql',) # type: Iterable[str]
- template_ext = ('.hql', '.sql',) # type: Iterable[str]
- ui_color = '#fff7e6'
+ template_fields = ("sql",) # type: Iterable[str]
+ template_ext = (
+ ".hql",
+ ".sql",
+ ) # type: Iterable[str]
+ ui_color = "#fff7e6"
@apply_defaults
def __init__(
- self,
- sql, # type: str
- conn_id=None, # type: Optional[str]
- *args,
- **kwargs
+ self, sql, conn_id=None, *args, **kwargs
):
- super(CheckOperator, self).__init__(*args, **kwargs)
+ super(SQLCheckOperator, self).__init__(*args, **kwargs)
self.conn_id = conn_id
self.sql = sql
def execute(self, context=None):
- self.log.info('Executing SQL check: %s', self.sql)
+ self.log.info("Executing SQL check: %s", self.sql)
records = self.get_db_hook().get_first(self.sql)
- self.log.info('Record: %s', records)
+ self.log.info("Record: %s", records)
if not records:
raise AirflowException("The query returned None")
elif not all([bool(r) for r in records]):
- raise AirflowException("Test failed.\nQuery:\n{query}\nResults:\n{records!s}".format(
- query=self.sql, records=records))
+ raise AirflowException(
+ "Test failed.\nQuery:\n{query}\nResults:\n{records!s}".format(
+ query=self.sql, records=records
+ )
+ )
self.log.info("Success.")
def get_db_hook(self):
+ """
+ Get the database hook for the connection.
+
+ :return: the database hook object.
+ :rtype: DbApiHook
+ """
return BaseHook.get_hook(conn_id=self.conn_id)
@@ -110,7 +129,7 @@ def _convert_to_float_if_possible(s):
return ret
-class ValueCheckOperator(BaseOperator):
+class SQLValueCheckOperator(BaseOperator):
"""
Performs a simple value check using sql code.
@@ -122,24 +141,28 @@ class ValueCheckOperator(BaseOperator):
:type sql: str
"""
- __mapper_args__ = {
- 'polymorphic_identity': 'ValueCheckOperator'
- }
- template_fields = ('sql', 'pass_value',) # type: Iterable[str]
- template_ext = ('.hql', '.sql',) # type: Iterable[str]
- ui_color = '#fff7e6'
+ __mapper_args__ = {"polymorphic_identity": "SQLValueCheckOperator"}
+ template_fields = (
+ "sql",
+ "pass_value",
+ ) # type: Iterable[str]
+ template_ext = (
+ ".hql",
+ ".sql",
+ ) # type: Iterable[str]
+ ui_color = "#fff7e6"
@apply_defaults
def __init__(
- self,
- sql, # type: str
- pass_value, # type: Any
- tolerance=None, # type: Any
- conn_id=None, # type: Optional[str]
- *args,
- **kwargs
- ):
- super(ValueCheckOperator, self).__init__(*args, **kwargs)
+ self,
+ sql,
+ pass_value,
+ tolerance=None,
+ conn_id=None,
+ *args,
+ **kwargs
+ ):
+ super(SQLValueCheckOperator, self).__init__(*args, **kwargs)
self.sql = sql
self.conn_id = conn_id
self.pass_value = str(pass_value)
@@ -148,7 +171,7 @@ class ValueCheckOperator(BaseOperator):
self.has_tolerance = self.tol is not None
def execute(self, context=None):
- self.log.info('Executing SQL check: %s', self.sql)
+ self.log.info("Executing SQL check: %s", self.sql)
records = self.get_db_hook().get_first(self.sql)
if not records:
@@ -157,14 +180,16 @@ class ValueCheckOperator(BaseOperator):
pass_value_conv = _convert_to_float_if_possible(self.pass_value)
is_numeric_value_check = isinstance(pass_value_conv, float)
- tolerance_pct_str = str(self.tol * 100) + '%' if self.has_tolerance else None
- error_msg = ("Test failed.\nPass value:{pass_value_conv}\n"
- "Tolerance:{tolerance_pct_str}\n"
- "Query:\n{sql}\nResults:\n{records!s}").format(
+ tolerance_pct_str = str(self.tol * 100) + "%" if self.has_tolerance else None
+ error_msg = (
+ "Test failed.\nPass value:{pass_value_conv}\n"
+ "Tolerance:{tolerance_pct_str}\n"
+ "Query:\n{sql}\nResults:\n{records!s}"
+ ).format(
pass_value_conv=pass_value_conv,
tolerance_pct_str=tolerance_pct_str,
sql=self.sql,
- records=records
+ records=records,
)
if not is_numeric_value_check:
@@ -173,7 +198,9 @@ class ValueCheckOperator(BaseOperator):
try:
numeric_records = self._to_float(records)
except (ValueError, TypeError):
- raise AirflowException("Converting a result to float failed.\n{}".format(error_msg))
+ raise AirflowException(
+ "Converting a result to float failed.\n{}".format(error_msg)
+ )
tests = self._get_numeric_matches(numeric_records, pass_value_conv)
else:
tests = []
@@ -197,10 +224,16 @@ class ValueCheckOperator(BaseOperator):
return [record == numeric_pass_value_conv for record in numeric_records]
def get_db_hook(self):
+ """
+ Get the database hook for the connection.
+
+ :return: the database hook object.
+ :rtype: DbApiHook
+ """
return BaseHook.get_hook(conn_id=self.conn_id)
-class IntervalCheckOperator(BaseOperator):
+class SQLIntervalCheckOperator(BaseOperator):
"""
Checks that the values of metrics given as SQL expressions are within
a certain tolerance of the ones from days_back before.
@@ -229,38 +262,43 @@ class IntervalCheckOperator(BaseOperator):
:type metrics_threshold: dict
"""
- __mapper_args__ = {
- 'polymorphic_identity': 'IntervalCheckOperator'
- }
- template_fields = ('sql1', 'sql2') # type: Iterable[str]
- template_ext = ('.hql', '.sql',) # type: Iterable[str]
- ui_color = '#fff7e6'
+ __mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"}
+ template_fields = ("sql1", "sql2") # type: Iterable[str]
+ template_ext = (
+ ".hql",
+ ".sql",
+ ) # type: Iterable[str]
+ ui_color = "#fff7e6"
ratio_formulas = {
- 'max_over_min': lambda cur, ref: float(max(cur, ref)) / min(cur, ref),
- 'relative_diff': lambda cur, ref: float(abs(cur - ref)) / ref,
+ "max_over_min": lambda cur, ref: float(max(cur, ref)) / min(cur, ref),
+ "relative_diff": lambda cur, ref: float(abs(cur - ref)) / ref,
}
@apply_defaults
def __init__(
self,
- table, # type: str
- metrics_thresholds, # type: Dict[str, int]
- date_filter_column='ds', # type: Optional[str]
- days_back=-7, # type: SupportsAbs[int]
- ratio_formula='max_over_min', # type: Optional[str]
- ignore_zero=True, # type: Optional[bool]
- conn_id=None, # type: Optional[str]
- *args, **kwargs
+ table,
+ metrics_thresholds,
+ date_filter_column="ds",
+ days_back=-7,
+ ratio_formula="max_over_min",
+ ignore_zero=True,
+ conn_id=None,
+ *args,
+ **kwargs
):
- super(IntervalCheckOperator, self).__init__(*args, **kwargs)
+ super(SQLIntervalCheckOperator, self).__init__(*args, **kwargs)
if ratio_formula not in self.ratio_formulas:
- msg_template = "Invalid diff_method: {diff_method}. " \
- "Supported diff methods are: {diff_methods}"
+ msg_template = (
+ "Invalid diff_method: {diff_method}. "
+ "Supported diff methods are: {diff_methods}"
+ )
raise AirflowException(
- msg_template.format(diff_method=ratio_formula,
- diff_methods=self.ratio_formulas)
+ msg_template.format(
+ diff_method=ratio_formula, diff_methods=self.ratio_formulas
+ )
)
self.ratio_formula = ratio_formula
self.ignore_zero = ignore_zero
@@ -270,7 +308,7 @@ class IntervalCheckOperator(BaseOperator):
self.date_filter_column = date_filter_column
self.days_back = -abs(days_back)
self.conn_id = conn_id
- sqlexp = ', '.join(self.metrics_sorted)
+ sqlexp = ", ".join(self.metrics_sorted)
sqlt = "SELECT {sqlexp} FROM {table} WHERE {date_filter_column}=".format(
sqlexp=sqlexp, table=table, date_filter_column=date_filter_column
)
@@ -280,10 +318,10 @@ class IntervalCheckOperator(BaseOperator):
def execute(self, context=None):
hook = self.get_db_hook()
- self.log.info('Using ratio formula: %s', self.ratio_formula)
- self.log.info('Executing SQL check: %s', self.sql2)
+ self.log.info("Using ratio formula: %s", self.ratio_formula)
+ self.log.info("Executing SQL check: %s", self.sql2)
row2 = hook.get_first(self.sql2)
- self.log.info('Executing SQL check: %s', self.sql1)
+ self.log.info("Executing SQL check: %s", self.sql1)
row1 = hook.get_first(self.sql1)
if not row2:
@@ -297,16 +335,18 @@ class IntervalCheckOperator(BaseOperator):
ratios = {}
test_results = {}
- for m in self.metrics_sorted:
- cur = current[m]
- ref = reference[m]
- threshold = self.metrics_thresholds[m]
+ for metric in self.metrics_sorted:
+ cur = current[metric]
+ ref = reference[metric]
+ threshold = self.metrics_thresholds[metric]
if cur == 0 or ref == 0:
- ratios[m] = None
- test_results[m] = self.ignore_zero
+ ratios[metric] = None
+ test_results[metric] = self.ignore_zero
else:
- ratios[m] = self.ratio_formulas[self.ratio_formula](current[m], reference[m])
- test_results[m] = ratios[m] < threshold
+ ratios[metric] = self.ratio_formulas[self.ratio_formula](
+ current[metric], reference[metric]
+ )
+ test_results[metric] = ratios[metric] < threshold
self.log.info(
(
@@ -314,27 +354,49 @@ class IntervalCheckOperator(BaseOperator):
"Past metric for %s: %s\n"
"Ratio for %s: %s\n"
"Threshold: %s\n"
- ), m, cur, m, ref, m, ratios[m], threshold)
+ ),
+ metric,
+ cur,
+ metric,
+ ref,
+ metric,
+ ratios[metric],
+ threshold,
+ )
if not all(test_results.values()):
failed_tests = [it[0] for it in test_results.items() if not it[1]]
- j = len(failed_tests)
- n = len(self.metrics_sorted)
- self.log.warning("The following %s tests out of %s failed:", j, n)
+ self.log.warning(
+ "The following %s tests out of %s failed:",
+ len(failed_tests),
+ len(self.metrics_sorted),
+ )
for k in failed_tests:
self.log.warning(
- "'%s' check failed. %s is above %s", k, ratios[k], self.metrics_thresholds[k]
+ "'%s' check failed. %s is above %s",
+ k,
+ ratios[k],
+ self.metrics_thresholds[k],
+ )
+ raise AirflowException(
+ "The following tests have failed:\n {0}".format(
+ ", ".join(sorted(failed_tests))
)
- raise AirflowException("The following tests have failed:\n {0}".format(", ".join(
- sorted(failed_tests))))
+ )
self.log.info("All tests have passed")
def get_db_hook(self):
+ """
+ Get the database hook for the connection.
+
+ :return: the database hook object.
+ :rtype: DbApiHook
+ """
return BaseHook.get_hook(conn_id=self.conn_id)
-class ThresholdCheckOperator(BaseOperator):
+class SQLThresholdCheckOperator(BaseOperator):
"""
Performs a value check using sql code against a mininmum threshold
and a maximum threshold. Thresholds can be in the form of a numeric
@@ -352,19 +414,23 @@ class ThresholdCheckOperator(BaseOperator):
:type max_threshold: numeric or str
"""
- template_fields = ('sql', 'min_threshold', 'max_threshold') # type: Iterable[str]
- template_ext = ('.hql', '.sql',) # type: Iterable[str]
+ template_fields = ("sql", "min_threshold", "max_threshold") # type: Iterable[str]
+ template_ext = (
+ ".hql",
+ ".sql",
+ ) # type: Iterable[str]
@apply_defaults
def __init__(
self,
- sql, # type: str
- min_threshold, # type: Any
- max_threshold, # type: Any
- conn_id=None, # type: Optional[str]
- *args, **kwargs
+ sql,
+ min_threshold,
+ max_threshold,
+ conn_id=None,
+ *args,
+ **kwargs
):
- super(ThresholdCheckOperator, self).__init__(*args, **kwargs)
+ super(SQLThresholdCheckOperator, self).__init__(*args, **kwargs)
self.sql = sql
self.conn_id = conn_id
self.min_threshold = _convert_to_float_if_possible(min_threshold)
@@ -389,7 +455,7 @@ class ThresholdCheckOperator(BaseOperator):
"task_id": self.task_id,
"min_threshold": lower_bound,
"max_threshold": upper_bound,
- "within_threshold": lower_bound <= result <= upper_bound
+ "within_threshold": lower_bound <= result <= upper_bound,
}
self.push(meta_data)
@@ -398,16 +464,17 @@ class ThresholdCheckOperator(BaseOperator):
'Threshold Check: "{task_id}" failed.\n'
'DAG: {dag_id}\nTask_id: {task_id}\n'
'Check description: {description}\n'
- 'SQL: {sql}\n'
- 'Result: {result} is not within thresholds '
- '{min_threshold} and {max_threshold}'
- ).format(
- task_id=self.task_id, dag_id=self.dag_id,
- description=meta_data.get("description"), sql=self.sql,
- result=round(meta_data.get("result"), 2),
- min_threshold=meta_data.get("min_threshold"),
- max_threshold=meta_data.get("max_threshold")
- )
+ "SQL: {sql}\n"
+ 'Result: {round} is not within thresholds '
+ '{min} and {max}'
+ .format(task_id=meta_data.get("task_id"),
+ dag_id=self.dag_id,
+ description=meta_data.get("description"),
+ sql=self.sql,
+ round=round(meta_data.get("result"), 2),
+ min=meta_data.get("min_threshold"),
+ max=meta_data.get("max_threshold"),
+ ))
raise AirflowException(error_msg)
self.log.info("Test %s Successful.", self.task_id)
@@ -418,8 +485,149 @@ class ThresholdCheckOperator(BaseOperator):
Default functionality will log metadata.
"""
- info = "\n".join(["""{}: {}""".format(key, item) for key, item in meta_data.items()])
+ info = "\n".join(["{key}: {item}".format(key=key, item=item) for key, item in meta_data.items()])
self.log.info("Log from %s:\n%s", self.dag_id, info)
def get_db_hook(self):
+ """
+ Returns DB hook
+ """
return BaseHook.get_hook(conn_id=self.conn_id)
+
+
+class BranchSQLOperator(BaseOperator, SkipMixin):
+ """
+ Executes sql code in a specific database
+
+ :param sql: the sql code to be executed. (templated)
+ :type sql: Can receive a str representing a sql statement or reference to a template file.
+ Template reference are recognized by str ending in '.sql'.
+ Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+ or string (true/y/yes/1/on/false/n/no/0/off).
+ :param follow_task_ids_if_true: task id or task ids to follow if query return true
+ :type follow_task_ids_if_true: str or list
+ :param follow_task_ids_if_false: task id or task ids to follow if query return true
+ :type follow_task_ids_if_false: str or list
+ :param conn_id: reference to a specific database
+ :type conn_id: str
+ :param database: name of database which overwrite defined one in connection
+ :param parameters: (optional) the parameters to render the SQL query with.
+ :type parameters: mapping or iterable
+ """
+
+ template_fields = ("sql",)
+ template_ext = (".sql",)
+ ui_color = "#a22034"
+ ui_fgcolor = "#F7F7F7"
+
+ @apply_defaults
+ def __init__(
+ self,
+ sql,
+ follow_task_ids_if_true,
+ follow_task_ids_if_false,
+ conn_id="default_conn_id",
+ database=None,
+ parameters=None,
+ *args,
+ **kwargs
+ ):
+ super(BranchSQLOperator, self).__init__(*args, **kwargs)
+ self.conn_id = conn_id
+ self.sql = sql
+ self.parameters = parameters
+ self.follow_task_ids_if_true = follow_task_ids_if_true
+ self.follow_task_ids_if_false = follow_task_ids_if_false
+ self.database = database
+ self._hook = None
+
+ def _get_hook(self):
+ self.log.debug("Get connection for %s", self.conn_id)
+ conn = BaseHook.get_connection(self.conn_id)
+
+ if conn.conn_type not in ALLOWED_CONN_TYPE:
+ raise AirflowException(
+ "The connection type is not supported by BranchSQLOperator.\
+ Supported connection types: {}".format(list(ALLOWED_CONN_TYPE))
+ )
+
+ if not self._hook:
+ self._hook = conn.get_hook()
+ if self.database:
+ self._hook.schema = self.database
+
+ return self._hook
+
+ def execute(self, context):
+ # get supported hook
+ self._hook = self._get_hook()
+
+ if self._hook is None:
+ raise AirflowException(
+ "Failed to establish connection to '%s'" % self.conn_id
+ )
+
+ if self.sql is None:
+ raise AirflowException("Expected 'sql' parameter is missing.")
+
+ if self.follow_task_ids_if_true is None:
+ raise AirflowException(
+ "Expected 'follow_task_ids_if_true' paramter is missing."
+ )
+
+ if self.follow_task_ids_if_false is None:
+ raise AirflowException(
+ "Expected 'follow_task_ids_if_false' parameter is missing."
+ )
+
+ self.log.info(
+ "Executing: %s (with parameters %s) with connection: %s",
+ self.sql,
+ self.parameters,
+ self._hook,
+ )
+ record = self._hook.get_first(self.sql, self.parameters)
+ if not record:
+ raise AirflowException(
+ "No rows returned from sql query. Operator expected True or False return value."
+ )
+
+ if isinstance(record, list):
+ if isinstance(record[0], list):
+ query_result = record[0][0]
+ else:
+ query_result = record[0]
+ elif isinstance(record, tuple):
+ query_result = record[0]
+ else:
+ query_result = record
+
+ self.log.info("Query returns %s, type '%s'", query_result, type(query_result))
+
+ follow_branch = None
+ try:
+ if isinstance(query_result, bool):
+ if query_result:
+ follow_branch = self.follow_task_ids_if_true
+ elif isinstance(query_result, str):
+ # return result is not Boolean, try to convert from String to Boolean
+ if bool(strtobool(query_result)):
+ follow_branch = self.follow_task_ids_if_true
+ elif isinstance(query_result, int):
+ if bool(query_result):
+ follow_branch = self.follow_task_ids_if_true
+ else:
+ raise AirflowException(
+ "Unexpected query return result '%s' type '%s'"
+ % (query_result, type(query_result))
+ )
+
+ if follow_branch is None:
+ follow_branch = self.follow_task_ids_if_false
+ except ValueError:
+ raise AirflowException(
+ "Unexpected query return result '%s' type '%s'"
+ % (query_result, type(query_result))
+ )
+
+ self.skip_all_except(context["ti"], follow_branch)
diff --git a/airflow/operators/sql_branch_operator.py b/airflow/operators/sql_branch_operator.py
index 072c40c..b911e34 100644
--- a/airflow/operators/sql_branch_operator.py
+++ b/airflow/operators/sql_branch_operator.py
@@ -14,160 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""This module is deprecated. Please use `airflow.operators.sql`."""
+import warnings
-from distutils.util import strtobool
+from airflow.operators.sql import BranchSQLOperator
-from airflow.exceptions import AirflowException
-from airflow.hooks.base_hook import BaseHook
-from airflow.models import BaseOperator, SkipMixin
-from airflow.utils.decorators import apply_defaults
-ALLOWED_CONN_TYPE = {
- "google_cloud_platform",
- "jdbc",
- "mssql",
- "mysql",
- "odbc",
- "oracle",
- "postgres",
- "presto",
- "sqlite",
- "vertica",
-}
-
-
-class BranchSqlOperator(BaseOperator, SkipMixin):
+class BranchSqlOperator(BranchSQLOperator):
"""
- Executes sql code in a specific database
-
- :param sql: the sql code to be executed. (templated)
- :type sql: Can receive a str representing a sql statement or reference to a template file.
- Template reference are recognized by str ending in '.sql'.
- Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
- or string (true/y/yes/1/on/false/n/no/0/off).
- :param follow_task_ids_if_true: task id or task ids to follow if query return true
- :type follow_task_ids_if_true: str or list
- :param follow_task_ids_if_false: task id or task ids to follow if query return true
- :type follow_task_ids_if_false: str or list
- :param conn_id: reference to a specific database
- :type conn_id: str
- :param database: name of database which overwrite defined one in connection
- :param parameters: (optional) the parameters to render the SQL query with.
- :type parameters: mapping or iterable
+ This class is deprecated.
+ Please use `airflow.operators.sql.BranchSQLOperator`.
"""
- template_fields = ("sql",)
- template_ext = (".sql",)
- ui_color = "#a22034"
- ui_fgcolor = "#F7F7F7"
-
- @apply_defaults
- def __init__(
- self,
- sql,
- follow_task_ids_if_true,
- follow_task_ids_if_false,
- conn_id="default_conn_id",
- database=None,
- parameters=None,
- *args,
- **kwargs):
- super(BranchSqlOperator, self).__init__(*args, **kwargs)
- self.conn_id = conn_id
- self.sql = sql
- self.parameters = parameters
- self.follow_task_ids_if_true = follow_task_ids_if_true
- self.follow_task_ids_if_false = follow_task_ids_if_false
- self.database = database
- self._hook = None
-
- def _get_hook(self):
- self.log.debug("Get connection for %s", self.conn_id)
- conn = BaseHook.get_connection(self.conn_id)
-
- if conn.conn_type not in ALLOWED_CONN_TYPE:
- raise AirflowException(
- "The connection type is not supported by BranchSqlOperator. "
- + "Supported connection types: {}".format(list(ALLOWED_CONN_TYPE))
- )
-
- if not self._hook:
- self._hook = conn.get_hook()
- if self.database:
- self._hook.schema = self.database
-
- return self._hook
-
- def execute(self, context):
- # get supported hook
- self._hook = self._get_hook()
-
- if self._hook is None:
- raise AirflowException(
- "Failed to establish connection to '%s'" % self.conn_id
- )
-
- if self.sql is None:
- raise AirflowException("Expected 'sql' parameter is missing.")
-
- if self.follow_task_ids_if_true is None:
- raise AirflowException(
- "Expected 'follow_task_ids_if_true' paramter is missing."
- )
-
- if self.follow_task_ids_if_false is None:
- raise AirflowException(
- "Expected 'follow_task_ids_if_false' parameter is missing."
- )
-
- self.log.info(
- "Executing: %s (with parameters %s) with connection: %s",
- self.sql,
- self.parameters,
- self._hook,
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ """This class is deprecated.
+ Please use `airflow.operators.sql.BranchSQLOperator`.""",
+ DeprecationWarning, stacklevel=2
)
- record = self._hook.get_first(self.sql, self.parameters)
- if not record:
- raise AirflowException(
- "No rows returned from sql query. Operator expected True or False return value."
- )
-
- if isinstance(record, list):
- if isinstance(record[0], list):
- query_result = record[0][0]
- else:
- query_result = record[0]
- elif isinstance(record, tuple):
- query_result = record[0]
- else:
- query_result = record
-
- self.log.info("Query returns %s, type '%s'", query_result, type(query_result))
-
- follow_branch = None
- try:
- if isinstance(query_result, bool):
- if query_result:
- follow_branch = self.follow_task_ids_if_true
- elif isinstance(query_result, str):
- # return result is not Boolean, try to convert from String to Boolean
- if bool(strtobool(query_result)):
- follow_branch = self.follow_task_ids_if_true
- elif isinstance(query_result, int):
- if bool(query_result):
- follow_branch = self.follow_task_ids_if_true
- else:
- raise AirflowException(
- "Unexpected query return result '%s' type '%s'"
- % (query_result, type(query_result))
- )
-
- if follow_branch is None:
- follow_branch = self.follow_task_ids_if_false
- except ValueError:
- raise AirflowException(
- "Unexpected query return result '%s' type '%s'"
- % (query_result, type(query_result))
- )
-
- self.skip_all_except(context["ti"], follow_branch)
+ super(BranchSqlOperator, self).__init__(*args, **kwargs)
diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst
index 55176f8..1fd11c3 100644
--- a/docs/operators-and-hooks-ref.rst
+++ b/docs/operators-and-hooks-ref.rst
@@ -57,10 +57,6 @@ Fundamentals
* - :mod:`airflow.operators.branch_operator`
-
-
- * - :mod:`airflow.operators.check_operator`
- -
-
* - :mod:`airflow.operators.dagrun_operator`
-
@@ -76,7 +72,7 @@ Fundamentals
* - :mod:`airflow.operators.subdag_operator`
-
- * - :mod:`airflow.operators.sql_branch_operator`
+ * - :mod:`airflow.operators.sql`
-
**Sensors:**
@@ -90,9 +86,6 @@ Fundamentals
* - :mod:`airflow.sensors.weekday_sensor`
-
- * - :mod:`airflow.sensors.external_task_sensor`
- - :doc:`How to use <howto/operator/external_task_sensor>`
-
* - :mod:`airflow.sensors.sql_sensor`
-
@@ -470,7 +463,7 @@ These integrations allow you to copy data from/to Amazon Web Services.
* - `Amazon Simple Storage Service (S3) <https://aws.amazon.com/s3/>`__
- `Google Cloud Storage (GCS) <https://cloud.google.com/gcs/>`__
- - :doc:`How to use <howto/operator/gcp/cloud_storage_transfer_service>`
+ -
- :mod:`airflow.contrib.operators.s3_to_gcs_operator`,
:mod:`airflow.gcp.operators.cloud_storage_transfer_service`
@@ -551,7 +544,7 @@ These integrations allow you to perform various operations within the Google Clo
- Sensors
* - `AutoML <https://cloud.google.com/automl/>`__
- - :doc:`How to use <howto/operator/gcp/automl>`
+ -
- :mod:`airflow.gcp.hooks.automl`
- :mod:`airflow.gcp.operators.automl`
-
@@ -563,7 +556,7 @@ These integrations allow you to perform various operations within the Google Clo
- :mod:`airflow.gcp.sensors.bigquery`
* - `BigQuery Data Transfer Service <https://cloud.google.com/bigquery/transfer/>`__
- - :doc:`How to use <howto/operator/gcp/bigquery_dts>`
+ -
- :mod:`airflow.gcp.hooks.bigquery_dts`
- :mod:`airflow.gcp.operators.bigquery_dts`
- :mod:`airflow.gcp.sensors.bigquery_dts`
@@ -611,7 +604,7 @@ These integrations allow you to perform various operations within the Google Clo
-
* - `Cloud Functions <https://cloud.google.com/functions/>`__
- - :doc:`How to use <howto/operator/gcp/functions>`
+ - :doc:`How to use <howto/operator/gcp/function>`
- :mod:`airflow.gcp.hooks.functions`
- :mod:`airflow.gcp.operators.functions`
-
@@ -635,7 +628,7 @@ These integrations allow you to perform various operations within the Google Clo
-
* - `Cloud Memorystore <https://cloud.google.com/memorystore/>`__
- - :doc:`How to use <howto/operator/gcp/cloud_memorystore>`
+ -
- :mod:`airflow.gcp.hooks.cloud_memorystore`
- :mod:`airflow.gcp.operators.cloud_memorystore`
-
@@ -677,7 +670,7 @@ These integrations allow you to perform various operations within the Google Clo
- :mod:`airflow.gcp.sensors.gcs`
* - `Storage Transfer Service <https://cloud.google.com/storage/transfer/>`__
- - :doc:`How to use <howto/operator/gcp/cloud_storage_transfer_service>`
+ -
- :mod:`airflow.gcp.hooks.cloud_storage_transfer_service`
- :mod:`airflow.gcp.operators.cloud_storage_transfer_service`
- :mod:`airflow.gcp.sensors.cloud_storage_transfer_service`
@@ -701,7 +694,7 @@ These integrations allow you to perform various operations within the Google Clo
-
* - `Cloud Video Intelligence <https://cloud.google.com/video_intelligence/>`__
- - :doc:`How to use <howto/operator/gcp/video_intelligence>`
+ - :doc:`How to use <howto/operator/gcp/video>`
- :mod:`airflow.gcp.hooks.video_intelligence`
- :mod:`airflow.gcp.operators.video_intelligence`
-
@@ -741,7 +734,7 @@ These integrations allow you to copy data from/to Google Cloud Platform.
* - `Amazon Simple Storage Service (S3) <https://aws.amazon.com/s3/>`__
- `Google Cloud Storage (GCS) <https://cloud.google.com/gcs/>`__
- - :doc:`How to use <howto/operator/gcp/cloud_storage_transfer_service>`
+ -
- :mod:`airflow.contrib.operators.s3_to_gcs_operator`,
:mod:`airflow.gcp.operators.cloud_storage_transfer_service`
@@ -772,8 +765,7 @@ These integrations allow you to copy data from/to Google Cloud Platform.
* - `Google Cloud Storage (GCS) <https://cloud.google.com/gcs/>`__
- `Google Cloud Storage (GCS) <https://cloud.google.com/gcs/>`__
- - :doc:`How to use <howto/operator/gcp/gcs_to_gcs>`,
- :doc:`How to use <howto/operator/gcp/cloud_storage_transfer_service>`
+ -
- :mod:`airflow.operators.gcs_to_gcs`,
:mod:`airflow.gcp.operators.cloud_storage_transfer_service`
@@ -1037,7 +1029,7 @@ These integrations allow you to perform various operations using various softwar
- :mod:`airflow.contrib.sensors.bash_sensor`
* - `Kubernetes <https://kubernetes.io/>`__
- - :doc:`How to use <howto/operator/kubernetes>`
+ -
-
- :mod:`airflow.contrib.operators.kubernetes_pod_operator`
-
diff --git a/tests/api/common/experimental/test_pool.py b/tests/api/common/experimental/test_pool.py
index 29c7105..79944de 100644
--- a/tests/api/common/experimental/test_pool.py
+++ b/tests/api/common/experimental/test_pool.py
@@ -131,8 +131,8 @@ class TestPool(unittest.TestCase):
name=name)
def test_delete_default_pool_not_allowed(self):
- with self.assertRaisesRegex(AirflowBadRequest,
- "^default_pool cannot be deleted$"):
+ with self.assertRaisesRegexp(AirflowBadRequest,
+ "^default_pool cannot be deleted$"):
pool_api.delete_pool(Pool.DEFAULT_POOL_NAME)
diff --git a/tests/contrib/hooks/test_gcp_api_base_hook.py b/tests/contrib/hooks/test_gcp_api_base_hook.py
index e3f99b5..9fc9924 100644
--- a/tests/contrib/hooks/test_gcp_api_base_hook.py
+++ b/tests/contrib/hooks/test_gcp_api_base_hook.py
@@ -98,7 +98,7 @@ class QuotaRetryTestCase(unittest.TestCase): # ptlint: disable=invalid-name
self.assertEqual(5, custom_fn.counter)
def test_raise_exception_on_non_quota_exception(self):
- with six.assertRaisesRegex(self, Forbidden, "Daily Limit Exceeded"):
+ with six.assertRaisesRegexp(self, Forbidden, "Daily Limit Exceeded"):
message = "POST https://translation.googleapis.com/language/translate/v2: Daily Limit Exceeded"
errors = [
{'message': 'Daily Limit Exceeded', 'domain': 'usageLimits', 'reason': 'dailyLimitExceeded'}
diff --git a/tests/contrib/hooks/test_gcp_cloud_build_hook.py b/tests/contrib/hooks/test_gcp_cloud_build_hook.py
index 81c7ae0..6d5aa16 100644
--- a/tests/contrib/hooks/test_gcp_cloud_build_hook.py
+++ b/tests/contrib/hooks/test_gcp_cloud_build_hook.py
@@ -117,7 +117,7 @@ class TestCloudBuildHookWithPassedProjectId(unittest.TestCase):
execute_mock = mock.Mock(**{"side_effect": [TEST_WAITING_OPERATION, TEST_ERROR_OPERATION]})
service_mock.operations.return_value.get.return_value.execute = execute_mock
- with six.assertRaisesRegex(self, AirflowException, "error"):
+ with six.assertRaisesRegexp(self, AirflowException, "error"):
self.hook.create_build(body={})
@@ -186,7 +186,7 @@ class TestGcpComputeHookWithDefaultProjectIdFromConnection(unittest.TestCase):
execute_mock = mock.Mock(**{"side_effect": [TEST_WAITING_OPERATION, TEST_ERROR_OPERATION]})
service_mock.operations.return_value.get.return_value.execute = execute_mock
- with six.assertRaisesRegex(self, AirflowException, "error"):
+ with six.assertRaisesRegexp(self, AirflowException, "error"):
self.hook.create_build(body={})
diff --git a/tests/contrib/hooks/test_gcp_transfer_hook.py b/tests/contrib/hooks/test_gcp_transfer_hook.py
index ab0cfbf..e78d67b 100644
--- a/tests/contrib/hooks/test_gcp_transfer_hook.py
+++ b/tests/contrib/hooks/test_gcp_transfer_hook.py
@@ -265,7 +265,7 @@ class TestGCPTransferServiceHookWithPassedProjectId(unittest.TestCase):
}
get_conn.return_value.transferOperations.return_value.list_next.return_value = None
- with six.assertRaisesRegex(
+ with six.assertRaisesRegexp(
self, AirflowException, "An unexpected operation status was encountered. Expected: SUCCESS"
):
self.gct_hook.wait_for_transfer_job(
@@ -298,7 +298,7 @@ class TestGCPTransferServiceHookWithPassedProjectId(unittest.TestCase):
def test_operations_contain_expected_statuses_red_path(self, statuses, expected_statuses):
operations = [{NAME: TEST_TRANSFER_OPERATION_NAME, METADATA: {STATUS: status}} for status in statuses]
- with six.assertRaisesRegex(
+ with six.assertRaisesRegexp(
self,
AirflowException,
"An unexpected operation status was encountered. Expected: {}".format(
diff --git a/tests/contrib/operators/test_gcp_cloud_build_operator.py b/tests/contrib/operators/test_gcp_cloud_build_operator.py
index 2136757..4a7ca73 100644
--- a/tests/contrib/operators/test_gcp_cloud_build_operator.py
+++ b/tests/contrib/operators/test_gcp_cloud_build_operator.py
@@ -42,7 +42,7 @@ TEST_PROJECT_ID = "example-id"
class BuildProcessorTestCase(TestCase):
def test_verify_source(self):
- with six.assertRaisesRegex(self, AirflowException, "The source could not be determined."):
+ with six.assertRaisesRegexp(self, AirflowException, "The source could not be determined."):
BuildProcessor(body={"source": {"storageSource": {}, "repoSource": {}}}).process_body()
@parameterized.expand(
@@ -77,7 +77,7 @@ class BuildProcessorTestCase(TestCase):
)
def test_convert_repo_url_to_storage_dict_invalid(self, url):
body = {"source": {"repoSource": url}}
- with six.assertRaisesRegex(self, AirflowException, "Invalid URL."):
+ with six.assertRaisesRegexp(self, AirflowException, "Invalid URL."):
BuildProcessor(body=body).process_body()
@parameterized.expand(
@@ -102,7 +102,7 @@ class BuildProcessorTestCase(TestCase):
)
def test_convert_storage_url_to_dict_invalid(self, url):
body = {"source": {"storageSource": url}}
- with six.assertRaisesRegex(self, AirflowException, "Invalid URL."):
+ with six.assertRaisesRegexp(self, AirflowException, "Invalid URL."):
BuildProcessor(body=body).process_body()
@parameterized.expand([("storageSource",), ("repoSource",)])
@@ -128,7 +128,7 @@ class GcpCloudBuildCreateBuildOperatorTestCase(TestCase):
@parameterized.expand([({},), (None,)])
def test_missing_input(self, body):
- with six.assertRaisesRegex(self, AirflowException, "The required parameter 'body' is missing"):
+ with six.assertRaisesRegexp(self, AirflowException, "The required parameter 'body' is missing"):
CloudBuildCreateBuildOperator(body=body, project_id=TEST_PROJECT_ID, task_id="task-id")
@mock.patch( # type: ignore
diff --git a/tests/contrib/operators/test_gcs_to_gdrive.py b/tests/contrib/operators/test_gcs_to_gdrive.py
index 4f49055..03a4e66 100644
--- a/tests/contrib/operators/test_gcs_to_gdrive.py
+++ b/tests/contrib/operators/test_gcs_to_gdrive.py
@@ -147,5 +147,5 @@ class TestGcsToGDriveOperator(unittest.TestCase):
task = GcsToGDriveOperator(
task_id="move_files", source_bucket="data", source_object="sales/*/*.avro", move_object=True
)
- with six.assertRaisesRegex(self, AirflowException, "Only one wildcard"):
+ with six.assertRaisesRegexp(self, AirflowException, "Only one wildcard"):
task.execute(mock.MagicMock())
diff --git a/tests/contrib/operators/test_sftp_operator.py b/tests/contrib/operators/test_sftp_operator.py
index 24db36e..d597d8d 100644
--- a/tests/contrib/operators/test_sftp_operator.py
+++ b/tests/contrib/operators/test_sftp_operator.py
@@ -363,9 +363,9 @@ class SFTPOperatorTest(unittest.TestCase):
# Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
if six.PY2:
- self.assertRaisesRegex = self.assertRaisesRegexp
- with self.assertRaisesRegex(AirflowException,
- "Cannot operate without ssh_hook or ssh_conn_id."):
+ self.assertRaisesRegexp = self.assertRaisesRegexp
+ with self.assertRaisesRegexp(AirflowException,
+ "Cannot operate without ssh_hook or ssh_conn_id."):
task_0 = SFTPOperator(
task_id="test_sftp",
local_filepath=self.test_local_filepath,
diff --git a/tests/contrib/operators/test_ssh_operator.py b/tests/contrib/operators/test_ssh_operator.py
index f2294ba..aa113ec 100644
--- a/tests/contrib/operators/test_ssh_operator.py
+++ b/tests/contrib/operators/test_ssh_operator.py
@@ -153,9 +153,9 @@ class SSHOperatorTest(TestCase):
# Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
if six.PY2:
- self.assertRaisesRegex = self.assertRaisesRegexp
- with self.assertRaisesRegex(AirflowException,
- "Cannot operate without ssh_hook or ssh_conn_id."):
+ self.assertRaisesRegexp = self.assertRaisesRegexp
+ with self.assertRaisesRegexp(AirflowException,
+ "Cannot operate without ssh_hook or ssh_conn_id."):
task_0 = SSHOperator(task_id="test", command=COMMAND,
timeout=TIMEOUT, dag=self.dag)
task_0.execute(None)
diff --git a/tests/contrib/secrets/test_hashicorp_vault.py b/tests/contrib/secrets/test_hashicorp_vault.py
index 0d52c3a..1887db3 100644
--- a/tests/contrib/secrets/test_hashicorp_vault.py
+++ b/tests/contrib/secrets/test_hashicorp_vault.py
@@ -217,7 +217,7 @@ class TestVaultSecrets(TestCase):
"token": "test_wrong_token"
}
- with six.assertRaisesRegex(self, VaultError, "Vault Authentication Error!"):
+ with six.assertRaisesRegexp(self, VaultError, "Vault Authentication Error!"):
VaultBackend(**kwargs).get_connections(conn_id='test')
@mock.patch("airflow.contrib.secrets.hashicorp_vault.hvac")
@@ -232,5 +232,5 @@ class TestVaultSecrets(TestCase):
"url": "http://127.0.0.1:8200",
}
- with six.assertRaisesRegex(self, VaultError, "token cannot be None for auth_type='token'"):
+ with six.assertRaisesRegexp(self, VaultError, "token cannot be None for auth_type='token'"):
VaultBackend(**kwargs).get_connections(conn_id='test')
diff --git a/tests/contrib/utils/test_gcp_credentials_provider.py b/tests/contrib/utils/test_gcp_credentials_provider.py
index 3478a42..40e180b 100644
--- a/tests/contrib/utils/test_gcp_credentials_provider.py
+++ b/tests/contrib/utils/test_gcp_credentials_provider.py
@@ -97,7 +97,7 @@ class TestGetGcpCredentialsAndProjectId(unittest.TestCase):
def test_get_credentials_and_project_id_with_mutually_exclusive_configuration(
self,
):
- with six.assertRaisesRegex(self, AirflowException, re.escape(
+ with six.assertRaisesRegexp(self, AirflowException, re.escape(
'The `keyfile_dict` and `key_path` fields are mutually exclusive.'
)):
get_credentials_and_project_id(key_path='KEY.json', keyfile_dict={'private_key': 'PRIVATE_KEY'})
diff --git a/tests/operators/test_check_operator.py b/tests/operators/test_check_operator.py
deleted file mode 100644
index 22523a4..0000000
--- a/tests/operators/test_check_operator.py
+++ /dev/null
@@ -1,327 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import six
-import unittest
-from datetime import datetime
-
-from airflow.exceptions import AirflowException
-from airflow.models import DAG
-from airflow.operators.check_operator import (
- CheckOperator, IntervalCheckOperator, ThresholdCheckOperator, ValueCheckOperator,
-)
-from tests.compat import mock
-
-
-class TestCheckOperator(unittest.TestCase):
-
- @mock.patch.object(CheckOperator, 'get_db_hook')
- def test_execute_no_records(self, mock_get_db_hook):
- mock_get_db_hook.return_value.get_first.return_value = []
-
- with self.assertRaises(AirflowException):
- CheckOperator(sql='sql').execute()
-
- @mock.patch.object(CheckOperator, 'get_db_hook')
- def test_execute_not_all_records_are_true(self, mock_get_db_hook):
- mock_get_db_hook.return_value.get_first.return_value = ["data", ""]
-
- with self.assertRaises(AirflowException):
- CheckOperator(sql='sql').execute()
-
-
-class TestValueCheckOperator(unittest.TestCase):
-
- def setUp(self):
- self.task_id = 'test_task'
- self.conn_id = 'default_conn'
-
- def _construct_operator(self, sql, pass_value, tolerance=None):
- dag = DAG('test_dag', start_date=datetime(2017, 1, 1))
-
- return ValueCheckOperator(
- dag=dag,
- task_id=self.task_id,
- conn_id=self.conn_id,
- sql=sql,
- pass_value=pass_value,
- tolerance=tolerance)
-
- def test_pass_value_template_string(self):
- pass_value_str = "2018-03-22"
- operator = self._construct_operator('select date from tab1;', "{{ ds }}")
-
- operator.render_template_fields({'ds': pass_value_str})
-
- self.assertEqual(operator.task_id, self.task_id)
- self.assertEqual(operator.pass_value, pass_value_str)
-
- def test_pass_value_template_string_float(self):
- pass_value_float = 4.0
- operator = self._construct_operator('select date from tab1;', pass_value_float)
-
- operator.render_template_fields({})
-
- self.assertEqual(operator.task_id, self.task_id)
- self.assertEqual(operator.pass_value, str(pass_value_float))
-
- @mock.patch.object(ValueCheckOperator, 'get_db_hook')
- def test_execute_pass(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.return_value = [10]
- mock_get_db_hook.return_value = mock_hook
- sql = 'select value from tab1 limit 1;'
- operator = self._construct_operator(sql, 5, 1)
-
- operator.execute(None)
-
- mock_hook.get_first.assert_called_with(sql)
-
- @mock.patch.object(ValueCheckOperator, 'get_db_hook')
- def test_execute_fail(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.return_value = [11]
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator('select value from tab1 limit 1;', 5, 1)
-
- with self.assertRaisesRegexp(AirflowException, 'Tolerance:100.0%'):
- operator.execute()
-
-
-class IntervalCheckOperatorTest(unittest.TestCase):
-
- def _construct_operator(self, table, metric_thresholds,
- ratio_formula, ignore_zero):
- return IntervalCheckOperator(
- task_id='test_task',
- table=table,
- metrics_thresholds=metric_thresholds,
- ratio_formula=ratio_formula,
- ignore_zero=ignore_zero,
- )
-
- def test_invalid_ratio_formula(self):
- with self.assertRaisesRegexp(AirflowException, 'Invalid diff_method'):
- self._construct_operator(
- table='test_table',
- metric_thresholds={
- 'f1': 1,
- },
- ratio_formula='abs',
- ignore_zero=False,
- )
-
- @mock.patch.object(IntervalCheckOperator, 'get_db_hook')
- def test_execute_not_ignore_zero(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.return_value = [0]
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator(
- table='test_table',
- metric_thresholds={
- 'f1': 1,
- },
- ratio_formula='max_over_min',
- ignore_zero=False,
- )
-
- with self.assertRaises(AirflowException):
- operator.execute()
-
- @mock.patch.object(IntervalCheckOperator, 'get_db_hook')
- def test_execute_ignore_zero(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.return_value = [0]
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator(
- table='test_table',
- metric_thresholds={
- 'f1': 1,
- },
- ratio_formula='max_over_min',
- ignore_zero=True,
- )
-
- operator.execute()
-
- @mock.patch.object(IntervalCheckOperator, 'get_db_hook')
- def test_execute_min_max(self, mock_get_db_hook):
- mock_hook = mock.Mock()
-
- def returned_row():
- rows = [
- [2, 2, 2, 2], # reference
- [1, 1, 1, 1], # current
- ]
-
- for r in rows:
- yield r
-
- mock_hook.get_first.side_effect = returned_row()
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator(
- table='test_table',
- metric_thresholds={
- 'f0': 1.0,
- 'f1': 1.5,
- 'f2': 2.0,
- 'f3': 2.5,
- },
- ratio_formula='max_over_min',
- ignore_zero=True,
- )
-
- with self.assertRaisesRegexp(AirflowException, "f0, f1, f2"):
- operator.execute()
-
- @mock.patch.object(IntervalCheckOperator, 'get_db_hook')
- def test_execute_diff(self, mock_get_db_hook):
- mock_hook = mock.Mock()
-
- def returned_row():
- rows = [
- [3, 3, 3, 3], # reference
- [1, 1, 1, 1], # current
- ]
-
- for r in rows:
- yield r
-
- mock_hook.get_first.side_effect = returned_row()
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator(
- table='test_table',
- metric_thresholds={
- 'f0': 0.5,
- 'f1': 0.6,
- 'f2': 0.7,
- 'f3': 0.8,
- },
- ratio_formula='relative_diff',
- ignore_zero=True,
- )
-
- with self.assertRaisesRegexp(AirflowException, "f0, f1"):
- operator.execute()
-
-
-class TestThresholdCheckOperator(unittest.TestCase):
-
- def _construct_operator(self, sql, min_threshold, max_threshold):
- dag = DAG('test_dag', start_date=datetime(2017, 1, 1))
-
- return ThresholdCheckOperator(
- task_id='test_task',
- sql=sql,
- min_threshold=min_threshold,
- max_threshold=max_threshold,
- dag=dag
- )
-
- @mock.patch.object(ThresholdCheckOperator, 'get_db_hook')
- def test_pass_min_value_max_value(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.return_value = [(10,)]
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator(
- 'Select avg(val) from table1 limit 1',
- 1,
- 100
- )
-
- operator.execute()
-
- @mock.patch.object(ThresholdCheckOperator, 'get_db_hook')
- def test_fail_min_value_max_value(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.return_value = [(10,)]
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator(
- 'Select avg(val) from table1 limit 1',
- 20,
- 100
- )
-
- with six.assertRaisesRegex(self, AirflowException, '10.*20.0.*100.0'):
- operator.execute()
-
- @mock.patch.object(ThresholdCheckOperator, 'get_db_hook')
- def test_pass_min_sql_max_sql(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)]
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator(
- 'Select 10',
- 'Select 1',
- 'Select 100'
- )
-
- operator.execute()
-
- @mock.patch.object(ThresholdCheckOperator, 'get_db_hook')
- def test_fail_min_sql_max_sql(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)]
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator(
- 'Select 10',
- 'Select 20',
- 'Select 100'
- )
-
- with six.assertRaisesRegex(self, AirflowException, '10.*20.*100'):
- operator.execute()
-
- @mock.patch.object(ThresholdCheckOperator, 'get_db_hook')
- def test_pass_min_value_max_sql(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)]
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator(
- 'Select 75',
- 45,
- 'Select 100'
- )
-
- operator.execute()
-
- @mock.patch.object(ThresholdCheckOperator, 'get_db_hook')
- def test_fail_min_sql_max_value(self, mock_get_db_hook):
- mock_hook = mock.Mock()
- mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)]
- mock_get_db_hook.return_value = mock_hook
-
- operator = self._construct_operator(
- 'Select 155',
- 'Select 45',
- 100
- )
-
- with six.assertRaisesRegex(self, AirflowException, '155.*45.*100.0'):
- operator.execute()
diff --git a/tests/operators/test_sql_branch_operator.py b/tests/operators/test_sql.py
similarity index 57%
rename from tests/operators/test_sql_branch_operator.py
rename to tests/operators/test_sql.py
index 6510609..ca88883 100644
--- a/tests/operators/test_sql_branch_operator.py
+++ b/tests/operators/test_sql.py
@@ -24,8 +24,11 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.models import DAG, DagRun, TaskInstance as TI
+from airflow.operators.check_operator import (
+ CheckOperator, IntervalCheckOperator, ThresholdCheckOperator, ValueCheckOperator,
+)
from airflow.operators.dummy_operator import DummyOperator
-from airflow.operators.sql_branch_operator import BranchSqlOperator
+from airflow.operators.sql import BranchSQLOperator
from airflow.utils import timezone
from airflow.utils.db import create_session
from airflow.utils.state import State
@@ -60,6 +63,266 @@ SUPPORTED_FALSE_VALUES = [
]
+class TestCheckOperator(unittest.TestCase):
+ @mock.patch.object(CheckOperator, "get_db_hook")
+ def test_execute_no_records(self, mock_get_db_hook):
+ mock_get_db_hook.return_value.get_first.return_value = []
+
+ with self.assertRaises(AirflowException):
+ CheckOperator(sql="sql").execute()
+
+ @mock.patch.object(CheckOperator, "get_db_hook")
+ def test_execute_not_all_records_are_true(self, mock_get_db_hook):
+ mock_get_db_hook.return_value.get_first.return_value = ["data", ""]
+
+ with self.assertRaises(AirflowException):
+ CheckOperator(sql="sql").execute()
+
+
+class TestValueCheckOperator(unittest.TestCase):
+ def setUp(self):
+ self.task_id = "test_task"
+ self.conn_id = "default_conn"
+
+ def _construct_operator(self, sql, pass_value, tolerance=None):
+ dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1))
+
+ return ValueCheckOperator(
+ dag=dag,
+ task_id=self.task_id,
+ conn_id=self.conn_id,
+ sql=sql,
+ pass_value=pass_value,
+ tolerance=tolerance,
+ )
+
+ def test_pass_value_template_string(self):
+ pass_value_str = "2018-03-22"
+ operator = self._construct_operator(
+ "select date from tab1;", "{{ ds }}")
+
+ operator.render_template_fields({"ds": pass_value_str})
+
+ self.assertEqual(operator.task_id, self.task_id)
+ self.assertEqual(operator.pass_value, pass_value_str)
+
+ def test_pass_value_template_string_float(self):
+ pass_value_float = 4.0
+ operator = self._construct_operator(
+ "select date from tab1;", pass_value_float)
+
+ operator.render_template_fields({})
+
+ self.assertEqual(operator.task_id, self.task_id)
+ self.assertEqual(operator.pass_value, str(pass_value_float))
+
+ @mock.patch.object(ValueCheckOperator, "get_db_hook")
+ def test_execute_pass(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.return_value = [10]
+ mock_get_db_hook.return_value = mock_hook
+ sql = "select value from tab1 limit 1;"
+ operator = self._construct_operator(sql, 5, 1)
+
+ operator.execute(None)
+
+ mock_hook.get_first.assert_called_once_with(sql)
+
+ @mock.patch.object(ValueCheckOperator, "get_db_hook")
+ def test_execute_fail(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.return_value = [11]
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator(
+ "select value from tab1 limit 1;", 5, 1)
+
+ with self.assertRaisesRegexp(AirflowException, "Tolerance:100.0%"):
+ operator.execute()
+
+
+class TestIntervalCheckOperator(unittest.TestCase):
+ def _construct_operator(self, table, metric_thresholds, ratio_formula, ignore_zero):
+ return IntervalCheckOperator(
+ task_id="test_task",
+ table=table,
+ metrics_thresholds=metric_thresholds,
+ ratio_formula=ratio_formula,
+ ignore_zero=ignore_zero,
+ )
+
+ def test_invalid_ratio_formula(self):
+ with self.assertRaisesRegexp(AirflowException, "Invalid diff_method"):
+ self._construct_operator(
+ table="test_table",
+ metric_thresholds={"f1": 1, },
+ ratio_formula="abs",
+ ignore_zero=False,
+ )
+
+ @mock.patch.object(IntervalCheckOperator, "get_db_hook")
+ def test_execute_not_ignore_zero(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.return_value = [0]
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator(
+ table="test_table",
+ metric_thresholds={"f1": 1, },
+ ratio_formula="max_over_min",
+ ignore_zero=False,
+ )
+
+ with self.assertRaises(AirflowException):
+ operator.execute()
+
+ @mock.patch.object(IntervalCheckOperator, "get_db_hook")
+ def test_execute_ignore_zero(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.return_value = [0]
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator(
+ table="test_table",
+ metric_thresholds={"f1": 1, },
+ ratio_formula="max_over_min",
+ ignore_zero=True,
+ )
+
+ operator.execute()
+
+ @mock.patch.object(IntervalCheckOperator, "get_db_hook")
+ def test_execute_min_max(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+
+ def returned_row():
+ rows = [
+ [2, 2, 2, 2], # reference
+ [1, 1, 1, 1], # current
+ ]
+ return rows
+
+ mock_hook.get_first.side_effect = returned_row()
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator(
+ table="test_table",
+ metric_thresholds={"f0": 1.0, "f1": 1.5, "f2": 2.0, "f3": 2.5, },
+ ratio_formula="max_over_min",
+ ignore_zero=True,
+ )
+
+ with self.assertRaisesRegexp(AirflowException, "f0, f1, f2"):
+ operator.execute()
+
+ @mock.patch.object(IntervalCheckOperator, "get_db_hook")
+ def test_execute_diff(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+
+ def returned_row():
+ rows = [
+ [3, 3, 3, 3], # reference
+ [1, 1, 1, 1], # current
+ ]
+
+ return rows
+
+ mock_hook.get_first.side_effect = returned_row()
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator(
+ table="test_table",
+ metric_thresholds={"f0": 0.5, "f1": 0.6, "f2": 0.7, "f3": 0.8, },
+ ratio_formula="relative_diff",
+ ignore_zero=True,
+ )
+
+ with self.assertRaisesRegexp(AirflowException, "f0, f1"):
+ operator.execute()
+
+
+class TestThresholdCheckOperator(unittest.TestCase):
+ def _construct_operator(self, sql, min_threshold, max_threshold):
+ dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1))
+
+ return ThresholdCheckOperator(
+ task_id="test_task",
+ sql=sql,
+ min_threshold=min_threshold,
+ max_threshold=max_threshold,
+ dag=dag,
+ )
+
+ @mock.patch.object(ThresholdCheckOperator, "get_db_hook")
+ def test_pass_min_value_max_value(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.return_value = [(10,)]
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator(
+ "Select avg(val) from table1 limit 1", 1, 100
+ )
+
+ operator.execute()
+
+ @mock.patch.object(ThresholdCheckOperator, "get_db_hook")
+ def test_fail_min_value_max_value(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.return_value = [(10,)]
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator(
+ "Select avg(val) from table1 limit 1", 20, 100
+ )
+
+ with self.assertRaisesRegexp(AirflowException, "10.*20.0.*100.0"):
+ operator.execute()
+
+ @mock.patch.object(ThresholdCheckOperator, "get_db_hook")
+ def test_pass_min_sql_max_sql(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)]
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator(
+ "Select 10", "Select 1", "Select 100")
+
+ operator.execute()
+
+ @mock.patch.object(ThresholdCheckOperator, "get_db_hook")
+ def test_fail_min_sql_max_sql(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)]
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator(
+ "Select 10", "Select 20", "Select 100")
+
+ with self.assertRaisesRegexp(AirflowException, "10.*20.*100"):
+ operator.execute()
+
+ @mock.patch.object(ThresholdCheckOperator, "get_db_hook")
+ def test_pass_min_value_max_sql(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)]
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator("Select 75", 45, "Select 100")
+
+ operator.execute()
+
+ @mock.patch.object(ThresholdCheckOperator, "get_db_hook")
+ def test_fail_min_sql_max_value(self, mock_get_db_hook):
+ mock_hook = mock.Mock()
+ mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)]
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self._construct_operator("Select 155", "Select 45", 100)
+
+ with self.assertRaisesRegexp(AirflowException, "155.*45.*100.0"):
+ operator.execute()
+
+
class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
"""
Test for SQL Branch Operator
@@ -92,8 +355,8 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
session.query(TI).delete()
def test_unsupported_conn_type(self):
- """ Check if BranchSqlOperator throws an exception for unsupported connection type """
- op = BranchSqlOperator(
+ """ Check if BranchSQLOperator throws an exception for unsupported connection type """
+ op = BranchSQLOperator(
task_id="make_choice",
conn_id="redis_default",
sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
@@ -103,11 +366,12 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
)
with self.assertRaises(AirflowException):
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+ op.run(start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_invalid_conn(self):
- """ Check if BranchSqlOperator throws an exception for invalid connection """
- op = BranchSqlOperator(
+ """ Check if BranchSQLOperator throws an exception for invalid connection """
+ op = BranchSQLOperator(
task_id="make_choice",
conn_id="invalid_connection",
sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
@@ -117,11 +381,12 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
)
with self.assertRaises(AirflowException):
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+ op.run(start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_invalid_follow_task_true(self):
- """ Check if BranchSqlOperator throws an exception for invalid connection """
- op = BranchSqlOperator(
+ """ Check if BranchSQLOperator throws an exception for invalid connection """
+ op = BranchSQLOperator(
task_id="make_choice",
conn_id="invalid_connection",
sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
@@ -131,11 +396,12 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
)
with self.assertRaises(AirflowException):
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+ op.run(start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_invalid_follow_task_false(self):
- """ Check if BranchSqlOperator throws an exception for invalid connection """
- op = BranchSqlOperator(
+ """ Check if BranchSQLOperator throws an exception for invalid connection """
+ op = BranchSQLOperator(
task_id="make_choice",
conn_id="invalid_connection",
sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
@@ -145,12 +411,13 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
)
with self.assertRaises(AirflowException):
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+ op.run(start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE, ignore_ti_state=True)
@pytest.mark.backend("mysql")
def test_sql_branch_operator_mysql(self):
- """ Check if BranchSqlOperator works with backend """
- branch_op = BranchSqlOperator(
+ """ Check if BranchSQLOperator works with backend """
+ branch_op = BranchSQLOperator(
task_id="make_choice",
conn_id="mysql_default",
sql="SELECT 1",
@@ -164,8 +431,8 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
@pytest.mark.backend("postgres")
def test_sql_branch_operator_postgres(self):
- """ Check if BranchSqlOperator works with backend """
- branch_op = BranchSqlOperator(
+ """ Check if BranchSQLOperator works with backend """
+ branch_op = BranchSQLOperator(
task_id="make_choice",
conn_id="postgres_default",
sql="SELECT 1",
@@ -177,10 +444,10 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
)
- @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+ @mock.patch("airflow.operators.sql.BaseHook")
def test_branch_single_value_with_dag_run(self, mock_hook):
- """ Check BranchSqlOperator branch operation """
- branch_op = BranchSqlOperator(
+ """ Check BranchSQLOperator branch operation """
+ branch_op = BranchSQLOperator(
task_id="make_choice",
conn_id="mysql_default",
sql="SELECT 1",
@@ -220,10 +487,10 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
else:
raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
- @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+ @mock.patch("airflow.operators.sql.BaseHook")
def test_branch_true_with_dag_run(self, mock_hook):
- """ Check BranchSqlOperator branch operation """
- branch_op = BranchSqlOperator(
+ """ Check BranchSQLOperator branch operation """
+ branch_op = BranchSQLOperator(
task_id="make_choice",
conn_id="mysql_default",
sql="SELECT 1",
@@ -264,10 +531,10 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
else:
raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
- @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+ @mock.patch("airflow.operators.sql.BaseHook")
def test_branch_false_with_dag_run(self, mock_hook):
- """ Check BranchSqlOperator branch operation """
- branch_op = BranchSqlOperator(
+ """ Check BranchSQLOperator branch operation """
+ branch_op = BranchSQLOperator(
task_id="make_choice",
conn_id="mysql_default",
sql="SELECT 1",
@@ -308,10 +575,10 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
else:
raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
- @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+ @mock.patch("airflow.operators.sql.BaseHook")
def test_branch_list_with_dag_run(self, mock_hook):
- """ Checks if the BranchSqlOperator supports branching off to a list of tasks."""
- branch_op = BranchSqlOperator(
+ """ Checks if the BranchSQLOperator supports branching off to a list of tasks."""
+ branch_op = BranchSQLOperator(
task_id="make_choice",
conn_id="mysql_default",
sql="SELECT 1",
@@ -354,10 +621,10 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
else:
raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
- @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+ @mock.patch("airflow.operators.sql.BaseHook")
def test_invalid_query_result_with_dag_run(self, mock_hook):
- """ Check BranchSqlOperator branch operation """
- branch_op = BranchSqlOperator(
+ """ Check BranchSQLOperator branch operation """
+ branch_op = BranchSQLOperator(
task_id="make_choice",
conn_id="mysql_default",
sql="SELECT 1",
@@ -387,10 +654,10 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
with self.assertRaises(AirflowException):
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+ @mock.patch("airflow.operators.sql.BaseHook")
def test_with_skip_in_branch_downstream_dependencies(self, mock_hook):
""" Test SQL Branch with skipping all downstream dependencies """
- branch_op = BranchSqlOperator(
+ branch_op = BranchSQLOperator(
task_id="make_choice",
conn_id="mysql_default",
sql="SELECT 1",
@@ -431,10 +698,10 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
else:
raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
- @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+ @mock.patch("airflow.operators.sql.BaseHook")
def test_with_skip_in_branch_downstream_dependencies2(self, mock_hook):
""" Test skipping downstream dependency for false condition"""
- branch_op = BranchSqlOperator(
+ branch_op = BranchSQLOperator(
task_id="make_choice",
conn_id="mysql_default",
sql="SELECT 1",
diff --git a/tests/www_rbac/test_validators.py b/tests/www_rbac/test_validators.py
index 4a543ff..415c53f 100644
--- a/tests/www_rbac/test_validators.py
+++ b/tests/www_rbac/test_validators.py
@@ -119,7 +119,7 @@ class TestValidJson(unittest.TestCase):
def test_validation_raises_default_message(self):
self.form_field_mock.data = '2017-05-04'
- six.assertRaisesRegex(
+ six.assertRaisesRegexp(
self,
validators.ValidationError,
"JSON Validation Error:.*",
@@ -129,7 +129,7 @@ class TestValidJson(unittest.TestCase):
def test_validation_raises_custom_message(self):
self.form_field_mock.data = '2017-05-04'
- six.assertRaisesRegex(
+ six.assertRaisesRegexp(
self,
validators.ValidationError,
"Invalid JSON",