You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/29 11:25:18 UTC

[airflow] branch main updated: Move all "old" SQL operators to common.sql providers (#25350)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new acab8f52dd Move all "old" SQL operators to common.sql providers (#25350)
acab8f52dd is described below

commit acab8f52dd8d90fd6583779127895dd343780f79
Author: Jarek Potiuk <ja...@polidea.com>
AuthorDate: Fri Jul 29 13:25:11 2022 +0200

    Move all "old" SQL operators to common.sql providers (#25350)
    
    Previously, in #24836 we moved Hooks and added some new operators to the
    common.sql package. Now we are salso moving the operators
    and sensors to common.sql.
---
 airflow/operators/check_operator.py                |  24 +-
 airflow/operators/druid_check_operator.py          |   4 +-
 airflow/operators/presto_check_operator.py         |  24 +-
 airflow/operators/sql.py                           | 560 +--------------
 airflow/operators/sql_branch_operator.py           |  12 +-
 .../apache/druid/operators/druid_check.py          |   6 +-
 airflow/providers/apache/druid/provider.yaml       |   2 +-
 airflow/providers/common/sql/CHANGELOG.rst         |   5 +
 airflow/providers/common/sql/operators/sql.py      | 550 ++++++++++++++-
 airflow/providers/common/sql/provider.yaml         |   2 +-
 .../providers/google/cloud/operators/bigquery.py   |   6 +-
 airflow/providers/google/provider.yaml             |   2 +-
 .../google/suite/transfers/sql_to_sheets.py        |   2 +-
 airflow/providers/qubole/operators/qubole_check.py |   2 +-
 airflow/providers/qubole/provider.yaml             |   2 +-
 airflow/providers/snowflake/operators/snowflake.py |   6 +-
 airflow/providers/snowflake/provider.yaml          |   2 +-
 airflow/sensors/sql_sensor.py                      |   2 +-
 docs/apache-airflow/operators-and-hooks-ref.rst    |   6 -
 generated/provider_dependencies.json               |   8 +-
 tests/deprecated_classes.py                        |  18 +-
 tests/operators/test_sql.py                        | 779 ---------------------
 tests/providers/common/sql/operators/test_sql.py   | 776 +++++++++++++++++++-
 .../qubole/operators/test_qubole_check.py          |  14 +-
 tests/providers/slack/operators/test_slack.py      |   8 +-
 .../providers/slack/transfers/test_sql_to_slack.py |   6 +-
 26 files changed, 1424 insertions(+), 1404 deletions(-)

diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py
index 130eb6e577..0575211a01 100644
--- a/airflow/operators/check_operator.py
+++ b/airflow/operators/check_operator.py
@@ -16,11 +16,11 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""This module is deprecated. Please use :mod:`airflow.operators.sql`."""
+"""This module is deprecated. Please use :mod:`airflow.providers.common.sql.operators.sql`."""
 
 import warnings
 
-from airflow.operators.sql import (
+from airflow.providers.common.sql.operators.sql import (
     SQLCheckOperator,
     SQLIntervalCheckOperator,
     SQLThresholdCheckOperator,
@@ -28,20 +28,22 @@ from airflow.operators.sql import (
 )
 
 warnings.warn(
-    "This module is deprecated. Please use `airflow.operators.sql`.", DeprecationWarning, stacklevel=2
+    "This module is deprecated. Please use `airflow.providers.common.sql.operators.sql`.",
+    DeprecationWarning,
+    stacklevel=2,
 )
 
 
 class CheckOperator(SQLCheckOperator):
     """
     This class is deprecated.
-    Please use `airflow.operators.sql.SQLCheckOperator`.
+    Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`.
     """
 
     def __init__(self, **kwargs):
         warnings.warn(
             """This class is deprecated.
-            Please use `airflow.operators.sql.SQLCheckOperator`.""",
+            Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`.""",
             DeprecationWarning,
             stacklevel=2,
         )
@@ -51,13 +53,13 @@ class CheckOperator(SQLCheckOperator):
 class IntervalCheckOperator(SQLIntervalCheckOperator):
     """
     This class is deprecated.
-    Please use `airflow.operators.sql.SQLIntervalCheckOperator`.
+    Please use `airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator`.
     """
 
     def __init__(self, **kwargs):
         warnings.warn(
             """This class is deprecated.
-            Please use `airflow.operators.sql.SQLIntervalCheckOperator`.""",
+            Please use `airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator`.""",
             DeprecationWarning,
             stacklevel=2,
         )
@@ -67,13 +69,13 @@ class IntervalCheckOperator(SQLIntervalCheckOperator):
 class ThresholdCheckOperator(SQLThresholdCheckOperator):
     """
     This class is deprecated.
-    Please use `airflow.operators.sql.SQLThresholdCheckOperator`.
+    Please use `airflow.providers.common.sql.operators.sql.SQLThresholdCheckOperator`.
     """
 
     def __init__(self, **kwargs):
         warnings.warn(
             """This class is deprecated.
-            Please use `airflow.operators.sql.SQLThresholdCheckOperator`.""",
+            Please use `airflow.providers.common.sql.operators.sql.SQLThresholdCheckOperator`.""",
             DeprecationWarning,
             stacklevel=2,
         )
@@ -83,13 +85,13 @@ class ThresholdCheckOperator(SQLThresholdCheckOperator):
 class ValueCheckOperator(SQLValueCheckOperator):
     """
     This class is deprecated.
-    Please use `airflow.operators.sql.SQLValueCheckOperator`.
+    Please use `airflow.providers.common.sql.operators.sql.SQLValueCheckOperator`.
     """
 
     def __init__(self, **kwargs):
         warnings.warn(
             """This class is deprecated.
-            Please use `airflow.operators.sql.SQLValueCheckOperator`.""",
+            Please use `airflow.providers.common.sql.operators.sql.SQLValueCheckOperator`.""",
             DeprecationWarning,
             stacklevel=2,
         )
diff --git a/airflow/operators/druid_check_operator.py b/airflow/operators/druid_check_operator.py
index 008a91750c..217c306f6d 100644
--- a/airflow/operators/druid_check_operator.py
+++ b/airflow/operators/druid_check_operator.py
@@ -15,14 +15,14 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.druid.operators.druid_check`."""
+"""This module is deprecated. Please use :mod:`airflow.providers.common.sql.operators.sql`."""
 
 import warnings
 
 from airflow.providers.apache.druid.operators.druid_check import DruidCheckOperator  # noqa
 
 warnings.warn(
-    "This module is deprecated. Please use `airflow.operators.sql.SQLCheckOperator`.",
+    "This module is deprecated. Please use `airflow.providers.common.sql.operators.sql` module.",
     DeprecationWarning,
     stacklevel=2,
 )
diff --git a/airflow/operators/presto_check_operator.py b/airflow/operators/presto_check_operator.py
index 693471f18c..810eef39a4 100644
--- a/airflow/operators/presto_check_operator.py
+++ b/airflow/operators/presto_check_operator.py
@@ -15,27 +15,33 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.sql`."""
+"""This module is deprecated. Please use :mod:`airflow.providers.common.sql.operators.sql`."""
 
 import warnings
 
-from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator
+from airflow.providers.common.sql.operators.sql import (
+    SQLCheckOperator,
+    SQLIntervalCheckOperator,
+    SQLValueCheckOperator,
+)
 
 warnings.warn(
-    "This module is deprecated. Please use `airflow.operators.sql`.", DeprecationWarning, stacklevel=2
+    "This module is deprecated. Please use `airflow.providers.common.sql.operators.sql`.",
+    DeprecationWarning,
+    stacklevel=2,
 )
 
 
 class PrestoCheckOperator(SQLCheckOperator):
     """
     This class is deprecated.
-    Please use `airflow.operators.sql.SQLCheckOperator`.
+    Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`.
     """
 
     def __init__(self, **kwargs):
         warnings.warn(
             """This class is deprecated.
-            Please use `airflow.operators.sql.SQLCheckOperator`.""",
+            Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`.""",
             DeprecationWarning,
             stacklevel=2,
         )
@@ -45,14 +51,14 @@ class PrestoCheckOperator(SQLCheckOperator):
 class PrestoIntervalCheckOperator(SQLIntervalCheckOperator):
     """
     This class is deprecated.
-    Please use `airflow.operators.sql.SQLIntervalCheckOperator`.
+    Please use `airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator`.
     """
 
     def __init__(self, **kwargs):
         warnings.warn(
             """
             This class is deprecated.l
-            Please use `airflow.operators.sql.SQLIntervalCheckOperator`.
+            Please use `airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator`.
             """,
             DeprecationWarning,
             stacklevel=2,
@@ -63,14 +69,14 @@ class PrestoIntervalCheckOperator(SQLIntervalCheckOperator):
 class PrestoValueCheckOperator(SQLValueCheckOperator):
     """
     This class is deprecated.
-    Please use `airflow.operators.sql.SQLValueCheckOperator`.
+    Please use `airflow.providers.common.sql.operators.sql.SQLValueCheckOperator`.
     """
 
     def __init__(self, **kwargs):
         warnings.warn(
             """
             This class is deprecated.l
-            Please use `airflow.operators.sql.SQLValueCheckOperator`.
+            Please use `airflow.providers.common.sql.operators.sql.SQLValueCheckOperator`.
             """,
             DeprecationWarning,
             stacklevel=2,
diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py
index fbce9b85a1..9bbe159c17 100644
--- a/airflow/operators/sql.py
+++ b/airflow/operators/sql.py
@@ -15,543 +15,23 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, SupportsAbs, Union
-
-from airflow.compat.functools import cached_property
-from airflow.exceptions import AirflowException
-from airflow.hooks.base import BaseHook
-from airflow.models import BaseOperator, SkipMixin
-from airflow.providers.common.sql.hooks.sql import DbApiHook
-from airflow.utils.context import Context
-
-
-def parse_boolean(val: str) -> Union[str, bool]:
-    """Try to parse a string into boolean.
-
-    Raises ValueError if the input is not a valid true- or false-like string value.
-    """
-    val = val.lower()
-    if val in ('y', 'yes', 't', 'true', 'on', '1'):
-        return True
-    if val in ('n', 'no', 'f', 'false', 'off', '0'):
-        return False
-    raise ValueError(f"{val!r} is not a boolean-like string value")
-
-
-class BaseSQLOperator(BaseOperator):
-    """
-    This is a base class for generic SQL Operator to get a DB Hook
-
-    The provided method is .get_db_hook(). The default behavior will try to
-    retrieve the DB hook based on connection type.
-    You can custom the behavior by overriding the .get_db_hook() method.
-    """
-
-    def __init__(
-        self,
-        *,
-        conn_id: Optional[str] = None,
-        database: Optional[str] = None,
-        hook_params: Optional[Dict] = None,
-        **kwargs,
-    ):
-        super().__init__(**kwargs)
-        self.conn_id = conn_id
-        self.database = database
-        self.hook_params = {} if hook_params is None else hook_params
-
-    @cached_property
-    def _hook(self):
-        """Get DB Hook based on connection type"""
-        self.log.debug("Get connection for %s", self.conn_id)
-        conn = BaseHook.get_connection(self.conn_id)
-
-        hook = conn.get_hook(hook_params=self.hook_params)
-        if not isinstance(hook, DbApiHook):
-            raise AirflowException(
-                f'The connection type is not supported by {self.__class__.__name__}. '
-                f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}'
-            )
-
-        if self.database:
-            hook.schema = self.database
-
-        return hook
-
-    def get_db_hook(self) -> DbApiHook:
-        """
-        Get the database hook for the connection.
-
-        :return: the database hook object.
-        :rtype: DbApiHook
-        """
-        return self._hook
-
-
-class SQLCheckOperator(BaseSQLOperator):
-    """
-    Performs checks against a db. The ``SQLCheckOperator`` expects
-    a sql query that will return a single row. Each value on that
-    first row is evaluated using python ``bool`` casting. If any of the
-    values return ``False`` the check is failed and errors out.
-
-    Note that Python bool casting evals the following as ``False``:
-
-    * ``False``
-    * ``0``
-    * Empty string (``""``)
-    * Empty list (``[]``)
-    * Empty dictionary or set (``{}``)
-
-    Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if
-    the count ``== 0``. You can craft much more complex query that could,
-    for instance, check that the table has the same number of rows as
-    the source table upstream, or that the count of today's partition is
-    greater than yesterday's partition, or that a set of metrics are less
-    than 3 standard deviation for the 7 day average.
-
-    This operator can be used as a data quality check in your pipeline, and
-    depending on where you put it in your DAG, you have the choice to
-    stop the critical path, preventing from
-    publishing dubious data, or on the side and receive email alerts
-    without stopping the progress of the DAG.
-
-    :param sql: the sql to be executed. (templated)
-    :param conn_id: the connection ID used to connect to the database.
-    :param database: name of database which overwrite the defined one in connection
-    """
-
-    template_fields: Sequence[str] = ("sql",)
-    template_ext: Sequence[str] = (
-        ".hql",
-        ".sql",
-    )
-    template_fields_renderers = {"sql": "sql"}
-    ui_color = "#fff7e6"
-
-    def __init__(
-        self, *, sql: str, conn_id: Optional[str] = None, database: Optional[str] = None, **kwargs
-    ) -> None:
-        super().__init__(conn_id=conn_id, database=database, **kwargs)
-        self.sql = sql
-
-    def execute(self, context: Context):
-        self.log.info("Executing SQL check: %s", self.sql)
-        records = self.get_db_hook().get_first(self.sql)
-
-        self.log.info("Record: %s", records)
-        if not records:
-            raise AirflowException("The query returned None")
-        elif not all(bool(r) for r in records):
-            raise AirflowException(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}")
-
-        self.log.info("Success.")
-
-
-def _convert_to_float_if_possible(s):
-    """
-    A small helper function to convert a string to a numeric value
-    if appropriate
-
-    :param s: the string to be converted
-    """
-    try:
-        ret = float(s)
-    except (ValueError, TypeError):
-        ret = s
-    return ret
-
-
-class SQLValueCheckOperator(BaseSQLOperator):
-    """
-    Performs a simple value check using sql code.
-
-    :param sql: the sql to be executed. (templated)
-    :param conn_id: the connection ID used to connect to the database.
-    :param database: name of database which overwrite the defined one in connection
-    """
-
-    __mapper_args__ = {"polymorphic_identity": "SQLValueCheckOperator"}
-    template_fields: Sequence[str] = (
-        "sql",
-        "pass_value",
-    )
-    template_ext: Sequence[str] = (
-        ".hql",
-        ".sql",
-    )
-    template_fields_renderers = {"sql": "sql"}
-    ui_color = "#fff7e6"
-
-    def __init__(
-        self,
-        *,
-        sql: str,
-        pass_value: Any,
-        tolerance: Any = None,
-        conn_id: Optional[str] = None,
-        database: Optional[str] = None,
-        **kwargs,
-    ):
-        super().__init__(conn_id=conn_id, database=database, **kwargs)
-        self.sql = sql
-        self.pass_value = str(pass_value)
-        tol = _convert_to_float_if_possible(tolerance)
-        self.tol = tol if isinstance(tol, float) else None
-        self.has_tolerance = self.tol is not None
-
-    def execute(self, context=None):
-        self.log.info("Executing SQL check: %s", self.sql)
-        records = self.get_db_hook().get_first(self.sql)
-
-        if not records:
-            raise AirflowException("The query returned None")
-
-        pass_value_conv = _convert_to_float_if_possible(self.pass_value)
-        is_numeric_value_check = isinstance(pass_value_conv, float)
-
-        tolerance_pct_str = str(self.tol * 100) + "%" if self.has_tolerance else None
-        error_msg = (
-            "Test failed.\nPass value:{pass_value_conv}\n"
-            "Tolerance:{tolerance_pct_str}\n"
-            "Query:\n{sql}\nResults:\n{records!s}"
-        ).format(
-            pass_value_conv=pass_value_conv,
-            tolerance_pct_str=tolerance_pct_str,
-            sql=self.sql,
-            records=records,
-        )
-
-        if not is_numeric_value_check:
-            tests = self._get_string_matches(records, pass_value_conv)
-        elif is_numeric_value_check:
-            try:
-                numeric_records = self._to_float(records)
-            except (ValueError, TypeError):
-                raise AirflowException(f"Converting a result to float failed.\n{error_msg}")
-            tests = self._get_numeric_matches(numeric_records, pass_value_conv)
-        else:
-            tests = []
-
-        if not all(tests):
-            raise AirflowException(error_msg)
-
-    def _to_float(self, records):
-        return [float(record) for record in records]
-
-    def _get_string_matches(self, records, pass_value_conv):
-        return [str(record) == pass_value_conv for record in records]
-
-    def _get_numeric_matches(self, numeric_records, numeric_pass_value_conv):
-        if self.has_tolerance:
-            return [
-                numeric_pass_value_conv * (1 - self.tol) <= record <= numeric_pass_value_conv * (1 + self.tol)
-                for record in numeric_records
-            ]
-
-        return [record == numeric_pass_value_conv for record in numeric_records]
-
-
-class SQLIntervalCheckOperator(BaseSQLOperator):
-    """
-    Checks that the values of metrics given as SQL expressions are within
-    a certain tolerance of the ones from days_back before.
-
-    :param table: the table name
-    :param conn_id: the connection ID used to connect to the database.
-    :param database: name of database which will overwrite the defined one in connection
-    :param days_back: number of days between ds and the ds we want to check
-        against. Defaults to 7 days
-    :param date_filter_column: The column name for the dates to filter on. Defaults to 'ds'
-    :param ratio_formula: which formula to use to compute the ratio between
-        the two metrics. Assuming cur is the metric of today and ref is
-        the metric to today - days_back.
-
-        max_over_min: computes max(cur, ref) / min(cur, ref)
-        relative_diff: computes abs(cur-ref) / ref
-
-        Default: 'max_over_min'
-    :param ignore_zero: whether we should ignore zero metrics
-    :param metrics_thresholds: a dictionary of ratios indexed by metrics
-    """
-
-    __mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"}
-    template_fields: Sequence[str] = ("sql1", "sql2")
-    template_ext: Sequence[str] = (
-        ".hql",
-        ".sql",
-    )
-    template_fields_renderers = {"sql1": "sql", "sql2": "sql"}
-    ui_color = "#fff7e6"
-
-    ratio_formulas = {
-        "max_over_min": lambda cur, ref: float(max(cur, ref)) / min(cur, ref),
-        "relative_diff": lambda cur, ref: float(abs(cur - ref)) / ref,
-    }
-
-    def __init__(
-        self,
-        *,
-        table: str,
-        metrics_thresholds: Dict[str, int],
-        date_filter_column: Optional[str] = "ds",
-        days_back: SupportsAbs[int] = -7,
-        ratio_formula: Optional[str] = "max_over_min",
-        ignore_zero: bool = True,
-        conn_id: Optional[str] = None,
-        database: Optional[str] = None,
-        **kwargs,
-    ):
-        super().__init__(conn_id=conn_id, database=database, **kwargs)
-        if ratio_formula not in self.ratio_formulas:
-            msg_template = "Invalid diff_method: {diff_method}. Supported diff methods are: {diff_methods}"
-
-            raise AirflowException(
-                msg_template.format(diff_method=ratio_formula, diff_methods=self.ratio_formulas)
-            )
-        self.ratio_formula = ratio_formula
-        self.ignore_zero = ignore_zero
-        self.table = table
-        self.metrics_thresholds = metrics_thresholds
-        self.metrics_sorted = sorted(metrics_thresholds.keys())
-        self.date_filter_column = date_filter_column
-        self.days_back = -abs(days_back)
-        sqlexp = ", ".join(self.metrics_sorted)
-        sqlt = f"SELECT {sqlexp} FROM {table} WHERE {date_filter_column}="
-
-        self.sql1 = sqlt + "'{{ ds }}'"
-        self.sql2 = sqlt + "'{{ macros.ds_add(ds, " + str(self.days_back) + ") }}'"
-
-    def execute(self, context=None):
-        hook = self.get_db_hook()
-        self.log.info("Using ratio formula: %s", self.ratio_formula)
-        self.log.info("Executing SQL check: %s", self.sql2)
-        row2 = hook.get_first(self.sql2)
-        self.log.info("Executing SQL check: %s", self.sql1)
-        row1 = hook.get_first(self.sql1)
-
-        if not row2:
-            raise AirflowException(f"The query {self.sql2} returned None")
-        if not row1:
-            raise AirflowException(f"The query {self.sql1} returned None")
-
-        current = dict(zip(self.metrics_sorted, row1))
-        reference = dict(zip(self.metrics_sorted, row2))
-
-        ratios = {}
-        test_results = {}
-
-        for metric in self.metrics_sorted:
-            cur = current[metric]
-            ref = reference[metric]
-            threshold = self.metrics_thresholds[metric]
-            if cur == 0 or ref == 0:
-                ratios[metric] = None
-                test_results[metric] = self.ignore_zero
-            else:
-                ratios[metric] = self.ratio_formulas[self.ratio_formula](current[metric], reference[metric])
-                test_results[metric] = ratios[metric] < threshold
-
-            self.log.info(
-                (
-                    "Current metric for %s: %s\n"
-                    "Past metric for %s: %s\n"
-                    "Ratio for %s: %s\n"
-                    "Threshold: %s\n"
-                ),
-                metric,
-                cur,
-                metric,
-                ref,
-                metric,
-                ratios[metric],
-                threshold,
-            )
-
-        if not all(test_results.values()):
-            failed_tests = [it[0] for it in test_results.items() if not it[1]]
-            self.log.warning(
-                "The following %s tests out of %s failed:",
-                len(failed_tests),
-                len(self.metrics_sorted),
-            )
-            for k in failed_tests:
-                self.log.warning(
-                    "'%s' check failed. %s is above %s",
-                    k,
-                    ratios[k],
-                    self.metrics_thresholds[k],
-                )
-            raise AirflowException(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}")
-
-        self.log.info("All tests have passed")
-
-
-class SQLThresholdCheckOperator(BaseSQLOperator):
-    """
-    Performs a value check using sql code against a minimum threshold
-    and a maximum threshold. Thresholds can be in the form of a numeric
-    value OR a sql statement that results a numeric.
-
-    :param sql: the sql to be executed. (templated)
-    :param conn_id: the connection ID used to connect to the database.
-    :param database: name of database which overwrite the defined one in connection
-    :param min_threshold: numerical value or min threshold sql to be executed (templated)
-    :param max_threshold: numerical value or max threshold sql to be executed (templated)
-    """
-
-    template_fields: Sequence[str] = ("sql", "min_threshold", "max_threshold")
-    template_ext: Sequence[str] = (
-        ".hql",
-        ".sql",
-    )
-    template_fields_renderers = {"sql": "sql"}
-
-    def __init__(
-        self,
-        *,
-        sql: str,
-        min_threshold: Any,
-        max_threshold: Any,
-        conn_id: Optional[str] = None,
-        database: Optional[str] = None,
-        **kwargs,
-    ):
-        super().__init__(conn_id=conn_id, database=database, **kwargs)
-        self.sql = sql
-        self.min_threshold = _convert_to_float_if_possible(min_threshold)
-        self.max_threshold = _convert_to_float_if_possible(max_threshold)
-
-    def execute(self, context=None):
-        hook = self.get_db_hook()
-        result = hook.get_first(self.sql)[0]
-
-        if isinstance(self.min_threshold, float):
-            lower_bound = self.min_threshold
-        else:
-            lower_bound = hook.get_first(self.min_threshold)[0]
-
-        if isinstance(self.max_threshold, float):
-            upper_bound = self.max_threshold
-        else:
-            upper_bound = hook.get_first(self.max_threshold)[0]
-
-        meta_data = {
-            "result": result,
-            "task_id": self.task_id,
-            "min_threshold": lower_bound,
-            "max_threshold": upper_bound,
-            "within_threshold": lower_bound <= result <= upper_bound,
-        }
-
-        self.push(meta_data)
-        if not meta_data["within_threshold"]:
-            error_msg = (
-                f'Threshold Check: "{meta_data.get("task_id")}" failed.\n'
-                f'DAG: {self.dag_id}\nTask_id: {meta_data.get("task_id")}\n'
-                f'Check description: {meta_data.get("description")}\n'
-                f"SQL: {self.sql}\n"
-                f'Result: {round(meta_data.get("result"), 2)} is not within thresholds '
-                f'{meta_data.get("min_threshold")} and {meta_data.get("max_threshold")}'
-            )
-            raise AirflowException(error_msg)
-
-        self.log.info("Test %s Successful.", self.task_id)
-
-    def push(self, meta_data):
-        """
-        Optional: Send data check info and metadata to an external database.
-        Default functionality will log metadata.
-        """
-        info = "\n".join(f"""{key}: {item}""" for key, item in meta_data.items())
-        self.log.info("Log from %s:\n%s", self.dag_id, info)
-
-
-class BranchSQLOperator(BaseSQLOperator, SkipMixin):
-    """
-    Allows a DAG to "branch" or follow a specified path based on the results of a SQL query.
-
-    :param sql: The SQL code to be executed, should return true or false (templated)
-       Template reference are recognized by str ending in '.sql'.
-       Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
-       or string (true/y/yes/1/on/false/n/no/0/off).
-    :param follow_task_ids_if_true: task id or task ids to follow if query returns true
-    :param follow_task_ids_if_false: task id or task ids to follow if query returns false
-    :param conn_id: the connection ID used to connect to the database.
-    :param database: name of database which overwrite the defined one in connection
-    :param parameters: (optional) the parameters to render the SQL query with.
-    """
-
-    template_fields: Sequence[str] = ("sql",)
-    template_ext: Sequence[str] = (".sql",)
-    template_fields_renderers = {"sql": "sql"}
-    ui_color = "#a22034"
-    ui_fgcolor = "#F7F7F7"
-
-    def __init__(
-        self,
-        *,
-        sql: str,
-        follow_task_ids_if_true: List[str],
-        follow_task_ids_if_false: List[str],
-        conn_id: str = "default_conn_id",
-        database: Optional[str] = None,
-        parameters: Optional[Union[Iterable, Mapping]] = None,
-        **kwargs,
-    ) -> None:
-        super().__init__(conn_id=conn_id, database=database, **kwargs)
-        self.sql = sql
-        self.parameters = parameters
-        self.follow_task_ids_if_true = follow_task_ids_if_true
-        self.follow_task_ids_if_false = follow_task_ids_if_false
-
-    def execute(self, context: Context):
-        self.log.info(
-            "Executing: %s (with parameters %s) with connection: %s",
-            self.sql,
-            self.parameters,
-            self.conn_id,
-        )
-        record = self.get_db_hook().get_first(self.sql, self.parameters)
-        if not record:
-            raise AirflowException(
-                "No rows returned from sql query. Operator expected True or False return value."
-            )
-
-        if isinstance(record, list):
-            if isinstance(record[0], list):
-                query_result = record[0][0]
-            else:
-                query_result = record[0]
-        elif isinstance(record, tuple):
-            query_result = record[0]
-        else:
-            query_result = record
-
-        self.log.info("Query returns %s, type '%s'", query_result, type(query_result))
-
-        follow_branch = None
-        try:
-            if isinstance(query_result, bool):
-                if query_result:
-                    follow_branch = self.follow_task_ids_if_true
-            elif isinstance(query_result, str):
-                # return result is not Boolean, try to convert from String to Boolean
-                if parse_boolean(query_result):
-                    follow_branch = self.follow_task_ids_if_true
-            elif isinstance(query_result, int):
-                if bool(query_result):
-                    follow_branch = self.follow_task_ids_if_true
-            else:
-                raise AirflowException(
-                    f"Unexpected query return result '{query_result}' type '{type(query_result)}'"
-                )
-
-            if follow_branch is None:
-                follow_branch = self.follow_task_ids_if_false
-        except ValueError:
-            raise AirflowException(
-                f"Unexpected query return result '{query_result}' type '{type(query_result)}'"
-            )
-
-        self.skip_all_except(context["ti"], follow_branch)
+import warnings
+
+from airflow.providers.common.sql.operators.sql import (  # noqa
+    BaseSQLOperator,
+    BranchSQLOperator,
+    SQLCheckOperator,
+    SQLColumnCheckOperator,
+    SQLIntervalCheckOperator,
+    SQLTableCheckOperator,
+    SQLThresholdCheckOperator,
+    SQLValueCheckOperator,
+    _convert_to_float_if_possible,
+    parse_boolean,
+)
+
+warnings.warn(
+    "This module is deprecated. Please use `airflow.providers.common.sql.operators.sql`.",
+    DeprecationWarning,
+    stacklevel=2,
+)
diff --git a/airflow/operators/sql_branch_operator.py b/airflow/operators/sql_branch_operator.py
index 5987bce610..90d5fcbc1b 100644
--- a/airflow/operators/sql_branch_operator.py
+++ b/airflow/operators/sql_branch_operator.py
@@ -14,26 +14,28 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.sql`."""
+"""This module is deprecated. Please use :mod:`airflow.providers.common.sql.operators.sql`."""
 import warnings
 
-from airflow.operators.sql import BranchSQLOperator
+from airflow.providers.common.sql.operators.sql import BranchSQLOperator
 
 warnings.warn(
-    "This module is deprecated. Please use :mod:`airflow.operators.sql`.", DeprecationWarning, stacklevel=2
+    "This module is deprecated. Please use :mod:`airflow.providers.common.sql.operators.sql`.",
+    DeprecationWarning,
+    stacklevel=2,
 )
 
 
 class BranchSqlOperator(BranchSQLOperator):
     """
     This class is deprecated.
-    Please use `airflow.operators.sql.BranchSQLOperator`.
+    Please use `airflow.providers.common.sql.operators.sql.BranchSQLOperator`.
     """
 
     def __init__(self, **kwargs):
         warnings.warn(
             """This class is deprecated.
-            Please use `airflow.operators.sql.BranchSQLOperator`.""",
+            Please use `airflow.providers.common.sql.operators.sql.BranchSQLOperator`.""",
             DeprecationWarning,
             stacklevel=2,
         )
diff --git a/airflow/providers/apache/druid/operators/druid_check.py b/airflow/providers/apache/druid/operators/druid_check.py
index 33a4151350..75f2b17d0c 100644
--- a/airflow/providers/apache/druid/operators/druid_check.py
+++ b/airflow/providers/apache/druid/operators/druid_check.py
@@ -17,19 +17,19 @@
 # under the License.
 import warnings
 
-from airflow.operators.sql import SQLCheckOperator
+from airflow.providers.common.sql.operators.sql import SQLCheckOperator
 
 
 class DruidCheckOperator(SQLCheckOperator):
     """
     This class is deprecated.
-    Please use `airflow.operators.sql.SQLCheckOperator`.
+    Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`.
     """
 
     def __init__(self, druid_broker_conn_id: str = 'druid_broker_default', **kwargs):
         warnings.warn(
             """This class is deprecated.
-            Please use `airflow.operators.sql.SQLCheckOperator`.""",
+            Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`.""",
             DeprecationWarning,
             stacklevel=2,
         )
diff --git a/airflow/providers/apache/druid/provider.yaml b/airflow/providers/apache/druid/provider.yaml
index 038ae7fc23..1ccaf98153 100644
--- a/airflow/providers/apache/druid/provider.yaml
+++ b/airflow/providers/apache/druid/provider.yaml
@@ -39,7 +39,7 @@ versions:
 
 dependencies:
   - apache-airflow>=2.2.0
-  - apache-airflow-providers-common-sql
+  - apache-airflow-providers-common-sql>=1.1.0
   - pydruid>=0.4.1
 
 integrations:
diff --git a/airflow/providers/common/sql/CHANGELOG.rst b/airflow/providers/common/sql/CHANGELOG.rst
index d48dafc25d..c575d7a027 100644
--- a/airflow/providers/common/sql/CHANGELOG.rst
+++ b/airflow/providers/common/sql/CHANGELOG.rst
@@ -24,6 +24,11 @@
 Changelog
 ---------
 
+1.1.0
+.....
+
+
+
 1.0.0
 .....
 
diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py
index a6883d9f08..f1c872cc54 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -15,10 +15,19 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from typing import Any, Dict, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, SupportsAbs, Union
 
+from packaging.version import Version
+
+from airflow.compat.functools import cached_property
 from airflow.exceptions import AirflowException
-from airflow.operators.sql import BaseSQLOperator
+from airflow.hooks.base import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.providers.common.sql.hooks.sql import DbApiHook, _backported_get_hook
+from airflow.version import version
+
+if TYPE_CHECKING:
+    from airflow.utils.context import Context
 
 
 def parse_boolean(val: str) -> Union[str, bool]:
@@ -48,6 +57,61 @@ def _get_failed_checks(checks, col=None):
     ]
 
 
+class BaseSQLOperator(BaseOperator):
+    """
+    This is a base class for generic SQL Operator to get a DB Hook
+
+    The provided method is .get_db_hook(). The default behavior will try to
+    retrieve the DB hook based on connection type.
+    You can custom the behavior by overriding the .get_db_hook() method.
+    """
+
+    def __init__(
+        self,
+        *,
+        conn_id: Optional[str] = None,
+        database: Optional[str] = None,
+        hook_params: Optional[Dict] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.conn_id = conn_id
+        self.database = database
+        self.hook_params = {} if hook_params is None else hook_params
+
+    @cached_property
+    def _hook(self):
+        """Get DB Hook based on connection type"""
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+        if Version(version) >= Version('2.3'):
+            # "hook_params" were introduced to into "get_hook()" only in Airflow 2.3.
+            hook = conn.get_hook(hook_params=self.hook_params)  # ignore airflow compat check
+        else:
+            # For supporting Airflow versions < 2.3, we backport "get_hook()" method. This should be removed
+            # when "apache-airflow-providers-common-sql" will depend on Airflow >= 2.3.
+            hook = _backported_get_hook(conn, hook_params=self.hook_params)
+        if not isinstance(hook, DbApiHook):
+            raise AirflowException(
+                f'The connection type is not supported by {self.__class__.__name__}. '
+                f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}'
+            )
+
+        if self.database:
+            hook.schema = self.database
+
+        return hook
+
+    def get_db_hook(self) -> DbApiHook:
+        """
+        Get the database hook for the connection.
+
+        :return: the database hook object.
+        :rtype: DbApiHook
+        """
+        return self._hook
+
+
 class SQLColumnCheckOperator(BaseSQLOperator):
     """
     Performs one or more of the templated checks in the column_checks dictionary.
@@ -125,7 +189,7 @@ class SQLColumnCheckOperator(BaseSQLOperator):
         # OpenLineage needs a valid SQL query with the input/output table(s) to parse
         self.sql = f"SELECT * FROM {self.table};"
 
-    def execute(self, context=None):
+    def execute(self, context: 'Context'):
         hook = self.get_db_hook()
         failed_tests = []
         for column in self.column_mapping:
@@ -307,7 +371,7 @@ class SQLTableCheckOperator(BaseSQLOperator):
         # OpenLineage needs a valid SQL query with the input/output table(s) to parse
         self.sql = f"SELECT * FROM {self.table};"
 
-    def execute(self, context=None):
+    def execute(self, context: 'Context'):
         hook = self.get_db_hook()
         checks_sql = " UNION ALL ".join(
             [
@@ -343,3 +407,481 @@ class SQLTableCheckOperator(BaseSQLOperator):
             )
 
         self.log.info("All tests have passed")
+
+
+class SQLCheckOperator(BaseSQLOperator):
+    """
+    Performs checks against a db. The ``SQLCheckOperator`` expects
+    a sql query that will return a single row. Each value on that
+    first row is evaluated using python ``bool`` casting. If any of the
+    values return ``False`` the check is failed and errors out.
+
+    Note that Python bool casting evals the following as ``False``:
+
+    * ``False``
+    * ``0``
+    * Empty string (``""``)
+    * Empty list (``[]``)
+    * Empty dictionary or set (``{}``)
+
+    Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if
+    the count ``== 0``. You can craft much more complex query that could,
+    for instance, check that the table has the same number of rows as
+    the source table upstream, or that the count of today's partition is
+    greater than yesterday's partition, or that a set of metrics are less
+    than 3 standard deviation for the 7 day average.
+
+    This operator can be used as a data quality check in your pipeline, and
+    depending on where you put it in your DAG, you have the choice to
+    stop the critical path, preventing from
+    publishing dubious data, or on the side and receive email alerts
+    without stopping the progress of the DAG.
+
+    :param sql: the sql to be executed. (templated)
+    :param conn_id: the connection ID used to connect to the database.
+    :param database: name of database which overwrite the defined one in connection
+    """
+
+    template_fields: Sequence[str] = ("sql",)
+    template_ext: Sequence[str] = (
+        ".hql",
+        ".sql",
+    )
+    template_fields_renderers = {"sql": "sql"}
+    ui_color = "#fff7e6"
+
+    def __init__(
+        self, *, sql: str, conn_id: Optional[str] = None, database: Optional[str] = None, **kwargs
+    ) -> None:
+        super().__init__(conn_id=conn_id, database=database, **kwargs)
+        self.sql = sql
+
+    def execute(self, context: 'Context'):
+        self.log.info("Executing SQL check: %s", self.sql)
+        records = self.get_db_hook().get_first(self.sql)
+
+        self.log.info("Record: %s", records)
+        if not records:
+            raise AirflowException("The query returned None")
+        elif not all(bool(r) for r in records):
+            raise AirflowException(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}")
+
+        self.log.info("Success.")
+
+
+class SQLValueCheckOperator(BaseSQLOperator):
+    """
+    Performs a simple value check using sql code.
+
+    :param sql: the sql to be executed. (templated)
+    :param conn_id: the connection ID used to connect to the database.
+    :param database: name of database which overwrite the defined one in connection
+    """
+
+    __mapper_args__ = {"polymorphic_identity": "SQLValueCheckOperator"}
+    template_fields: Sequence[str] = (
+        "sql",
+        "pass_value",
+    )
+    template_ext: Sequence[str] = (
+        ".hql",
+        ".sql",
+    )
+    template_fields_renderers = {"sql": "sql"}
+    ui_color = "#fff7e6"
+
+    def __init__(
+        self,
+        *,
+        sql: str,
+        pass_value: Any,
+        tolerance: Any = None,
+        conn_id: Optional[str] = None,
+        database: Optional[str] = None,
+        **kwargs,
+    ):
+        super().__init__(conn_id=conn_id, database=database, **kwargs)
+        self.sql = sql
+        self.pass_value = str(pass_value)
+        tol = _convert_to_float_if_possible(tolerance)
+        self.tol = tol if isinstance(tol, float) else None
+        self.has_tolerance = self.tol is not None
+
+    def execute(self, context: 'Context'):
+        self.log.info("Executing SQL check: %s", self.sql)
+        records = self.get_db_hook().get_first(self.sql)
+
+        if not records:
+            raise AirflowException("The query returned None")
+
+        pass_value_conv = _convert_to_float_if_possible(self.pass_value)
+        is_numeric_value_check = isinstance(pass_value_conv, float)
+
+        tolerance_pct_str = str(self.tol * 100) + "%" if self.tol is not None else None
+        error_msg = (
+            "Test failed.\nPass value:{pass_value_conv}\n"
+            "Tolerance:{tolerance_pct_str}\n"
+            "Query:\n{sql}\nResults:\n{records!s}"
+        ).format(
+            pass_value_conv=pass_value_conv,
+            tolerance_pct_str=tolerance_pct_str,
+            sql=self.sql,
+            records=records,
+        )
+
+        if not is_numeric_value_check:
+            tests = self._get_string_matches(records, pass_value_conv)
+        elif is_numeric_value_check:
+            try:
+                numeric_records = self._to_float(records)
+            except (ValueError, TypeError):
+                raise AirflowException(f"Converting a result to float failed.\n{error_msg}")
+            tests = self._get_numeric_matches(numeric_records, pass_value_conv)
+        else:
+            tests = []
+
+        if not all(tests):
+            raise AirflowException(error_msg)
+
+    def _to_float(self, records):
+        return [float(record) for record in records]
+
+    def _get_string_matches(self, records, pass_value_conv):
+        return [str(record) == pass_value_conv for record in records]
+
+    def _get_numeric_matches(self, numeric_records, numeric_pass_value_conv):
+        if self.has_tolerance:
+            return [
+                numeric_pass_value_conv * (1 - self.tol) <= record <= numeric_pass_value_conv * (1 + self.tol)
+                for record in numeric_records
+            ]
+
+        return [record == numeric_pass_value_conv for record in numeric_records]
+
+
+class SQLIntervalCheckOperator(BaseSQLOperator):
+    """
+    Checks that the values of metrics given as SQL expressions are within
+    a certain tolerance of the ones from days_back before.
+
+    :param table: the table name
+    :param conn_id: the connection ID used to connect to the database.
+    :param database: name of database which will overwrite the defined one in connection
+    :param days_back: number of days between ds and the ds we want to check
+        against. Defaults to 7 days
+    :param date_filter_column: The column name for the dates to filter on. Defaults to 'ds'
+    :param ratio_formula: which formula to use to compute the ratio between
+        the two metrics. Assuming cur is the metric of today and ref is
+        the metric to today - days_back.
+
+        max_over_min: computes max(cur, ref) / min(cur, ref)
+        relative_diff: computes abs(cur-ref) / ref
+
+        Default: 'max_over_min'
+    :param ignore_zero: whether we should ignore zero metrics
+    :param metrics_thresholds: a dictionary of ratios indexed by metrics
+    """
+
+    __mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"}
+    template_fields: Sequence[str] = ("sql1", "sql2")
+    template_ext: Sequence[str] = (
+        ".hql",
+        ".sql",
+    )
+    template_fields_renderers = {"sql1": "sql", "sql2": "sql"}
+    ui_color = "#fff7e6"
+
+    ratio_formulas = {
+        "max_over_min": lambda cur, ref: float(max(cur, ref)) / min(cur, ref),
+        "relative_diff": lambda cur, ref: float(abs(cur - ref)) / ref,
+    }
+
+    def __init__(
+        self,
+        *,
+        table: str,
+        metrics_thresholds: Dict[str, int],
+        date_filter_column: Optional[str] = "ds",
+        days_back: SupportsAbs[int] = -7,
+        ratio_formula: Optional[str] = "max_over_min",
+        ignore_zero: bool = True,
+        conn_id: Optional[str] = None,
+        database: Optional[str] = None,
+        **kwargs,
+    ):
+        super().__init__(conn_id=conn_id, database=database, **kwargs)
+        if ratio_formula not in self.ratio_formulas:
+            msg_template = "Invalid diff_method: {diff_method}. Supported diff methods are: {diff_methods}"
+
+            raise AirflowException(
+                msg_template.format(diff_method=ratio_formula, diff_methods=self.ratio_formulas)
+            )
+        self.ratio_formula = ratio_formula
+        self.ignore_zero = ignore_zero
+        self.table = table
+        self.metrics_thresholds = metrics_thresholds
+        self.metrics_sorted = sorted(metrics_thresholds.keys())
+        self.date_filter_column = date_filter_column
+        self.days_back = -abs(days_back)
+        sqlexp = ", ".join(self.metrics_sorted)
+        sqlt = f"SELECT {sqlexp} FROM {table} WHERE {date_filter_column}="
+
+        self.sql1 = sqlt + "'{{ ds }}'"
+        self.sql2 = sqlt + "'{{ macros.ds_add(ds, " + str(self.days_back) + ") }}'"
+
+    def execute(self, context: 'Context'):
+        hook = self.get_db_hook()
+        self.log.info("Using ratio formula: %s", self.ratio_formula)
+        self.log.info("Executing SQL check: %s", self.sql2)
+        row2 = hook.get_first(self.sql2)
+        self.log.info("Executing SQL check: %s", self.sql1)
+        row1 = hook.get_first(self.sql1)
+
+        if not row2:
+            raise AirflowException(f"The query {self.sql2} returned None")
+        if not row1:
+            raise AirflowException(f"The query {self.sql1} returned None")
+
+        current = dict(zip(self.metrics_sorted, row1))
+        reference = dict(zip(self.metrics_sorted, row2))
+
+        ratios: Dict[str, Optional[int]] = {}
+        test_results = {}
+
+        for metric in self.metrics_sorted:
+            cur = current[metric]
+            ref = reference[metric]
+            threshold = self.metrics_thresholds[metric]
+            if cur == 0 or ref == 0:
+                ratios[metric] = None
+                test_results[metric] = self.ignore_zero
+            else:
+                ratio_metric = self.ratio_formulas[self.ratio_formula](current[metric], reference[metric])
+                ratios[metric] = ratio_metric
+                if ratio_metric is not None:
+                    test_results[metric] = ratio_metric < threshold
+                else:
+                    test_results[metric] = self.ignore_zero
+
+            self.log.info(
+                (
+                    "Current metric for %s: %s\n"
+                    "Past metric for %s: %s\n"
+                    "Ratio for %s: %s\n"
+                    "Threshold: %s\n"
+                ),
+                metric,
+                cur,
+                metric,
+                ref,
+                metric,
+                ratios[metric],
+                threshold,
+            )
+
+        if not all(test_results.values()):
+            failed_tests = [it[0] for it in test_results.items() if not it[1]]
+            self.log.warning(
+                "The following %s tests out of %s failed:",
+                len(failed_tests),
+                len(self.metrics_sorted),
+            )
+            for k in failed_tests:
+                self.log.warning(
+                    "'%s' check failed. %s is above %s",
+                    k,
+                    ratios[k],
+                    self.metrics_thresholds[k],
+                )
+            raise AirflowException(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}")
+
+        self.log.info("All tests have passed")
+
+
+class SQLThresholdCheckOperator(BaseSQLOperator):
+    """
+    Performs a value check using sql code against a minimum threshold
+    and a maximum threshold. Thresholds can be in the form of a numeric
+    value OR a sql statement that results a numeric.
+
+    :param sql: the sql to be executed. (templated)
+    :param conn_id: the connection ID used to connect to the database.
+    :param database: name of database which overwrite the defined one in connection
+    :param min_threshold: numerical value or min threshold sql to be executed (templated)
+    :param max_threshold: numerical value or max threshold sql to be executed (templated)
+    """
+
+    template_fields: Sequence[str] = ("sql", "min_threshold", "max_threshold")
+    template_ext: Sequence[str] = (
+        ".hql",
+        ".sql",
+    )
+    template_fields_renderers = {"sql": "sql"}
+
+    def __init__(
+        self,
+        *,
+        sql: str,
+        min_threshold: Any,
+        max_threshold: Any,
+        conn_id: Optional[str] = None,
+        database: Optional[str] = None,
+        **kwargs,
+    ):
+        super().__init__(conn_id=conn_id, database=database, **kwargs)
+        self.sql = sql
+        self.min_threshold = _convert_to_float_if_possible(min_threshold)
+        self.max_threshold = _convert_to_float_if_possible(max_threshold)
+
+    def execute(self, context: 'Context'):
+        hook = self.get_db_hook()
+        result = hook.get_first(self.sql)[0]
+
+        if isinstance(self.min_threshold, float):
+            lower_bound = self.min_threshold
+        else:
+            lower_bound = hook.get_first(self.min_threshold)[0]
+
+        if isinstance(self.max_threshold, float):
+            upper_bound = self.max_threshold
+        else:
+            upper_bound = hook.get_first(self.max_threshold)[0]
+
+        meta_data = {
+            "result": result,
+            "task_id": self.task_id,
+            "min_threshold": lower_bound,
+            "max_threshold": upper_bound,
+            "within_threshold": lower_bound <= result <= upper_bound,
+        }
+
+        self.push(meta_data)
+        if not meta_data["within_threshold"]:
+            result = (
+                round(meta_data.get("result"), 2)  # type: ignore[arg-type]
+                if meta_data.get("result") is not None
+                else "<None>"
+            )
+            error_msg = (
+                f'Threshold Check: "{meta_data.get("task_id")}" failed.\n'
+                f'DAG: {self.dag_id}\nTask_id: {meta_data.get("task_id")}\n'
+                f'Check description: {meta_data.get("description")}\n'
+                f"SQL: {self.sql}\n"
+                f'Result: {result} is not within thresholds '
+                f'{meta_data.get("min_threshold")} and {meta_data.get("max_threshold")}'
+            )
+            raise AirflowException(error_msg)
+
+        self.log.info("Test %s Successful.", self.task_id)
+
+    def push(self, meta_data):
+        """
+        Optional: Send data check info and metadata to an external database.
+        Default functionality will log metadata.
+        """
+        info = "\n".join(f"""{key}: {item}""" for key, item in meta_data.items())
+        self.log.info("Log from %s:\n%s", self.dag_id, info)
+
+
+class BranchSQLOperator(BaseSQLOperator, SkipMixin):
+    """
+    Allows a DAG to "branch" or follow a specified path based on the results of a SQL query.
+
+    :param sql: The SQL code to be executed, should return true or false (templated)
+       Template reference are recognized by str ending in '.sql'.
+       Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+       or string (true/y/yes/1/on/false/n/no/0/off).
+    :param follow_task_ids_if_true: task id or task ids to follow if query returns true
+    :param follow_task_ids_if_false: task id or task ids to follow if query returns false
+    :param conn_id: the connection ID used to connect to the database.
+    :param database: name of database which overwrite the defined one in connection
+    :param parameters: (optional) the parameters to render the SQL query with.
+    """
+
+    template_fields: Sequence[str] = ("sql",)
+    template_ext: Sequence[str] = (".sql",)
+    template_fields_renderers = {"sql": "sql"}
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    def __init__(
+        self,
+        *,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Iterable, Mapping]] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(conn_id=conn_id, database=database, **kwargs)
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+
+    def execute(self, context: 'Context'):
+        self.log.info(
+            "Executing: %s (with parameters %s) with connection: %s",
+            self.sql,
+            self.parameters,
+            self.conn_id,
+        )
+        record = self.get_db_hook().get_first(self.sql, self.parameters)
+        if not record:
+            raise AirflowException(
+                "No rows returned from sql query. Operator expected True or False return value."
+            )
+
+        if isinstance(record, list):
+            if isinstance(record[0], list):
+                query_result = record[0][0]
+            else:
+                query_result = record[0]
+        elif isinstance(record, tuple):
+            query_result = record[0]
+        else:
+            query_result = record
+
+        self.log.info("Query returns %s, type '%s'", query_result, type(query_result))
+
+        follow_branch = None
+        try:
+            if isinstance(query_result, bool):
+                if query_result:
+                    follow_branch = self.follow_task_ids_if_true
+            elif isinstance(query_result, str):
+                # return result is not Boolean, try to convert from String to Boolean
+                if parse_boolean(query_result):
+                    follow_branch = self.follow_task_ids_if_true
+            elif isinstance(query_result, int):
+                if bool(query_result):
+                    follow_branch = self.follow_task_ids_if_true
+            else:
+                raise AirflowException(
+                    f"Unexpected query return result '{query_result}' type '{type(query_result)}'"
+                )
+
+            if follow_branch is None:
+                follow_branch = self.follow_task_ids_if_false
+        except ValueError:
+            raise AirflowException(
+                f"Unexpected query return result '{query_result}' type '{type(query_result)}'"
+            )
+
+        self.skip_all_except(context["ti"], follow_branch)
+
+
+def _convert_to_float_if_possible(s):
+    """
+    A small helper function to convert a string to a numeric value
+    if appropriate
+
+    :param s: the string to be converted
+    """
+    try:
+        ret = float(s)
+    except (ValueError, TypeError):
+        ret = s
+    return ret
diff --git a/airflow/providers/common/sql/provider.yaml b/airflow/providers/common/sql/provider.yaml
index 39c8d483e4..30bc1258b3 100644
--- a/airflow/providers/common/sql/provider.yaml
+++ b/airflow/providers/common/sql/provider.yaml
@@ -22,7 +22,7 @@ description: |
     `Common SQL Provider <https://en.wikipedia.org/wiki/SQL>`__
 
 versions:
-  - 1.0.0
+  - 1.1.0
 
 dependencies:
   - sqlparse>=0.4.2
diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py
index 550c317406..827f82f5f6 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -32,7 +32,11 @@ from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob, Q
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator, BaseOperatorLink
 from airflow.models.xcom import XCom
-from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator
+from airflow.providers.common.sql.operators.sql import (
+    SQLCheckOperator,
+    SQLIntervalCheckOperator,
+    SQLValueCheckOperator,
+)
 from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob
 from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
 from airflow.providers.google.cloud.links.bigquery import BigQueryDatasetLink, BigQueryTableLink
diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml
index 1533cf4612..c3589140aa 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -53,7 +53,7 @@ versions:
 
 dependencies:
   - apache-airflow>=2.2.0
-  - apache-airflow-providers-common-sql
+  - apache-airflow-providers-common-sql>=1.1.0
   # Google has very clear rules on what dependencies should be used. All the limits below
   # follow strict guidelines of Google Libraries as quoted here:
   # While this issue is open, dependents of google-api-core, google-cloud-core. and google-auth
diff --git a/airflow/providers/google/suite/transfers/sql_to_sheets.py b/airflow/providers/google/suite/transfers/sql_to_sheets.py
index 8626fb7227..b448aafa17 100644
--- a/airflow/providers/google/suite/transfers/sql_to_sheets.py
+++ b/airflow/providers/google/suite/transfers/sql_to_sheets.py
@@ -22,7 +22,7 @@ import numbers
 from contextlib import closing
 from typing import Any, Iterable, Mapping, Optional, Sequence, Union
 
-from airflow.operators.sql import BaseSQLOperator
+from airflow.providers.common.sql.operators.sql import BaseSQLOperator
 from airflow.providers.google.suite.hooks.sheets import GSheetsHook
 
 
diff --git a/airflow/providers/qubole/operators/qubole_check.py b/airflow/providers/qubole/operators/qubole_check.py
index e63ff308b3..950c0d62fc 100644
--- a/airflow/providers/qubole/operators/qubole_check.py
+++ b/airflow/providers/qubole/operators/qubole_check.py
@@ -19,7 +19,7 @@
 from typing import Callable, Optional, Sequence, Union
 
 from airflow.exceptions import AirflowException
-from airflow.operators.sql import SQLCheckOperator, SQLValueCheckOperator
+from airflow.providers.common.sql.operators.sql import SQLCheckOperator, SQLValueCheckOperator
 from airflow.providers.qubole.hooks.qubole_check import QuboleCheckHook
 from airflow.providers.qubole.operators.qubole import QuboleOperator
 
diff --git a/airflow/providers/qubole/provider.yaml b/airflow/providers/qubole/provider.yaml
index 9b854da308..b45b2e1d3a 100644
--- a/airflow/providers/qubole/provider.yaml
+++ b/airflow/providers/qubole/provider.yaml
@@ -36,7 +36,7 @@ versions:
 
 dependencies:
   - apache-airflow>=2.2.0
-  - apache-airflow-providers-common-sql
+  - apache-airflow-providers-common-sql>=1.1.0
   - qds-sdk>=1.10.4
 
 integrations:
diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py
index dd996cc526..697cc9a344 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -18,8 +18,12 @@
 from typing import Any, Iterable, List, Mapping, Optional, Sequence, SupportsAbs, Union
 
 from airflow.models import BaseOperator
-from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator
 from airflow.providers.common.sql.hooks.sql import fetch_all_handler
+from airflow.providers.common.sql.operators.sql import (
+    SQLCheckOperator,
+    SQLIntervalCheckOperator,
+    SQLValueCheckOperator,
+)
 from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
 
 
diff --git a/airflow/providers/snowflake/provider.yaml b/airflow/providers/snowflake/provider.yaml
index 6ce66b1c66..cbf8fed30c 100644
--- a/airflow/providers/snowflake/provider.yaml
+++ b/airflow/providers/snowflake/provider.yaml
@@ -44,7 +44,7 @@ versions:
 
 dependencies:
   - apache-airflow>=2.2.0
-  - apache-airflow-providers-common-sql
+  - apache-airflow-providers-common-sql>=1.1.0
   - snowflake-connector-python>=2.4.1
   - snowflake-sqlalchemy>=1.1.0
 
diff --git a/airflow/sensors/sql_sensor.py b/airflow/sensors/sql_sensor.py
index fafc5335a2..6f7b1e46c0 100644
--- a/airflow/sensors/sql_sensor.py
+++ b/airflow/sensors/sql_sensor.py
@@ -15,7 +15,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""This module is deprecated. Please use :mod:`airflow.sensors.sql`."""
+"""This module is deprecated. Please use :mod:`airflow.providers.common.sql.sensors.sql`."""
 
 import warnings
 
diff --git a/docs/apache-airflow/operators-and-hooks-ref.rst b/docs/apache-airflow/operators-and-hooks-ref.rst
index c08a5deb00..7b5e148c5f 100644
--- a/docs/apache-airflow/operators-and-hooks-ref.rst
+++ b/docs/apache-airflow/operators-and-hooks-ref.rst
@@ -77,9 +77,6 @@ For details see: :doc:`apache-airflow-providers:operators-and-hooks-ref/index`.
    * - :mod:`airflow.operators.python`
      - :doc:`How to use <howto/operator/python>`
 
-   * - :mod:`airflow.operators.sql`
-     -
-
    * - :mod:`airflow.operators.subdag`
      -
 
@@ -112,9 +109,6 @@ For details see: :doc:`apache-airflow-providers:operators-and-hooks-ref/index`.
    * - :mod:`airflow.sensors.smart_sensor`
      - :doc:`concepts/smart-sensors`
 
-   * - :mod:`airflow.sensors.sql`
-     -
-
    * - :mod:`airflow.sensors.time_delta`
      -
 
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index 51e3ea26f8..bf1b18a598 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -70,7 +70,7 @@
   },
   "apache.druid": {
     "deps": [
-      "apache-airflow-providers-common-sql",
+      "apache-airflow-providers-common-sql>=1.1.0",
       "apache-airflow>=2.2.0",
       "pydruid>=0.4.1"
     ],
@@ -292,7 +292,7 @@
   "google": {
     "deps": [
       "PyOpenSSL",
-      "apache-airflow-providers-common-sql",
+      "apache-airflow-providers-common-sql>=1.1.0",
       "apache-airflow>=2.2.0",
       "google-ads>=15.1.1",
       "google-api-core>=2.7.0,<3.0.0",
@@ -576,7 +576,7 @@
   },
   "qubole": {
     "deps": [
-      "apache-airflow-providers-common-sql",
+      "apache-airflow-providers-common-sql>=1.1.0",
       "apache-airflow>=2.2.0",
       "qds-sdk>=1.10.4"
     ],
@@ -650,7 +650,7 @@
   },
   "snowflake": {
     "deps": [
-      "apache-airflow-providers-common-sql",
+      "apache-airflow-providers-common-sql>=1.1.0",
       "apache-airflow>=2.2.0",
       "snowflake-connector-python>=2.4.1",
       "snowflake-sqlalchemy>=1.1.0"
diff --git a/tests/deprecated_classes.py b/tests/deprecated_classes.py
index 5b76af905d..36ac69a2b9 100644
--- a/tests/deprecated_classes.py
+++ b/tests/deprecated_classes.py
@@ -1028,7 +1028,7 @@ OPERATORS = [
         'airflow.contrib.operators.sqoop_operator.SqoopOperator',
     ),
     (
-        'airflow.operators.sql.SQLCheckOperator',
+        'airflow.providers.common.sql.operators.sql.SQLCheckOperator',
         'airflow.operators.druid_check_operator.DruidCheckOperator',
     ),
     (
@@ -1232,35 +1232,35 @@ OPERATORS = [
         'airflow.operators.papermill_operator.PapermillOperator',
     ),
     (
-        'airflow.operators.sql.SQLCheckOperator',
+        'airflow.providers.common.sql.operators.sql.SQLCheckOperator',
         'airflow.operators.presto_check_operator.PrestoCheckOperator',
     ),
     (
-        'airflow.operators.sql.SQLIntervalCheckOperator',
+        'airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator',
         'airflow.operators.presto_check_operator.PrestoIntervalCheckOperator',
     ),
     (
-        'airflow.operators.sql.SQLValueCheckOperator',
+        'airflow.providers.common.sql.operators.sql.SQLValueCheckOperator',
         'airflow.operators.presto_check_operator.PrestoValueCheckOperator',
     ),
     (
-        'airflow.operators.sql.SQLCheckOperator',
+        'airflow.providers.common.sql.operators.sql.SQLCheckOperator',
         'airflow.operators.check_operator.CheckOperator',
     ),
     (
-        'airflow.operators.sql.SQLIntervalCheckOperator',
+        'airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator',
         'airflow.operators.check_operator.IntervalCheckOperator',
     ),
     (
-        'airflow.operators.sql.SQLValueCheckOperator',
+        'airflow.providers.common.sql.operators.sql.SQLValueCheckOperator',
         'airflow.operators.check_operator.ValueCheckOperator',
     ),
     (
-        'airflow.operators.sql.SQLThresholdCheckOperator',
+        'airflow.providers.common.sql.operators.sql.SQLThresholdCheckOperator',
         'airflow.operators.check_operator.ThresholdCheckOperator',
     ),
     (
-        'airflow.operators.sql.BranchSQLOperator',
+        'airflow.providers.common.sql.operators.sql.BranchSQLOperator',
         'airflow.operators.sql_branch_operator.BranchSqlOperator',
     ),
     (
diff --git a/tests/operators/test_sql.py b/tests/operators/test_sql.py
deleted file mode 100644
index 43d202819a..0000000000
--- a/tests/operators/test_sql.py
+++ /dev/null
@@ -1,779 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import datetime
-import unittest
-from unittest import mock
-
-import pytest
-
-from airflow.exceptions import AirflowException
-from airflow.models import DAG, Connection, DagRun, TaskInstance as TI, XCom
-from airflow.operators.empty import EmptyOperator
-from airflow.operators.sql import (
-    BranchSQLOperator,
-    SQLCheckOperator,
-    SQLIntervalCheckOperator,
-    SQLThresholdCheckOperator,
-    SQLValueCheckOperator,
-)
-from airflow.providers.postgres.hooks.postgres import PostgresHook
-from airflow.utils import timezone
-from airflow.utils.session import create_session
-from airflow.utils.state import State
-from tests.providers.apache.hive import TestHiveEnvironment
-
-DEFAULT_DATE = timezone.datetime(2016, 1, 1)
-INTERVAL = datetime.timedelta(hours=12)
-
-SUPPORTED_TRUE_VALUES = [
-    ["True"],
-    ["true"],
-    ["1"],
-    ["on"],
-    [1],
-    True,
-    "true",
-    "1",
-    "on",
-    1,
-]
-SUPPORTED_FALSE_VALUES = [
-    ["False"],
-    ["false"],
-    ["0"],
-    ["off"],
-    [0],
-    False,
-    "false",
-    "0",
-    "off",
-    0,
-]
-
-
-@mock.patch(
-    'airflow.operators.sql.BaseHook.get_connection',
-    return_value=Connection(conn_id='sql_default', conn_type='postgres'),
-)
-class TestSQLCheckOperatorDbHook:
-    def setup_method(self):
-        self.task_id = "test_task"
-        self.conn_id = "sql_default"
-        self._operator = SQLCheckOperator(task_id=self.task_id, conn_id=self.conn_id, sql="sql")
-
-    @pytest.mark.parametrize('database', [None, 'test-db'])
-    def test_get_hook(self, mock_get_conn, database):
-        if database:
-            self._operator.database = database
-        assert isinstance(self._operator._hook, PostgresHook)
-        assert self._operator._hook.schema == database
-        mock_get_conn.assert_called_once_with(self.conn_id)
-
-    def test_not_allowed_conn_type(self, mock_get_conn):
-        mock_get_conn.return_value = Connection(conn_id='sql_default', conn_type='s3')
-        with pytest.raises(AirflowException, match=r"The connection type is not supported"):
-            self._operator._hook
-
-    def test_sql_operator_hook_params_snowflake(self, mock_get_conn):
-        mock_get_conn.return_value = Connection(conn_id='snowflake_default', conn_type='snowflake')
-        self._operator.hook_params = {
-            'warehouse': 'warehouse',
-            'database': 'database',
-            'role': 'role',
-            'schema': 'schema',
-            'log_sql': False,
-        }
-        assert self._operator._hook.conn_type == 'snowflake'
-        assert self._operator._hook.warehouse == 'warehouse'
-        assert self._operator._hook.database == 'database'
-        assert self._operator._hook.role == 'role'
-        assert self._operator._hook.schema == 'schema'
-        assert not self._operator._hook.log_sql
-
-    def test_sql_operator_hook_params_biguery(self, mock_get_conn):
-        mock_get_conn.return_value = Connection(
-            conn_id='google_cloud_bigquery_default', conn_type='gcpbigquery'
-        )
-        self._operator.hook_params = {'use_legacy_sql': True, 'location': 'us-east1'}
-        assert self._operator._hook.conn_type == 'gcpbigquery'
-        assert self._operator._hook.use_legacy_sql
-        assert self._operator._hook.location == 'us-east1'
-
-
-class TestCheckOperator(unittest.TestCase):
-    def setUp(self):
-        self._operator = SQLCheckOperator(task_id="test_task", sql="sql")
-
-    @mock.patch.object(SQLCheckOperator, "get_db_hook")
-    def test_execute_no_records(self, mock_get_db_hook):
-        mock_get_db_hook.return_value.get_first.return_value = []
-
-        with pytest.raises(AirflowException, match=r"The query returned None"):
-            self._operator.execute({})
-
-    @mock.patch.object(SQLCheckOperator, "get_db_hook")
-    def test_execute_not_all_records_are_true(self, mock_get_db_hook):
-        mock_get_db_hook.return_value.get_first.return_value = ["data", ""]
-
-        with pytest.raises(AirflowException, match=r"Test failed."):
-            self._operator.execute({})
-
-
-class TestValueCheckOperator(unittest.TestCase):
-    def setUp(self):
-        self.task_id = "test_task"
-        self.conn_id = "default_conn"
-
-    def _construct_operator(self, sql, pass_value, tolerance=None):
-        dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1))
-
-        return SQLValueCheckOperator(
-            dag=dag,
-            task_id=self.task_id,
-            conn_id=self.conn_id,
-            sql=sql,
-            pass_value=pass_value,
-            tolerance=tolerance,
-        )
-
-    def test_pass_value_template_string(self):
-        pass_value_str = "2018-03-22"
-        operator = self._construct_operator("select date from tab1;", "{{ ds }}")
-
-        operator.render_template_fields({"ds": pass_value_str})
-
-        assert operator.task_id == self.task_id
-        assert operator.pass_value == pass_value_str
-
-    def test_pass_value_template_string_float(self):
-        pass_value_float = 4.0
-        operator = self._construct_operator("select date from tab1;", pass_value_float)
-
-        operator.render_template_fields({})
-
-        assert operator.task_id == self.task_id
-        assert operator.pass_value == str(pass_value_float)
-
-    @mock.patch.object(SQLValueCheckOperator, "get_db_hook")
-    def test_execute_pass(self, mock_get_db_hook):
-        mock_hook = mock.Mock()
-        mock_hook.get_first.return_value = [10]
-        mock_get_db_hook.return_value = mock_hook
-        sql = "select value from tab1 limit 1;"
-        operator = self._construct_operator(sql, 5, 1)
-
-        operator.execute(None)
-
-        mock_hook.get_first.assert_called_once_with(sql)
-
-    @mock.patch.object(SQLValueCheckOperator, "get_db_hook")
-    def test_execute_fail(self, mock_get_db_hook):
-        mock_hook = mock.Mock()
-        mock_hook.get_first.return_value = [11]
-        mock_get_db_hook.return_value = mock_hook
-
-        operator = self._construct_operator("select value from tab1 limit 1;", 5, 1)
-
-        with pytest.raises(AirflowException, match="Tolerance:100.0%"):
-            operator.execute()
-
-
-class TestIntervalCheckOperator(unittest.TestCase):
-    def _construct_operator(self, table, metric_thresholds, ratio_formula, ignore_zero):
-        return SQLIntervalCheckOperator(
-            task_id="test_task",
-            table=table,
-            metrics_thresholds=metric_thresholds,
-            ratio_formula=ratio_formula,
-            ignore_zero=ignore_zero,
-        )
-
-    def test_invalid_ratio_formula(self):
-        with pytest.raises(AirflowException, match="Invalid diff_method"):
-            self._construct_operator(
-                table="test_table",
-                metric_thresholds={
-                    "f1": 1,
-                },
-                ratio_formula="abs",
-                ignore_zero=False,
-            )
-
-    @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
-    def test_execute_not_ignore_zero(self, mock_get_db_hook):
-        mock_hook = mock.Mock()
-        mock_hook.get_first.return_value = [0]
-        mock_get_db_hook.return_value = mock_hook
-
-        operator = self._construct_operator(
-            table="test_table",
-            metric_thresholds={
-                "f1": 1,
-            },
-            ratio_formula="max_over_min",
-            ignore_zero=False,
-        )
-
-        with pytest.raises(AirflowException):
-            operator.execute()
-
-    @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
-    def test_execute_ignore_zero(self, mock_get_db_hook):
-        mock_hook = mock.Mock()
-        mock_hook.get_first.return_value = [0]
-        mock_get_db_hook.return_value = mock_hook
-
-        operator = self._construct_operator(
-            table="test_table",
-            metric_thresholds={
-                "f1": 1,
-            },
-            ratio_formula="max_over_min",
-            ignore_zero=True,
-        )
-
-        operator.execute()
-
-    @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
-    def test_execute_min_max(self, mock_get_db_hook):
-        mock_hook = mock.Mock()
-
-        def returned_row():
-            rows = [
-                [2, 2, 2, 2],  # reference
-                [1, 1, 1, 1],  # current
-            ]
-
-            yield from rows
-
-        mock_hook.get_first.side_effect = returned_row()
-        mock_get_db_hook.return_value = mock_hook
-
-        operator = self._construct_operator(
-            table="test_table",
-            metric_thresholds={
-                "f0": 1.0,
-                "f1": 1.5,
-                "f2": 2.0,
-                "f3": 2.5,
-            },
-            ratio_formula="max_over_min",
-            ignore_zero=True,
-        )
-
-        with pytest.raises(AirflowException, match="f0, f1, f2"):
-            operator.execute()
-
-    @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
-    def test_execute_diff(self, mock_get_db_hook):
-        mock_hook = mock.Mock()
-
-        def returned_row():
-            rows = [
-                [3, 3, 3, 3],  # reference
-                [1, 1, 1, 1],  # current
-            ]
-
-            yield from rows
-
-        mock_hook.get_first.side_effect = returned_row()
-        mock_get_db_hook.return_value = mock_hook
-
-        operator = self._construct_operator(
-            table="test_table",
-            metric_thresholds={
-                "f0": 0.5,
-                "f1": 0.6,
-                "f2": 0.7,
-                "f3": 0.8,
-            },
-            ratio_formula="relative_diff",
-            ignore_zero=True,
-        )
-
-        with pytest.raises(AirflowException, match="f0, f1"):
-            operator.execute()
-
-
-class TestThresholdCheckOperator(unittest.TestCase):
-    def _construct_operator(self, sql, min_threshold, max_threshold):
-        dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1))
-
-        return SQLThresholdCheckOperator(
-            task_id="test_task",
-            sql=sql,
-            min_threshold=min_threshold,
-            max_threshold=max_threshold,
-            dag=dag,
-        )
-
-    @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
-    def test_pass_min_value_max_value(self, mock_get_db_hook):
-        mock_hook = mock.Mock()
-        mock_hook.get_first.return_value = (10,)
-        mock_get_db_hook.return_value = mock_hook
-
-        operator = self._construct_operator("Select avg(val) from table1 limit 1", 1, 100)
-
-        operator.execute()
-
-    @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
-    def test_fail_min_value_max_value(self, mock_get_db_hook):
-        mock_hook = mock.Mock()
-        mock_hook.get_first.return_value = (10,)
-        mock_get_db_hook.return_value = mock_hook
-
-        operator = self._construct_operator("Select avg(val) from table1 limit 1", 20, 100)
-
-        with pytest.raises(AirflowException, match="10.*20.0.*100.0"):
-            operator.execute()
-
-    @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
-    def test_pass_min_sql_max_sql(self, mock_get_db_hook):
-        mock_hook = mock.Mock()
-        mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
-        mock_get_db_hook.return_value = mock_hook
-
-        operator = self._construct_operator("Select 10", "Select 1", "Select 100")
-
-        operator.execute()
-
-    @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
-    def test_fail_min_sql_max_sql(self, mock_get_db_hook):
-        mock_hook = mock.Mock()
-        mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
-        mock_get_db_hook.return_value = mock_hook
-
-        operator = self._construct_operator("Select 10", "Select 20", "Select 100")
-
-        with pytest.raises(AirflowException, match="10.*20.*100"):
-            operator.execute()
-
-    @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
-    def test_pass_min_value_max_sql(self, mock_get_db_hook):
-        mock_hook = mock.Mock()
-        mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
-        mock_get_db_hook.return_value = mock_hook
-
-        operator = self._construct_operator("Select 75", 45, "Select 100")
-
-        operator.execute()
-
-    @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
-    def test_fail_min_sql_max_value(self, mock_get_db_hook):
-        mock_hook = mock.Mock()
-        mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
-        mock_get_db_hook.return_value = mock_hook
-
-        operator = self._construct_operator("Select 155", "Select 45", 100)
-
-        with pytest.raises(AirflowException, match="155.*45.*100.0"):
-            operator.execute()
-
-
-class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
-    """
-    Test for SQL Branch Operator
-    """
-
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-
-        with create_session() as session:
-            session.query(DagRun).delete()
-            session.query(TI).delete()
-            session.query(XCom).delete()
-
-    def setUp(self):
-        super().setUp()
-        self.dag = DAG(
-            "sql_branch_operator_test",
-            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
-            schedule_interval=INTERVAL,
-        )
-        self.branch_1 = EmptyOperator(task_id="branch_1", dag=self.dag)
-        self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag)
-        self.branch_3 = None
-
-    def tearDown(self):
-        super().tearDown()
-
-        with create_session() as session:
-            session.query(DagRun).delete()
-            session.query(TI).delete()
-            session.query(XCom).delete()
-
-    def test_unsupported_conn_type(self):
-        """Check if BranchSQLOperator throws an exception for unsupported connection type"""
-        op = BranchSQLOperator(
-            task_id="make_choice",
-            conn_id="redis_default",
-            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
-            follow_task_ids_if_true="branch_1",
-            follow_task_ids_if_false="branch_2",
-            dag=self.dag,
-        )
-
-        with pytest.raises(AirflowException):
-            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
-
-    def test_invalid_conn(self):
-        """Check if BranchSQLOperator throws an exception for invalid connection"""
-        op = BranchSQLOperator(
-            task_id="make_choice",
-            conn_id="invalid_connection",
-            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
-            follow_task_ids_if_true="branch_1",
-            follow_task_ids_if_false="branch_2",
-            dag=self.dag,
-        )
-
-        with pytest.raises(AirflowException):
-            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
-
-    def test_invalid_follow_task_true(self):
-        """Check if BranchSQLOperator throws an exception for invalid connection"""
-        op = BranchSQLOperator(
-            task_id="make_choice",
-            conn_id="invalid_connection",
-            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
-            follow_task_ids_if_true=None,
-            follow_task_ids_if_false="branch_2",
-            dag=self.dag,
-        )
-
-        with pytest.raises(AirflowException):
-            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
-
-    def test_invalid_follow_task_false(self):
-        """Check if BranchSQLOperator throws an exception for invalid connection"""
-        op = BranchSQLOperator(
-            task_id="make_choice",
-            conn_id="invalid_connection",
-            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
-            follow_task_ids_if_true="branch_1",
-            follow_task_ids_if_false=None,
-            dag=self.dag,
-        )
-
-        with pytest.raises(AirflowException):
-            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
-
-    @pytest.mark.backend("mysql")
-    def test_sql_branch_operator_mysql(self):
-        """Check if BranchSQLOperator works with backend"""
-        branch_op = BranchSQLOperator(
-            task_id="make_choice",
-            conn_id="mysql_default",
-            sql="SELECT 1",
-            follow_task_ids_if_true="branch_1",
-            follow_task_ids_if_false="branch_2",
-            dag=self.dag,
-        )
-        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
-
-    @pytest.mark.backend("postgres")
-    def test_sql_branch_operator_postgres(self):
-        """Check if BranchSQLOperator works with backend"""
-        branch_op = BranchSQLOperator(
-            task_id="make_choice",
-            conn_id="postgres_default",
-            sql="SELECT 1",
-            follow_task_ids_if_true="branch_1",
-            follow_task_ids_if_false="branch_2",
-            dag=self.dag,
-        )
-        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
-
-    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
-    def test_branch_single_value_with_dag_run(self, mock_get_db_hook):
-        """Check BranchSQLOperator branch operation"""
-        branch_op = BranchSQLOperator(
-            task_id="make_choice",
-            conn_id="mysql_default",
-            sql="SELECT 1",
-            follow_task_ids_if_true="branch_1",
-            follow_task_ids_if_false="branch_2",
-            dag=self.dag,
-        )
-
-        self.branch_1.set_upstream(branch_op)
-        self.branch_2.set_upstream(branch_op)
-        self.dag.clear()
-
-        dr = self.dag.create_dagrun(
-            run_id="manual__",
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-
-        mock_get_records = mock_get_db_hook.return_value.get_first
-
-        mock_get_records.return_value = 1
-
-        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        tis = dr.get_task_instances()
-        for ti in tis:
-            if ti.task_id == "make_choice":
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == "branch_1":
-                assert ti.state == State.NONE
-            elif ti.task_id == "branch_2":
-                assert ti.state == State.SKIPPED
-            else:
-                raise ValueError(f"Invalid task id {ti.task_id} found!")
-
-    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
-    def test_branch_true_with_dag_run(self, mock_get_db_hook):
-        """Check BranchSQLOperator branch operation"""
-        branch_op = BranchSQLOperator(
-            task_id="make_choice",
-            conn_id="mysql_default",
-            sql="SELECT 1",
-            follow_task_ids_if_true="branch_1",
-            follow_task_ids_if_false="branch_2",
-            dag=self.dag,
-        )
-
-        self.branch_1.set_upstream(branch_op)
-        self.branch_2.set_upstream(branch_op)
-        self.dag.clear()
-
-        dr = self.dag.create_dagrun(
-            run_id="manual__",
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-
-        mock_get_records = mock_get_db_hook.return_value.get_first
-
-        for true_value in SUPPORTED_TRUE_VALUES:
-            mock_get_records.return_value = true_value
-
-            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-            tis = dr.get_task_instances()
-            for ti in tis:
-                if ti.task_id == "make_choice":
-                    assert ti.state == State.SUCCESS
-                elif ti.task_id == "branch_1":
-                    assert ti.state == State.NONE
-                elif ti.task_id == "branch_2":
-                    assert ti.state == State.SKIPPED
-                else:
-                    raise ValueError(f"Invalid task id {ti.task_id} found!")
-
-    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
-    def test_branch_false_with_dag_run(self, mock_get_db_hook):
-        """Check BranchSQLOperator branch operation"""
-        branch_op = BranchSQLOperator(
-            task_id="make_choice",
-            conn_id="mysql_default",
-            sql="SELECT 1",
-            follow_task_ids_if_true="branch_1",
-            follow_task_ids_if_false="branch_2",
-            dag=self.dag,
-        )
-
-        self.branch_1.set_upstream(branch_op)
-        self.branch_2.set_upstream(branch_op)
-        self.dag.clear()
-
-        dr = self.dag.create_dagrun(
-            run_id="manual__",
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-
-        mock_get_records = mock_get_db_hook.return_value.get_first
-
-        for false_value in SUPPORTED_FALSE_VALUES:
-            mock_get_records.return_value = false_value
-            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-            tis = dr.get_task_instances()
-            for ti in tis:
-                if ti.task_id == "make_choice":
-                    assert ti.state == State.SUCCESS
-                elif ti.task_id == "branch_1":
-                    assert ti.state == State.SKIPPED
-                elif ti.task_id == "branch_2":
-                    assert ti.state == State.NONE
-                else:
-                    raise ValueError(f"Invalid task id {ti.task_id} found!")
-
-    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
-    def test_branch_list_with_dag_run(self, mock_get_db_hook):
-        """Checks if the BranchSQLOperator supports branching off to a list of tasks."""
-        branch_op = BranchSQLOperator(
-            task_id="make_choice",
-            conn_id="mysql_default",
-            sql="SELECT 1",
-            follow_task_ids_if_true=["branch_1", "branch_2"],
-            follow_task_ids_if_false="branch_3",
-            dag=self.dag,
-        )
-
-        self.branch_1.set_upstream(branch_op)
-        self.branch_2.set_upstream(branch_op)
-        self.branch_3 = EmptyOperator(task_id="branch_3", dag=self.dag)
-        self.branch_3.set_upstream(branch_op)
-        self.dag.clear()
-
-        dr = self.dag.create_dagrun(
-            run_id="manual__",
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-
-        mock_get_records = mock_get_db_hook.return_value.get_first
-        mock_get_records.return_value = [["1"]]
-
-        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        tis = dr.get_task_instances()
-        for ti in tis:
-            if ti.task_id == "make_choice":
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == "branch_1":
-                assert ti.state == State.NONE
-            elif ti.task_id == "branch_2":
-                assert ti.state == State.NONE
-            elif ti.task_id == "branch_3":
-                assert ti.state == State.SKIPPED
-            else:
-                raise ValueError(f"Invalid task id {ti.task_id} found!")
-
-    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
-    def test_invalid_query_result_with_dag_run(self, mock_get_db_hook):
-        """Check BranchSQLOperator branch operation"""
-        branch_op = BranchSQLOperator(
-            task_id="make_choice",
-            conn_id="mysql_default",
-            sql="SELECT 1",
-            follow_task_ids_if_true="branch_1",
-            follow_task_ids_if_false="branch_2",
-            dag=self.dag,
-        )
-
-        self.branch_1.set_upstream(branch_op)
-        self.branch_2.set_upstream(branch_op)
-        self.dag.clear()
-
-        self.dag.create_dagrun(
-            run_id="manual__",
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-
-        mock_get_records = mock_get_db_hook.return_value.get_first
-
-        mock_get_records.return_value = ["Invalid Value"]
-
-        with pytest.raises(AirflowException):
-            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
-    def test_with_skip_in_branch_downstream_dependencies(self, mock_get_db_hook):
-        """Test SQL Branch with skipping all downstream dependencies"""
-        branch_op = BranchSQLOperator(
-            task_id="make_choice",
-            conn_id="mysql_default",
-            sql="SELECT 1",
-            follow_task_ids_if_true="branch_1",
-            follow_task_ids_if_false="branch_2",
-            dag=self.dag,
-        )
-
-        branch_op >> self.branch_1 >> self.branch_2
-        branch_op >> self.branch_2
-        self.dag.clear()
-
-        dr = self.dag.create_dagrun(
-            run_id="manual__",
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-
-        mock_get_records = mock_get_db_hook.return_value.get_first
-
-        for true_value in SUPPORTED_TRUE_VALUES:
-            mock_get_records.return_value = [true_value]
-
-            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-            tis = dr.get_task_instances()
-            for ti in tis:
-                if ti.task_id == "make_choice":
-                    assert ti.state == State.SUCCESS
-                elif ti.task_id == "branch_1":
-                    assert ti.state == State.NONE
-                elif ti.task_id == "branch_2":
-                    assert ti.state == State.NONE
-                else:
-                    raise ValueError(f"Invalid task id {ti.task_id} found!")
-
-    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
-    def test_with_skip_in_branch_downstream_dependencies2(self, mock_get_db_hook):
-        """Test skipping downstream dependency for false condition"""
-        branch_op = BranchSQLOperator(
-            task_id="make_choice",
-            conn_id="mysql_default",
-            sql="SELECT 1",
-            follow_task_ids_if_true="branch_1",
-            follow_task_ids_if_false="branch_2",
-            dag=self.dag,
-        )
-
-        branch_op >> self.branch_1 >> self.branch_2
-        branch_op >> self.branch_2
-        self.dag.clear()
-
-        dr = self.dag.create_dagrun(
-            run_id="manual__",
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-
-        mock_get_records = mock_get_db_hook.return_value.get_first
-
-        for false_value in SUPPORTED_FALSE_VALUES:
-            mock_get_records.return_value = [false_value]
-
-            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-            tis = dr.get_task_instances()
-            for ti in tis:
-                if ti.task_id == "make_choice":
-                    assert ti.state == State.SUCCESS
-                elif ti.task_id == "branch_1":
-                    assert ti.state == State.SKIPPED
-                elif ti.task_id == "branch_2":
-                    assert ti.state == State.NONE
-                else:
-                    raise ValueError(f"Invalid task id {ti.task_id} found!")
diff --git a/tests/providers/common/sql/operators/test_sql.py b/tests/providers/common/sql/operators/test_sql.py
index da53e6134f..c6dc3d0757 100644
--- a/tests/providers/common/sql/operators/test_sql.py
+++ b/tests/providers/common/sql/operators/test_sql.py
@@ -15,12 +15,32 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import datetime
+import unittest
+from unittest import mock
+from unittest.mock import MagicMock
 
 import pandas as pd
 import pytest
 
+from airflow import DAG
 from airflow.exceptions import AirflowException
-from airflow.providers.common.sql.operators.sql import SQLColumnCheckOperator, SQLTableCheckOperator
+from airflow.models import Connection, DagRun, TaskInstance as TI, XCom
+from airflow.operators.empty import EmptyOperator
+from airflow.providers.common.sql.operators.sql import (
+    BranchSQLOperator,
+    SQLCheckOperator,
+    SQLColumnCheckOperator,
+    SQLIntervalCheckOperator,
+    SQLTableCheckOperator,
+    SQLThresholdCheckOperator,
+    SQLValueCheckOperator,
+)
+from airflow.providers.postgres.hooks.postgres import PostgresHook
+from airflow.utils import timezone
+from airflow.utils.session import create_session
+from airflow.utils.state import State
+from tests.providers.apache.hive import TestHiveEnvironment
 
 
 class MockHook:
@@ -66,30 +86,30 @@ class TestColumnCheckOperator:
 
     def test_pass_all_checks_exact_check(self, monkeypatch):
         operator = self._construct_operator(monkeypatch, self.valid_column_mapping, (0, 10, 10, 1, 19))
-        operator.execute()
+        operator.execute(context=MagicMock())
 
     def test_max_less_than_fails_check(self, monkeypatch):
         with pytest.raises(AirflowException):
             operator = self._construct_operator(monkeypatch, self.valid_column_mapping, (0, 10, 10, 1, 21))
-            operator.execute()
+            operator.execute(context=MagicMock())
             assert operator.column_mapping["X"]["max"]["success"] is False
 
     def test_max_greater_than_fails_check(self, monkeypatch):
         with pytest.raises(AirflowException):
             operator = self._construct_operator(monkeypatch, self.valid_column_mapping, (0, 10, 10, 1, 9))
-            operator.execute()
+            operator.execute(context=MagicMock())
             assert operator.column_mapping["X"]["max"]["success"] is False
 
     def test_pass_all_checks_inexact_check(self, monkeypatch):
         operator = self._construct_operator(monkeypatch, self.valid_column_mapping, (0, 9, 12, 0, 15))
-        operator.execute()
+        operator.execute(context=MagicMock())
 
     def test_fail_all_checks_check(self, monkeypatch):
         operator = operator = self._construct_operator(
             monkeypatch, self.valid_column_mapping, (1, 12, 11, -1, 20)
         )
         with pytest.raises(AirflowException):
-            operator.execute()
+            operator.execute(context=MagicMock())
 
 
 class TestTableCheckOperator:
@@ -119,7 +139,7 @@ class TestTableCheckOperator:
             }
         )
         operator = self._construct_operator(monkeypatch, self.checks, df)
-        operator.execute()
+        operator.execute(context=MagicMock())
 
     def test_fail_all_checks_check(self, monkeypatch):
         df = pd.DataFrame(
@@ -127,4 +147,744 @@ class TestTableCheckOperator:
         )
         operator = self._construct_operator(monkeypatch, self.checks, df)
         with pytest.raises(AirflowException):
-            operator.execute()
+            operator.execute(context=MagicMock())
+
+
+DEFAULT_DATE = timezone.datetime(2016, 1, 1)
+INTERVAL = datetime.timedelta(hours=12)
+SUPPORTED_TRUE_VALUES = [
+    ["True"],
+    ["true"],
+    ["1"],
+    ["on"],
+    [1],
+    True,
+    "true",
+    "1",
+    "on",
+    1,
+]
+SUPPORTED_FALSE_VALUES = [
+    ["False"],
+    ["false"],
+    ["0"],
+    ["off"],
+    [0],
+    False,
+    "false",
+    "0",
+    "off",
+    0,
+]
+
+
+@mock.patch(
+    'airflow.providers.common.sql.operators.sql.BaseHook.get_connection',
+    return_value=Connection(conn_id='sql_default', conn_type='postgres'),
+)
+class TestSQLCheckOperatorDbHook:
+    def setup_method(self):
+        self.task_id = "test_task"
+        self.conn_id = "sql_default"
+        self._operator = SQLCheckOperator(task_id=self.task_id, conn_id=self.conn_id, sql="sql")
+
+    @pytest.mark.parametrize('database', [None, 'test-db'])
+    def test_get_hook(self, mock_get_conn, database):
+        if database:
+            self._operator.database = database
+        assert isinstance(self._operator._hook, PostgresHook)
+        assert self._operator._hook.schema == database
+        mock_get_conn.assert_called_once_with(self.conn_id)
+
+    def test_not_allowed_conn_type(self, mock_get_conn):
+        mock_get_conn.return_value = Connection(conn_id='sql_default', conn_type='s3')
+        with pytest.raises(AirflowException, match=r"The connection type is not supported"):
+            self._operator._hook
+
+    def test_sql_operator_hook_params_snowflake(self, mock_get_conn):
+        mock_get_conn.return_value = Connection(conn_id='snowflake_default', conn_type='snowflake')
+        self._operator.hook_params = {
+            'warehouse': 'warehouse',
+            'database': 'database',
+            'role': 'role',
+            'schema': 'schema',
+            'log_sql': False,
+        }
+        assert self._operator._hook.conn_type == 'snowflake'
+        assert self._operator._hook.warehouse == 'warehouse'
+        assert self._operator._hook.database == 'database'
+        assert self._operator._hook.role == 'role'
+        assert self._operator._hook.schema == 'schema'
+        assert not self._operator._hook.log_sql
+
+    def test_sql_operator_hook_params_biguery(self, mock_get_conn):
+        mock_get_conn.return_value = Connection(
+            conn_id='google_cloud_bigquery_default', conn_type='gcpbigquery'
+        )
+        self._operator.hook_params = {'use_legacy_sql': True, 'location': 'us-east1'}
+        assert self._operator._hook.conn_type == 'gcpbigquery'
+        assert self._operator._hook.use_legacy_sql
+        assert self._operator._hook.location == 'us-east1'
+
+
+class TestCheckOperator(unittest.TestCase):
+    def setUp(self):
+        self._operator = SQLCheckOperator(task_id="test_task", sql="sql")
+
+    @mock.patch.object(SQLCheckOperator, "get_db_hook")
+    def test_execute_no_records(self, mock_get_db_hook):
+        mock_get_db_hook.return_value.get_first.return_value = []
+
+        with pytest.raises(AirflowException, match=r"The query returned None"):
+            self._operator.execute({})
+
+    @mock.patch.object(SQLCheckOperator, "get_db_hook")
+    def test_execute_not_all_records_are_true(self, mock_get_db_hook):
+        mock_get_db_hook.return_value.get_first.return_value = ["data", ""]
+
+        with pytest.raises(AirflowException, match=r"Test failed."):
+            self._operator.execute({})
+
+
+class TestValueCheckOperator(unittest.TestCase):
+    def setUp(self):
+        self.task_id = "test_task"
+        self.conn_id = "default_conn"
+
+    def _construct_operator(self, sql, pass_value, tolerance=None):
+        dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1))
+
+        return SQLValueCheckOperator(
+            dag=dag,
+            task_id=self.task_id,
+            conn_id=self.conn_id,
+            sql=sql,
+            pass_value=pass_value,
+            tolerance=tolerance,
+        )
+
+    def test_pass_value_template_string(self):
+        pass_value_str = "2018-03-22"
+        operator = self._construct_operator("select date from tab1;", "{{ ds }}")
+
+        operator.render_template_fields({"ds": pass_value_str})
+
+        assert operator.task_id == self.task_id
+        assert operator.pass_value == pass_value_str
+
+    def test_pass_value_template_string_float(self):
+        pass_value_float = 4.0
+        operator = self._construct_operator("select date from tab1;", pass_value_float)
+
+        operator.render_template_fields({})
+
+        assert operator.task_id == self.task_id
+        assert operator.pass_value == str(pass_value_float)
+
+    @mock.patch.object(SQLValueCheckOperator, "get_db_hook")
+    def test_execute_pass(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+        mock_hook.get_first.return_value = [10]
+        mock_get_db_hook.return_value = mock_hook
+        sql = "select value from tab1 limit 1;"
+        operator = self._construct_operator(sql, 5, 1)
+
+        operator.execute(None)
+
+        mock_hook.get_first.assert_called_once_with(sql)
+
+    @mock.patch.object(SQLValueCheckOperator, "get_db_hook")
+    def test_execute_fail(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+        mock_hook.get_first.return_value = [11]
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator("select value from tab1 limit 1;", 5, 1)
+
+        with pytest.raises(AirflowException, match="Tolerance:100.0%"):
+            operator.execute(context=MagicMock())
+
+
+class TestIntervalCheckOperator(unittest.TestCase):
+    def _construct_operator(self, table, metric_thresholds, ratio_formula, ignore_zero):
+        return SQLIntervalCheckOperator(
+            task_id="test_task",
+            table=table,
+            metrics_thresholds=metric_thresholds,
+            ratio_formula=ratio_formula,
+            ignore_zero=ignore_zero,
+        )
+
+    def test_invalid_ratio_formula(self):
+        with pytest.raises(AirflowException, match="Invalid diff_method"):
+            self._construct_operator(
+                table="test_table",
+                metric_thresholds={
+                    "f1": 1,
+                },
+                ratio_formula="abs",
+                ignore_zero=False,
+            )
+
+    @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
+    def test_execute_not_ignore_zero(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+        mock_hook.get_first.return_value = [0]
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator(
+            table="test_table",
+            metric_thresholds={
+                "f1": 1,
+            },
+            ratio_formula="max_over_min",
+            ignore_zero=False,
+        )
+
+        with pytest.raises(AirflowException):
+            operator.execute(context=MagicMock())
+
+    @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
+    def test_execute_ignore_zero(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+        mock_hook.get_first.return_value = [0]
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator(
+            table="test_table",
+            metric_thresholds={
+                "f1": 1,
+            },
+            ratio_formula="max_over_min",
+            ignore_zero=True,
+        )
+
+        operator.execute(context=MagicMock())
+
+    @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
+    def test_execute_min_max(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+
+        def returned_row():
+            rows = [
+                [2, 2, 2, 2],  # reference
+                [1, 1, 1, 1],  # current
+            ]
+
+            yield from rows
+
+        mock_hook.get_first.side_effect = returned_row()
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator(
+            table="test_table",
+            metric_thresholds={
+                "f0": 1.0,
+                "f1": 1.5,
+                "f2": 2.0,
+                "f3": 2.5,
+            },
+            ratio_formula="max_over_min",
+            ignore_zero=True,
+        )
+
+        with pytest.raises(AirflowException, match="f0, f1, f2"):
+            operator.execute(context=MagicMock())
+
+    @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
+    def test_execute_diff(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+
+        def returned_row():
+            rows = [
+                [3, 3, 3, 3],  # reference
+                [1, 1, 1, 1],  # current
+            ]
+
+            yield from rows
+
+        mock_hook.get_first.side_effect = returned_row()
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator(
+            table="test_table",
+            metric_thresholds={
+                "f0": 0.5,
+                "f1": 0.6,
+                "f2": 0.7,
+                "f3": 0.8,
+            },
+            ratio_formula="relative_diff",
+            ignore_zero=True,
+        )
+
+        with pytest.raises(AirflowException, match="f0, f1"):
+            operator.execute(context=MagicMock())
+
+
+class TestThresholdCheckOperator(unittest.TestCase):
+    def _construct_operator(self, sql, min_threshold, max_threshold):
+        dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1))
+
+        return SQLThresholdCheckOperator(
+            task_id="test_task",
+            sql=sql,
+            min_threshold=min_threshold,
+            max_threshold=max_threshold,
+            dag=dag,
+        )
+
+    @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
+    def test_pass_min_value_max_value(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+        mock_hook.get_first.return_value = (10,)
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator("Select avg(val) from table1 limit 1", 1, 100)
+
+        operator.execute(context=MagicMock())
+
+    @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
+    def test_fail_min_value_max_value(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+        mock_hook.get_first.return_value = (10,)
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator("Select avg(val) from table1 limit 1", 20, 100)
+
+        with pytest.raises(AirflowException, match="10.*20.0.*100.0"):
+            operator.execute(context=MagicMock())
+
+    @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
+    def test_pass_min_sql_max_sql(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+        mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator("Select 10", "Select 1", "Select 100")
+
+        operator.execute(context=MagicMock())
+
+    @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
+    def test_fail_min_sql_max_sql(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+        mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator("Select 10", "Select 20", "Select 100")
+
+        with pytest.raises(AirflowException, match="10.*20.*100"):
+            operator.execute(context=MagicMock())
+
+    @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
+    def test_pass_min_value_max_sql(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+        mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator("Select 75", 45, "Select 100")
+
+        operator.execute(context=MagicMock())
+
+    @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
+    def test_fail_min_sql_max_value(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+        mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator("Select 155", "Select 45", 100)
+
+        with pytest.raises(AirflowException, match="155.*45.*100.0"):
+            operator.execute(context=MagicMock())
+
+
+class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
+    """
+    Test for SQL Branch Operator
+    """
+
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+
+        with create_session() as session:
+            session.query(DagRun).delete()
+            session.query(TI).delete()
+            session.query(XCom).delete()
+
+    def setUp(self):
+        super().setUp()
+        self.dag = DAG(
+            "sql_branch_operator_test",
+            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+            schedule_interval=INTERVAL,
+        )
+        self.branch_1 = EmptyOperator(task_id="branch_1", dag=self.dag)
+        self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag)
+        self.branch_3 = None
+
+    def tearDown(self):
+        super().tearDown()
+
+        with create_session() as session:
+            session.query(DagRun).delete()
+            session.query(TI).delete()
+            session.query(XCom).delete()
+
+    def test_unsupported_conn_type(self):
+        """Check if BranchSQLOperator throws an exception for unsupported connection type"""
+        op = BranchSQLOperator(
+            task_id="make_choice",
+            conn_id="redis_default",
+            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        with pytest.raises(AirflowException):
+            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+    def test_invalid_conn(self):
+        """Check if BranchSQLOperator throws an exception for invalid connection"""
+        op = BranchSQLOperator(
+            task_id="make_choice",
+            conn_id="invalid_connection",
+            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        with pytest.raises(AirflowException):
+            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+    def test_invalid_follow_task_true(self):
+        """Check if BranchSQLOperator throws an exception for invalid connection"""
+        op = BranchSQLOperator(
+            task_id="make_choice",
+            conn_id="invalid_connection",
+            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+            follow_task_ids_if_true=None,
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        with pytest.raises(AirflowException):
+            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+    def test_invalid_follow_task_false(self):
+        """Check if BranchSQLOperator throws an exception for invalid connection"""
+        op = BranchSQLOperator(
+            task_id="make_choice",
+            conn_id="invalid_connection",
+            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false=None,
+            dag=self.dag,
+        )
+
+        with pytest.raises(AirflowException):
+            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+    @pytest.mark.backend("mysql")
+    def test_sql_branch_operator_mysql(self):
+        """Check if BranchSQLOperator works with backend"""
+        branch_op = BranchSQLOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+    @pytest.mark.backend("postgres")
+    def test_sql_branch_operator_postgres(self):
+        """Check if BranchSQLOperator works with backend"""
+        branch_op = BranchSQLOperator(
+            task_id="make_choice",
+            conn_id="postgres_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
+    def test_branch_single_value_with_dag_run(self, mock_get_db_hook):
+        """Check BranchSQLOperator branch operation"""
+        branch_op = BranchSQLOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(branch_op)
+        self.branch_2.set_upstream(branch_op)
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_get_records = mock_get_db_hook.return_value.get_first
+
+        mock_get_records.return_value = 1
+
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        tis = dr.get_task_instances()
+        for ti in tis:
+            if ti.task_id == "make_choice":
+                assert ti.state == State.SUCCESS
+            elif ti.task_id == "branch_1":
+                assert ti.state == State.NONE
+            elif ti.task_id == "branch_2":
+                assert ti.state == State.SKIPPED
+            else:
+                raise ValueError(f"Invalid task id {ti.task_id} found!")
+
+    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
+    def test_branch_true_with_dag_run(self, mock_get_db_hook):
+        """Check BranchSQLOperator branch operation"""
+        branch_op = BranchSQLOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(branch_op)
+        self.branch_2.set_upstream(branch_op)
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_get_records = mock_get_db_hook.return_value.get_first
+
+        for true_value in SUPPORTED_TRUE_VALUES:
+            mock_get_records.return_value = true_value
+
+            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+            tis = dr.get_task_instances()
+            for ti in tis:
+                if ti.task_id == "make_choice":
+                    assert ti.state == State.SUCCESS
+                elif ti.task_id == "branch_1":
+                    assert ti.state == State.NONE
+                elif ti.task_id == "branch_2":
+                    assert ti.state == State.SKIPPED
+                else:
+                    raise ValueError(f"Invalid task id {ti.task_id} found!")
+
+    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
+    def test_branch_false_with_dag_run(self, mock_get_db_hook):
+        """Check BranchSQLOperator branch operation"""
+        branch_op = BranchSQLOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(branch_op)
+        self.branch_2.set_upstream(branch_op)
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_get_records = mock_get_db_hook.return_value.get_first
+
+        for false_value in SUPPORTED_FALSE_VALUES:
+            mock_get_records.return_value = false_value
+            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+            tis = dr.get_task_instances()
+            for ti in tis:
+                if ti.task_id == "make_choice":
+                    assert ti.state == State.SUCCESS
+                elif ti.task_id == "branch_1":
+                    assert ti.state == State.SKIPPED
+                elif ti.task_id == "branch_2":
+                    assert ti.state == State.NONE
+                else:
+                    raise ValueError(f"Invalid task id {ti.task_id} found!")
+
+    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
+    def test_branch_list_with_dag_run(self, mock_get_db_hook):
+        """Checks if the BranchSQLOperator supports branching off to a list of tasks."""
+        branch_op = BranchSQLOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true=["branch_1", "branch_2"],
+            follow_task_ids_if_false="branch_3",
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(branch_op)
+        self.branch_2.set_upstream(branch_op)
+        self.branch_3 = EmptyOperator(task_id="branch_3", dag=self.dag)
+        self.branch_3.set_upstream(branch_op)
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_get_records = mock_get_db_hook.return_value.get_first
+        mock_get_records.return_value = [["1"]]
+
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        tis = dr.get_task_instances()
+        for ti in tis:
+            if ti.task_id == "make_choice":
+                assert ti.state == State.SUCCESS
+            elif ti.task_id == "branch_1":
+                assert ti.state == State.NONE
+            elif ti.task_id == "branch_2":
+                assert ti.state == State.NONE
+            elif ti.task_id == "branch_3":
+                assert ti.state == State.SKIPPED
+            else:
+                raise ValueError(f"Invalid task id {ti.task_id} found!")
+
+    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
+    def test_invalid_query_result_with_dag_run(self, mock_get_db_hook):
+        """Check BranchSQLOperator branch operation"""
+        branch_op = BranchSQLOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(branch_op)
+        self.branch_2.set_upstream(branch_op)
+        self.dag.clear()
+
+        self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_get_records = mock_get_db_hook.return_value.get_first
+
+        mock_get_records.return_value = ["Invalid Value"]
+
+        with pytest.raises(AirflowException):
+            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
+    def test_with_skip_in_branch_downstream_dependencies(self, mock_get_db_hook):
+        """Test SQL Branch with skipping all downstream dependencies"""
+        branch_op = BranchSQLOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        branch_op >> self.branch_1 >> self.branch_2
+        branch_op >> self.branch_2
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_get_records = mock_get_db_hook.return_value.get_first
+
+        for true_value in SUPPORTED_TRUE_VALUES:
+            mock_get_records.return_value = [true_value]
+
+            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+            tis = dr.get_task_instances()
+            for ti in tis:
+                if ti.task_id == "make_choice":
+                    assert ti.state == State.SUCCESS
+                elif ti.task_id == "branch_1":
+                    assert ti.state == State.NONE
+                elif ti.task_id == "branch_2":
+                    assert ti.state == State.NONE
+                else:
+                    raise ValueError(f"Invalid task id {ti.task_id} found!")
+
+    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
+    def test_with_skip_in_branch_downstream_dependencies2(self, mock_get_db_hook):
+        """Test skipping downstream dependency for false condition"""
+        branch_op = BranchSQLOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        branch_op >> self.branch_1 >> self.branch_2
+        branch_op >> self.branch_2
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_get_records = mock_get_db_hook.return_value.get_first
+
+        for false_value in SUPPORTED_FALSE_VALUES:
+            mock_get_records.return_value = [false_value]
+
+            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+            tis = dr.get_task_instances()
+            for ti in tis:
+                if ti.task_id == "make_choice":
+                    assert ti.state == State.SUCCESS
+                elif ti.task_id == "branch_1":
+                    assert ti.state == State.SKIPPED
+                elif ti.task_id == "branch_2":
+                    assert ti.state == State.NONE
+                else:
+                    raise ValueError(f"Invalid task id {ti.task_id} found!")
diff --git a/tests/providers/qubole/operators/test_qubole_check.py b/tests/providers/qubole/operators/test_qubole_check.py
index ac8f1edb18..3d2c9cbe1d 100644
--- a/tests/providers/qubole/operators/test_qubole_check.py
+++ b/tests/providers/qubole/operators/test_qubole_check.py
@@ -19,19 +19,19 @@
 import unittest
 from datetime import datetime
 from unittest import mock
+from unittest.mock import MagicMock
 
 import pytest
 from qds_sdk.commands import HiveCommand
 
 from airflow.exceptions import AirflowException
 from airflow.models import DAG
+from airflow.providers.common.sql.operators.sql import SQLCheckOperator, SQLValueCheckOperator
 from airflow.providers.qubole.hooks.qubole import QuboleHook
 from airflow.providers.qubole.hooks.qubole_check import QuboleCheckHook
 from airflow.providers.qubole.operators.qubole_check import (
     QuboleCheckOperator,
     QuboleValueCheckOperator,
-    SQLCheckOperator,
-    SQLValueCheckOperator,
     _QuboleCheckOperatorMixin,
 )
 
@@ -80,7 +80,7 @@ class TestQuboleCheckMixin:
         operator = self.__construct_operator(operator_class=operator_class, **kwargs)
 
         with mock.patch.object(parent_check_operator, 'execute') as mock_execute:
-            operator.execute()
+            operator.execute(context=MagicMock())
             mock_execute.assert_called_once()
 
     @mock.patch('airflow.providers.qubole.operators.qubole_check.handle_airflow_exception')
@@ -89,7 +89,7 @@ class TestQuboleCheckMixin:
 
         with mock.patch.object(parent_check_operator, 'execute') as mock_execute:
             mock_execute.side_effect = AirflowException()
-            operator.execute()
+            operator.execute(context=MagicMock())
             mock_execute.assert_called_once()
             mock_handle_airflow_exception.assert_called_once()
 
@@ -153,7 +153,7 @@ class TestQuboleValueCheckOperator(unittest.TestCase):
         operator = self.__construct_operator('select value from tab1 limit 1;', 5, 1)
 
         with pytest.raises(AirflowException, match='Qubole Command Id: ' + str(mock_cmd.id)):
-            operator.execute()
+            operator.execute(context=MagicMock())
 
         mock_cmd.is_success.assert_called_once_with(mock_cmd.status)
 
@@ -173,7 +173,7 @@ class TestQuboleValueCheckOperator(unittest.TestCase):
         operator = self.__construct_operator('select value from tab1 limit 1;', 5, 1)
 
         with pytest.raises(AirflowException) as ctx:
-            operator.execute()
+            operator.execute(context=MagicMock())
 
         assert 'Qubole Command Id: ' not in str(ctx.value)
         mock_cmd.is_success.assert_called_once_with(mock_cmd.status)
@@ -193,5 +193,5 @@ class TestQuboleValueCheckOperator(unittest.TestCase):
         operator = self.__construct_operator(
             'select value from tab1 limit 1;', pass_value, None, results_parser_callable
         )
-        operator.execute()
+        operator.execute(context=MagicMock())
         results_parser_callable.assert_called_once_with([pass_value])
diff --git a/tests/providers/slack/operators/test_slack.py b/tests/providers/slack/operators/test_slack.py
index 8b8ac66441..495dc7f6ba 100644
--- a/tests/providers/slack/operators/test_slack.py
+++ b/tests/providers/slack/operators/test_slack.py
@@ -15,10 +15,10 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
 import json
 import unittest
 from unittest import mock
+from unittest.mock import MagicMock
 
 from airflow.providers.slack.operators.slack import SlackAPIFileOperator, SlackAPIPostOperator
 
@@ -119,7 +119,7 @@ class TestSlackAPIPostOperator(unittest.TestCase):
             slack_conn_id=test_slack_conn_id,
         )
 
-        slack_api_post_operator.execute()
+        slack_api_post_operator.execute(context=MagicMock())
 
         expected_api_params = {
             'channel': "#general",
@@ -195,7 +195,7 @@ class TestSlackAPIFileOperator(unittest.TestCase):
             task_id='slack', slack_conn_id=test_slack_conn_id, content='test-content'
         )
 
-        slack_api_post_operator.execute()
+        slack_api_post_operator.execute(context=MagicMock())
 
         expected_api_params = {
             'channels': '#general',
@@ -221,7 +221,7 @@ class TestSlackAPIFileOperator(unittest.TestCase):
             task_id='slack', slack_conn_id=test_slack_conn_id, filename=file_path, filetype='csv'
         )
 
-        slack_api_post_operator.execute()
+        slack_api_post_operator.execute(context=MagicMock())
 
         expected_api_params = {
             'channels': '#general',
diff --git a/tests/providers/slack/transfers/test_sql_to_slack.py b/tests/providers/slack/transfers/test_sql_to_slack.py
index 0390a56b18..23b791fb7b 100644
--- a/tests/providers/slack/transfers/test_sql_to_slack.py
+++ b/tests/providers/slack/transfers/test_sql_to_slack.py
@@ -149,7 +149,7 @@ class TestSqlToSlackOperator:
         # Test that the Slack hook's execute method gets run once
         slack_webhook_hook.execute.assert_called_once()
 
-    @mock.patch('airflow.operators.sql.BaseHook.get_connection')
+    @mock.patch('airflow.providers.common.sql.operators.sql.BaseHook.get_connection')
     def test_hook_params_building(self, mock_get_conn):
         mock_get_conn.return_value = Connection(conn_id='snowflake_connection', conn_type='snowflake')
         hook_params = {
@@ -172,7 +172,7 @@ class TestSqlToSlackOperator:
 
         assert sql_to_slack_operator.sql_hook_params == hook_params
 
-    @mock.patch('airflow.operators.sql.BaseHook.get_connection')
+    @mock.patch('airflow.providers.common.sql.operators.sql.BaseHook.get_connection')
     def test_hook_params(self, mock_get_conn):
         mock_get_conn.return_value = Connection(conn_id='postgres_test', conn_type='postgres')
         op = SqlToSlackOperator(
@@ -188,7 +188,7 @@ class TestSqlToSlackOperator:
         hook = op._get_hook()
         assert hook.schema == 'public'
 
-    @mock.patch('airflow.operators.sql.BaseHook.get_connection')
+    @mock.patch('airflow.providers.common.sql.operators.sql.BaseHook.get_connection')
     def test_hook_params_snowflake(self, mock_get_conn):
         mock_get_conn.return_value = Connection(conn_id='snowflake_default', conn_type='snowflake')
         op = SqlToSlackOperator(