You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by di...@apache.org on 2020/06/18 17:25:59 UTC

[airflow] branch v1-10-test updated (d480352 -> 2bdf471)

This is an automated email from the ASF dual-hosted git repository.

dimberman pushed a change to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git.


    omit d480352  Add SQL Branch Operator
     new 2bdf471  Add SQL Branch Operator

This update added new revisions after undoing existing revisions.
That is to say, some revisions that were in the old version of the
branch are not in the new version.  This situation occurs
when a user --force pushes a change and generates a repository
containing something like this:

 * -- * -- B -- O -- O -- O   (d480352)
            \
             N -- N -- N   refs/heads/v1-10-test (2bdf471)

You should already have received notification emails for all of the O
revisions, and so the following emails describe only the N revisions
from the common base, B.

Any revisions marked "omit" are not gone; other references still
refer to them.  Any revisions marked "discard" are gone forever.

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 tests/operators/test_sql_branch_operator.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)


[airflow] 01/01: Add SQL Branch Operator

Posted by di...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

dimberman pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 2bdf471ed55976323cdb67ad82467994a6280c45
Author: samuelkhtu <46...@users.noreply.github.com>
AuthorDate: Mon Jun 1 14:14:13 2020 -0400

    Add SQL Branch Operator
    
    SQL Branch Operator allow user to execute a SQL query in any supported backend to decide which
    branch to follow. The SQL branch operator expect query to return True/False (Boolean) or
    0/1 (Integer) or true/y/yes/1/on/false/n/no/0/off (String).
    
    (cherry picked from commit 55b9b8f6456a7a46921b0bf7a893c7f08bf8237c)
---
 airflow/operators/sql_branch_operator.py    | 173 ++++++++++
 docs/operators-and-hooks-ref.rst            |  81 +++++
 tests/operators/test_sql_branch_operator.py | 476 ++++++++++++++++++++++++++++
 3 files changed, 730 insertions(+)

diff --git a/airflow/operators/sql_branch_operator.py b/airflow/operators/sql_branch_operator.py
new file mode 100644
index 0000000..072c40c
--- /dev/null
+++ b/airflow/operators/sql_branch_operator.py
@@ -0,0 +1,173 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+ALLOWED_CONN_TYPE = {
+    "google_cloud_platform",
+    "jdbc",
+    "mssql",
+    "mysql",
+    "odbc",
+    "oracle",
+    "postgres",
+    "presto",
+    "sqlite",
+    "vertica",
+}
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+    :param database: name of database which overwrite defined one in connection
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+            self,
+            sql,
+            follow_task_ids_if_true,
+            follow_task_ids_if_false,
+            conn_id="default_conn_id",
+            database=None,
+            parameters=None,
+            *args,
+            **kwargs):
+        super(BranchSqlOperator, self).__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        if conn.conn_type not in ALLOWED_CONN_TYPE:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(ALLOWED_CONN_TYPE))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context):
+        # get supported hook
+        self._hook = self._get_hook()
+
+        if self._hook is None:
+            raise AirflowException(
+                "Failed to establish connection to '%s'" % self.conn_id
+            )
+
+        if self.sql is None:
+            raise AirflowException("Expected 'sql' parameter is missing.")
+
+        if self.follow_task_ids_if_true is None:
+            raise AirflowException(
+                "Expected 'follow_task_ids_if_true' paramter is missing."
+            )
+
+        if self.follow_task_ids_if_false is None:
+            raise AirflowException(
+                "Expected 'follow_task_ids_if_false' parameter is missing."
+            )
+
+        self.log.info(
+            "Executing: %s (with parameters %s) with connection: %s",
+            self.sql,
+            self.parameters,
+            self._hook,
+        )
+        record = self._hook.get_first(self.sql, self.parameters)
+        if not record:
+            raise AirflowException(
+                "No rows returned from sql query. Operator expected True or False return value."
+            )
+
+        if isinstance(record, list):
+            if isinstance(record[0], list):
+                query_result = record[0][0]
+            else:
+                query_result = record[0]
+        elif isinstance(record, tuple):
+            query_result = record[0]
+        else:
+            query_result = record
+
+        self.log.info("Query returns %s, type '%s'", query_result, type(query_result))
+
+        follow_branch = None
+        try:
+            if isinstance(query_result, bool):
+                if query_result:
+                    follow_branch = self.follow_task_ids_if_true
+            elif isinstance(query_result, str):
+                # return result is not Boolean, try to convert from String to Boolean
+                if bool(strtobool(query_result)):
+                    follow_branch = self.follow_task_ids_if_true
+            elif isinstance(query_result, int):
+                if bool(query_result):
+                    follow_branch = self.follow_task_ids_if_true
+            else:
+                raise AirflowException(
+                    "Unexpected query return result '%s' type '%s'"
+                    % (query_result, type(query_result))
+                )
+
+            if follow_branch is None:
+                follow_branch = self.follow_task_ids_if_false
+        except ValueError:
+            raise AirflowException(
+                "Unexpected query return result '%s' type '%s'"
+                % (query_result, type(query_result))
+            )
+
+        self.skip_all_except(context["ti"], follow_branch)
diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst
index 6c80858..55176f8 100644
--- a/docs/operators-and-hooks-ref.rst
+++ b/docs/operators-and-hooks-ref.rst
@@ -22,6 +22,87 @@ Operators and Hooks Reference
   :local:
   :depth: 1
 
+.. _fundamentals:
+
+Fundamentals
+------------
+
+**Base:**
+
+.. list-table::
+   :header-rows: 1
+
+   * - Module
+     - Guides
+
+   * - :mod:`airflow.hooks.base_hook`
+     -
+
+   * - :mod:`airflow.hooks.dbapi_hook`
+     -
+
+   * - :mod:`airflow.models.baseoperator`
+     -
+
+   * - :mod:`airflow.sensors.base_sensor_operator`
+     -
+
+**Operators:**
+
+.. list-table::
+   :header-rows: 1
+
+   * - Operators
+     - Guides
+
+   * - :mod:`airflow.operators.branch_operator`
+     -
+
+   * - :mod:`airflow.operators.check_operator`
+     -
+
+   * - :mod:`airflow.operators.dagrun_operator`
+     -
+
+   * - :mod:`airflow.operators.dummy_operator`
+     -
+
+   * - :mod:`airflow.operators.generic_transfer`
+     -
+
+   * - :mod:`airflow.operators.latest_only_operator`
+     -
+
+   * - :mod:`airflow.operators.subdag_operator`
+     -
+
+   * - :mod:`airflow.operators.sql_branch_operator`
+     -
+
+**Sensors:**
+
+.. list-table::
+   :header-rows: 1
+
+   * - Sensors
+     - Guides
+
+   * - :mod:`airflow.sensors.weekday_sensor`
+     -
+
+   * - :mod:`airflow.sensors.external_task_sensor`
+     - :doc:`How to use <howto/operator/external_task_sensor>`
+
+   * - :mod:`airflow.sensors.sql_sensor`
+     -
+
+   * - :mod:`airflow.sensors.time_delta_sensor`
+     -
+
+   * - :mod:`airflow.sensors.time_sensor`
+     -
+
+
 .. _Apache:
 
 ASF: Apache Software Foundation
diff --git a/tests/operators/test_sql_branch_operator.py b/tests/operators/test_sql_branch_operator.py
new file mode 100644
index 0000000..6510609
--- /dev/null
+++ b/tests/operators/test_sql_branch_operator.py
@@ -0,0 +1,476 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import unittest
+from tests.compat import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.models import DAG, DagRun, TaskInstance as TI
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.operators.sql_branch_operator import BranchSqlOperator
+from airflow.utils import timezone
+from airflow.utils.db import create_session
+from airflow.utils.state import State
+from tests.hooks.test_hive_hook import TestHiveEnvironment
+
+DEFAULT_DATE = timezone.datetime(2016, 1, 1)
+INTERVAL = datetime.timedelta(hours=12)
+
+SUPPORTED_TRUE_VALUES = [
+    ["True"],
+    ["true"],
+    ["1"],
+    ["on"],
+    [1],
+    True,
+    "true",
+    "1",
+    "on",
+    1,
+]
+SUPPORTED_FALSE_VALUES = [
+    ["False"],
+    ["false"],
+    ["0"],
+    ["off"],
+    [0],
+    False,
+    "false",
+    "0",
+    "off",
+    0,
+]
+
+
+class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
+    """
+    Test for SQL Branch Operator
+    """
+
+    @classmethod
+    def setUpClass(cls):
+        super(TestSqlBranch, cls).setUpClass()
+
+        with create_session() as session:
+            session.query(DagRun).delete()
+            session.query(TI).delete()
+
+    def setUp(self):
+        super(TestSqlBranch, self).setUp()
+        self.dag = DAG(
+            "sql_branch_operator_test",
+            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+            schedule_interval=INTERVAL,
+        )
+        self.branch_1 = DummyOperator(task_id="branch_1", dag=self.dag)
+        self.branch_2 = DummyOperator(task_id="branch_2", dag=self.dag)
+        self.branch_3 = None
+
+    def tearDown(self):
+        super(TestSqlBranch, self).tearDown()
+
+        with create_session() as session:
+            session.query(DagRun).delete()
+            session.query(TI).delete()
+
+    def test_unsupported_conn_type(self):
+        """ Check if BranchSqlOperator throws an exception for unsupported connection type """
+        op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="redis_default",
+            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        with self.assertRaises(AirflowException):
+            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+    def test_invalid_conn(self):
+        """ Check if BranchSqlOperator throws an exception for invalid connection """
+        op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="invalid_connection",
+            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        with self.assertRaises(AirflowException):
+            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+    def test_invalid_follow_task_true(self):
+        """ Check if BranchSqlOperator throws an exception for invalid connection """
+        op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="invalid_connection",
+            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+            follow_task_ids_if_true=None,
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        with self.assertRaises(AirflowException):
+            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+    def test_invalid_follow_task_false(self):
+        """ Check if BranchSqlOperator throws an exception for invalid connection """
+        op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="invalid_connection",
+            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false=None,
+            dag=self.dag,
+        )
+
+        with self.assertRaises(AirflowException):
+            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+    @pytest.mark.backend("mysql")
+    def test_sql_branch_operator_mysql(self):
+        """ Check if BranchSqlOperator works with backend """
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+        branch_op.run(
+            start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
+        )
+
+    @pytest.mark.backend("postgres")
+    def test_sql_branch_operator_postgres(self):
+        """ Check if BranchSqlOperator works with backend """
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="postgres_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+        branch_op.run(
+            start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
+        )
+
+    @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+    def test_branch_single_value_with_dag_run(self, mock_hook):
+        """ Check BranchSqlOperator branch operation """
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(branch_op)
+        self.branch_2.set_upstream(branch_op)
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_hook.get_connection("mysql_default").conn_type = "mysql"
+        mock_get_records = (
+            mock_hook.get_connection.return_value.get_hook.return_value.get_first
+        )
+
+        mock_get_records.return_value = 1
+
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        tis = dr.get_task_instances()
+        for ti in tis:
+            if ti.task_id == "make_choice":
+                self.assertEqual(ti.state, State.SUCCESS)
+            elif ti.task_id == "branch_1":
+                self.assertEqual(ti.state, State.NONE)
+            elif ti.task_id == "branch_2":
+                self.assertEqual(ti.state, State.SKIPPED)
+            else:
+                raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
+
+    @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+    def test_branch_true_with_dag_run(self, mock_hook):
+        """ Check BranchSqlOperator branch operation """
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(branch_op)
+        self.branch_2.set_upstream(branch_op)
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_hook.get_connection("mysql_default").conn_type = "mysql"
+        mock_get_records = (
+            mock_hook.get_connection.return_value.get_hook.return_value.get_first
+        )
+
+        for true_value in SUPPORTED_TRUE_VALUES:
+            mock_get_records.return_value = true_value
+
+            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+            tis = dr.get_task_instances()
+            for ti in tis:
+                if ti.task_id == "make_choice":
+                    self.assertEqual(ti.state, State.SUCCESS)
+                elif ti.task_id == "branch_1":
+                    self.assertEqual(ti.state, State.NONE)
+                elif ti.task_id == "branch_2":
+                    self.assertEqual(ti.state, State.SKIPPED)
+                else:
+                    raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
+
+    @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+    def test_branch_false_with_dag_run(self, mock_hook):
+        """ Check BranchSqlOperator branch operation """
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(branch_op)
+        self.branch_2.set_upstream(branch_op)
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_hook.get_connection("mysql_default").conn_type = "mysql"
+        mock_get_records = (
+            mock_hook.get_connection.return_value.get_hook.return_value.get_first
+        )
+
+        for false_value in SUPPORTED_FALSE_VALUES:
+            mock_get_records.return_value = false_value
+
+            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+            tis = dr.get_task_instances()
+            for ti in tis:
+                if ti.task_id == "make_choice":
+                    self.assertEqual(ti.state, State.SUCCESS)
+                elif ti.task_id == "branch_1":
+                    self.assertEqual(ti.state, State.SKIPPED)
+                elif ti.task_id == "branch_2":
+                    self.assertEqual(ti.state, State.NONE)
+                else:
+                    raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
+
+    @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+    def test_branch_list_with_dag_run(self, mock_hook):
+        """ Checks if the BranchSqlOperator supports branching off to a list of tasks."""
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true=["branch_1", "branch_2"],
+            follow_task_ids_if_false="branch_3",
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(branch_op)
+        self.branch_2.set_upstream(branch_op)
+        self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag)
+        self.branch_3.set_upstream(branch_op)
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_hook.get_connection("mysql_default").conn_type = "mysql"
+        mock_get_records = (
+            mock_hook.get_connection.return_value.get_hook.return_value.get_first
+        )
+        mock_get_records.return_value = [["1"]]
+
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        tis = dr.get_task_instances()
+        for ti in tis:
+            if ti.task_id == "make_choice":
+                self.assertEqual(ti.state, State.SUCCESS)
+            elif ti.task_id == "branch_1":
+                self.assertEqual(ti.state, State.NONE)
+            elif ti.task_id == "branch_2":
+                self.assertEqual(ti.state, State.NONE)
+            elif ti.task_id == "branch_3":
+                self.assertEqual(ti.state, State.SKIPPED)
+            else:
+                raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
+
+    @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+    def test_invalid_query_result_with_dag_run(self, mock_hook):
+        """ Check BranchSqlOperator branch operation """
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(branch_op)
+        self.branch_2.set_upstream(branch_op)
+        self.dag.clear()
+
+        self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_hook.get_connection("mysql_default").conn_type = "mysql"
+        mock_get_records = (
+            mock_hook.get_connection.return_value.get_hook.return_value.get_first
+        )
+
+        mock_get_records.return_value = ["Invalid Value"]
+
+        with self.assertRaises(AirflowException):
+            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+    @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+    def test_with_skip_in_branch_downstream_dependencies(self, mock_hook):
+        """ Test SQL Branch with skipping all downstream dependencies """
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        branch_op >> self.branch_1 >> self.branch_2
+        branch_op >> self.branch_2
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_hook.get_connection("mysql_default").conn_type = "mysql"
+        mock_get_records = (
+            mock_hook.get_connection.return_value.get_hook.return_value.get_first
+        )
+
+        for true_value in SUPPORTED_TRUE_VALUES:
+            mock_get_records.return_value = [true_value]
+
+            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+            tis = dr.get_task_instances()
+            for ti in tis:
+                if ti.task_id == "make_choice":
+                    self.assertEqual(ti.state, State.SUCCESS)
+                elif ti.task_id == "branch_1":
+                    self.assertEqual(ti.state, State.NONE)
+                elif ti.task_id == "branch_2":
+                    self.assertEqual(ti.state, State.NONE)
+                else:
+                    raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
+
+    @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+    def test_with_skip_in_branch_downstream_dependencies2(self, mock_hook):
+        """ Test skipping downstream dependency for false condition"""
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        branch_op >> self.branch_1 >> self.branch_2
+        branch_op >> self.branch_2
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_hook.get_connection("mysql_default").conn_type = "mysql"
+        mock_get_records = (
+            mock_hook.get_connection.return_value.get_hook.return_value.get_first
+        )
+
+        for false_value in SUPPORTED_FALSE_VALUES:
+            mock_get_records.return_value = [false_value]
+
+            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+            tis = dr.get_task_instances()
+            for ti in tis:
+                if ti.task_id == "make_choice":
+                    self.assertEqual(ti.state, State.SUCCESS)
+                elif ti.task_id == "branch_1":
+                    self.assertEqual(ti.state, State.SKIPPED)
+                elif ti.task_id == "branch_2":
+                    self.assertEqual(ti.state, State.NONE)
+                else:
+                    raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))