You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by di...@apache.org on 2020/06/18 16:57:08 UTC

[airflow] 01/01: Add SQL Branch Operator

This is an automated email from the ASF dual-hosted git repository.

dimberman pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit f837c5937a4fd4d0a2b33516c60a0e8241b7640b
Author: samuelkhtu <46...@users.noreply.github.com>
AuthorDate: Mon Jun 1 14:14:13 2020 -0400

    Add SQL Branch Operator
    
    SQL Branch Operator allow user to execute a SQL query in any supported backend to decide which
    branch to follow. The SQL branch operator expect query to return True/False (Boolean) or
    0/1 (Integer) or true/y/yes/1/on/false/n/no/0/off (String).
    
    (cherry picked from commit 55b9b8f6456a7a46921b0bf7a893c7f08bf8237c)
---
 airflow/operators/sql_branch_operator.py    | 173 ++++++++++
 docs/operators-and-hooks-ref.rst            |  81 +++++
 tests/operators/test_sql_branch_operator.py | 476 ++++++++++++++++++++++++++++
 3 files changed, 730 insertions(+)

diff --git a/airflow/operators/sql_branch_operator.py b/airflow/operators/sql_branch_operator.py
new file mode 100644
index 0000000..072c40c
--- /dev/null
+++ b/airflow/operators/sql_branch_operator.py
@@ -0,0 +1,173 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+ALLOWED_CONN_TYPE = {
+    "google_cloud_platform",
+    "jdbc",
+    "mssql",
+    "mysql",
+    "odbc",
+    "oracle",
+    "postgres",
+    "presto",
+    "sqlite",
+    "vertica",
+}
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+    :param database: name of database which overwrite defined one in connection
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+            self,
+            sql,
+            follow_task_ids_if_true,
+            follow_task_ids_if_false,
+            conn_id="default_conn_id",
+            database=None,
+            parameters=None,
+            *args,
+            **kwargs):
+        super(BranchSqlOperator, self).__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        if conn.conn_type not in ALLOWED_CONN_TYPE:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(ALLOWED_CONN_TYPE))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context):
+        # get supported hook
+        self._hook = self._get_hook()
+
+        if self._hook is None:
+            raise AirflowException(
+                "Failed to establish connection to '%s'" % self.conn_id
+            )
+
+        if self.sql is None:
+            raise AirflowException("Expected 'sql' parameter is missing.")
+
+        if self.follow_task_ids_if_true is None:
+            raise AirflowException(
+                "Expected 'follow_task_ids_if_true' paramter is missing."
+            )
+
+        if self.follow_task_ids_if_false is None:
+            raise AirflowException(
+                "Expected 'follow_task_ids_if_false' parameter is missing."
+            )
+
+        self.log.info(
+            "Executing: %s (with parameters %s) with connection: %s",
+            self.sql,
+            self.parameters,
+            self._hook,
+        )
+        record = self._hook.get_first(self.sql, self.parameters)
+        if not record:
+            raise AirflowException(
+                "No rows returned from sql query. Operator expected True or False return value."
+            )
+
+        if isinstance(record, list):
+            if isinstance(record[0], list):
+                query_result = record[0][0]
+            else:
+                query_result = record[0]
+        elif isinstance(record, tuple):
+            query_result = record[0]
+        else:
+            query_result = record
+
+        self.log.info("Query returns %s, type '%s'", query_result, type(query_result))
+
+        follow_branch = None
+        try:
+            if isinstance(query_result, bool):
+                if query_result:
+                    follow_branch = self.follow_task_ids_if_true
+            elif isinstance(query_result, str):
+                # return result is not Boolean, try to convert from String to Boolean
+                if bool(strtobool(query_result)):
+                    follow_branch = self.follow_task_ids_if_true
+            elif isinstance(query_result, int):
+                if bool(query_result):
+                    follow_branch = self.follow_task_ids_if_true
+            else:
+                raise AirflowException(
+                    "Unexpected query return result '%s' type '%s'"
+                    % (query_result, type(query_result))
+                )
+
+            if follow_branch is None:
+                follow_branch = self.follow_task_ids_if_false
+        except ValueError:
+            raise AirflowException(
+                "Unexpected query return result '%s' type '%s'"
+                % (query_result, type(query_result))
+            )
+
+        self.skip_all_except(context["ti"], follow_branch)
diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst
index 6c80858..55176f8 100644
--- a/docs/operators-and-hooks-ref.rst
+++ b/docs/operators-and-hooks-ref.rst
@@ -22,6 +22,87 @@ Operators and Hooks Reference
   :local:
   :depth: 1
 
+.. _fundamentals:
+
+Fundamentals
+------------
+
+**Base:**
+
+.. list-table::
+   :header-rows: 1
+
+   * - Module
+     - Guides
+
+   * - :mod:`airflow.hooks.base_hook`
+     -
+
+   * - :mod:`airflow.hooks.dbapi_hook`
+     -
+
+   * - :mod:`airflow.models.baseoperator`
+     -
+
+   * - :mod:`airflow.sensors.base_sensor_operator`
+     -
+
+**Operators:**
+
+.. list-table::
+   :header-rows: 1
+
+   * - Operators
+     - Guides
+
+   * - :mod:`airflow.operators.branch_operator`
+     -
+
+   * - :mod:`airflow.operators.check_operator`
+     -
+
+   * - :mod:`airflow.operators.dagrun_operator`
+     -
+
+   * - :mod:`airflow.operators.dummy_operator`
+     -
+
+   * - :mod:`airflow.operators.generic_transfer`
+     -
+
+   * - :mod:`airflow.operators.latest_only_operator`
+     -
+
+   * - :mod:`airflow.operators.subdag_operator`
+     -
+
+   * - :mod:`airflow.operators.sql_branch_operator`
+     -
+
+**Sensors:**
+
+.. list-table::
+   :header-rows: 1
+
+   * - Sensors
+     - Guides
+
+   * - :mod:`airflow.sensors.weekday_sensor`
+     -
+
+   * - :mod:`airflow.sensors.external_task_sensor`
+     - :doc:`How to use <howto/operator/external_task_sensor>`
+
+   * - :mod:`airflow.sensors.sql_sensor`
+     -
+
+   * - :mod:`airflow.sensors.time_delta_sensor`
+     -
+
+   * - :mod:`airflow.sensors.time_sensor`
+     -
+
+
 .. _Apache:
 
 ASF: Apache Software Foundation
diff --git a/tests/operators/test_sql_branch_operator.py b/tests/operators/test_sql_branch_operator.py
new file mode 100644
index 0000000..b7a0885
--- /dev/null
+++ b/tests/operators/test_sql_branch_operator.py
@@ -0,0 +1,476 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import unittest
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.models import DAG, DagRun, TaskInstance as TI
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.operators.sql_branch_operator import BranchSqlOperator
+from airflow.utils import timezone
+from airflow.utils.session import create_session
+from airflow.utils.state import State
+from tests.providers.apache.hive import TestHiveEnvironment
+
+DEFAULT_DATE = timezone.datetime(2016, 1, 1)
+INTERVAL = datetime.timedelta(hours=12)
+
+SUPPORTED_TRUE_VALUES = [
+    ["True"],
+    ["true"],
+    ["1"],
+    ["on"],
+    [1],
+    True,
+    "true",
+    "1",
+    "on",
+    1,
+]
+SUPPORTED_FALSE_VALUES = [
+    ["False"],
+    ["false"],
+    ["0"],
+    ["off"],
+    [0],
+    False,
+    "false",
+    "0",
+    "off",
+    0,
+]
+
+
+class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
+    """
+    Test for SQL Branch Operator
+    """
+
+    @classmethod
+    def setUpClass(cls):
+        super(TestSqlBranch, cls).setUpClass()
+
+        with create_session() as session:
+            session.query(DagRun).delete()
+            session.query(TI).delete()
+
+    def setUp(self):
+        super(TestSqlBranch, self).setUp()
+        self.dag = DAG(
+            "sql_branch_operator_test",
+            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+            schedule_interval=INTERVAL,
+        )
+        self.branch_1 = DummyOperator(task_id="branch_1", dag=self.dag)
+        self.branch_2 = DummyOperator(task_id="branch_2", dag=self.dag)
+        self.branch_3 = None
+
+    def tearDown(self):
+        super(TestSqlBranch, self).tearDown()
+
+        with create_session() as session:
+            session.query(DagRun).delete()
+            session.query(TI).delete()
+
+    def test_unsupported_conn_type(self):
+        """ Check if BranchSqlOperator throws an exception for unsupported connection type """
+        op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="redis_default",
+            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        with self.assertRaises(AirflowException):
+            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+    def test_invalid_conn(self):
+        """ Check if BranchSqlOperator throws an exception for invalid connection """
+        op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="invalid_connection",
+            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        with self.assertRaises(AirflowException):
+            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+    def test_invalid_follow_task_true(self):
+        """ Check if BranchSqlOperator throws an exception for invalid connection """
+        op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="invalid_connection",
+            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+            follow_task_ids_if_true=None,
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        with self.assertRaises(AirflowException):
+            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+    def test_invalid_follow_task_false(self):
+        """ Check if BranchSqlOperator throws an exception for invalid connection """
+        op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="invalid_connection",
+            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false=None,
+            dag=self.dag,
+        )
+
+        with self.assertRaises(AirflowException):
+            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+    @pytest.mark.backend("mysql")
+    def test_sql_branch_operator_mysql(self):
+        """ Check if BranchSqlOperator works with backend """
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+        branch_op.run(
+            start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
+        )
+
+    @pytest.mark.backend("postgres")
+    def test_sql_branch_operator_postgres(self):
+        """ Check if BranchSqlOperator works with backend """
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="postgres_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+        branch_op.run(
+            start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
+        )
+
+    @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+    def test_branch_single_value_with_dag_run(self, mock_hook):
+        """ Check BranchSqlOperator branch operation """
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(branch_op)
+        self.branch_2.set_upstream(branch_op)
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_hook.get_connection("mysql_default").conn_type = "mysql"
+        mock_get_records = (
+            mock_hook.get_connection.return_value.get_hook.return_value.get_first
+        )
+
+        mock_get_records.return_value = 1
+
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        tis = dr.get_task_instances()
+        for ti in tis:
+            if ti.task_id == "make_choice":
+                self.assertEqual(ti.state, State.SUCCESS)
+            elif ti.task_id == "branch_1":
+                self.assertEqual(ti.state, State.NONE)
+            elif ti.task_id == "branch_2":
+                self.assertEqual(ti.state, State.SKIPPED)
+            else:
+                raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
+
+    @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+    def test_branch_true_with_dag_run(self, mock_hook):
+        """ Check BranchSqlOperator branch operation """
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(branch_op)
+        self.branch_2.set_upstream(branch_op)
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_hook.get_connection("mysql_default").conn_type = "mysql"
+        mock_get_records = (
+            mock_hook.get_connection.return_value.get_hook.return_value.get_first
+        )
+
+        for true_value in SUPPORTED_TRUE_VALUES:
+            mock_get_records.return_value = true_value
+
+            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+            tis = dr.get_task_instances()
+            for ti in tis:
+                if ti.task_id == "make_choice":
+                    self.assertEqual(ti.state, State.SUCCESS)
+                elif ti.task_id == "branch_1":
+                    self.assertEqual(ti.state, State.NONE)
+                elif ti.task_id == "branch_2":
+                    self.assertEqual(ti.state, State.SKIPPED)
+                else:
+                    raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
+
+    @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+    def test_branch_false_with_dag_run(self, mock_hook):
+        """ Check BranchSqlOperator branch operation """
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(branch_op)
+        self.branch_2.set_upstream(branch_op)
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_hook.get_connection("mysql_default").conn_type = "mysql"
+        mock_get_records = (
+            mock_hook.get_connection.return_value.get_hook.return_value.get_first
+        )
+
+        for false_value in SUPPORTED_FALSE_VALUES:
+            mock_get_records.return_value = false_value
+
+            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+            tis = dr.get_task_instances()
+            for ti in tis:
+                if ti.task_id == "make_choice":
+                    self.assertEqual(ti.state, State.SUCCESS)
+                elif ti.task_id == "branch_1":
+                    self.assertEqual(ti.state, State.SKIPPED)
+                elif ti.task_id == "branch_2":
+                    self.assertEqual(ti.state, State.NONE)
+                else:
+                    raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
+
+    @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+    def test_branch_list_with_dag_run(self, mock_hook):
+        """ Checks if the BranchSqlOperator supports branching off to a list of tasks."""
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true=["branch_1", "branch_2"],
+            follow_task_ids_if_false="branch_3",
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(branch_op)
+        self.branch_2.set_upstream(branch_op)
+        self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag)
+        self.branch_3.set_upstream(branch_op)
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_hook.get_connection("mysql_default").conn_type = "mysql"
+        mock_get_records = (
+            mock_hook.get_connection.return_value.get_hook.return_value.get_first
+        )
+        mock_get_records.return_value = [["1"]]
+
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        tis = dr.get_task_instances()
+        for ti in tis:
+            if ti.task_id == "make_choice":
+                self.assertEqual(ti.state, State.SUCCESS)
+            elif ti.task_id == "branch_1":
+                self.assertEqual(ti.state, State.NONE)
+            elif ti.task_id == "branch_2":
+                self.assertEqual(ti.state, State.NONE)
+            elif ti.task_id == "branch_3":
+                self.assertEqual(ti.state, State.SKIPPED)
+            else:
+                raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
+
+    @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+    def test_invalid_query_result_with_dag_run(self, mock_hook):
+        """ Check BranchSqlOperator branch operation """
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(branch_op)
+        self.branch_2.set_upstream(branch_op)
+        self.dag.clear()
+
+        self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_hook.get_connection("mysql_default").conn_type = "mysql"
+        mock_get_records = (
+            mock_hook.get_connection.return_value.get_hook.return_value.get_first
+        )
+
+        mock_get_records.return_value = ["Invalid Value"]
+
+        with self.assertRaises(AirflowException):
+            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+    @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+    def test_with_skip_in_branch_downstream_dependencies(self, mock_hook):
+        """ Test SQL Branch with skipping all downstream dependencies """
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        branch_op >> self.branch_1 >> self.branch_2
+        branch_op >> self.branch_2
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_hook.get_connection("mysql_default").conn_type = "mysql"
+        mock_get_records = (
+            mock_hook.get_connection.return_value.get_hook.return_value.get_first
+        )
+
+        for true_value in SUPPORTED_TRUE_VALUES:
+            mock_get_records.return_value = [true_value]
+
+            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+            tis = dr.get_task_instances()
+            for ti in tis:
+                if ti.task_id == "make_choice":
+                    self.assertEqual(ti.state, State.SUCCESS)
+                elif ti.task_id == "branch_1":
+                    self.assertEqual(ti.state, State.NONE)
+                elif ti.task_id == "branch_2":
+                    self.assertEqual(ti.state, State.NONE)
+                else:
+                    raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
+
+    @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+    def test_with_skip_in_branch_downstream_dependencies2(self, mock_hook):
+        """ Test skipping downstream dependency for false condition"""
+        branch_op = BranchSqlOperator(
+            task_id="make_choice",
+            conn_id="mysql_default",
+            sql="SELECT 1",
+            follow_task_ids_if_true="branch_1",
+            follow_task_ids_if_false="branch_2",
+            dag=self.dag,
+        )
+
+        branch_op >> self.branch_1 >> self.branch_2
+        branch_op >> self.branch_2
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        mock_hook.get_connection("mysql_default").conn_type = "mysql"
+        mock_get_records = (
+            mock_hook.get_connection.return_value.get_hook.return_value.get_first
+        )
+
+        for false_value in SUPPORTED_FALSE_VALUES:
+            mock_get_records.return_value = [false_value]
+
+            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+            tis = dr.get_task_instances()
+            for ti in tis:
+                if ti.task_id == "make_choice":
+                    self.assertEqual(ti.state, State.SUCCESS)
+                elif ti.task_id == "branch_1":
+                    self.assertEqual(ti.state, State.SKIPPED)
+                elif ti.task_id == "branch_2":
+                    self.assertEqual(ti.state, State.NONE)
+                else:
+                    raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))