You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by di...@apache.org on 2020/06/18 16:57:08 UTC
[airflow] 01/01: Add SQL Branch Operator
This is an automated email from the ASF dual-hosted git repository.
dimberman pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit f837c5937a4fd4d0a2b33516c60a0e8241b7640b
Author: samuelkhtu <46...@users.noreply.github.com>
AuthorDate: Mon Jun 1 14:14:13 2020 -0400
Add SQL Branch Operator
SQL Branch Operator allow user to execute a SQL query in any supported backend to decide which
branch to follow. The SQL branch operator expect query to return True/False (Boolean) or
0/1 (Integer) or true/y/yes/1/on/false/n/no/0/off (String).
(cherry picked from commit 55b9b8f6456a7a46921b0bf7a893c7f08bf8237c)
---
airflow/operators/sql_branch_operator.py | 173 ++++++++++
docs/operators-and-hooks-ref.rst | 81 +++++
tests/operators/test_sql_branch_operator.py | 476 ++++++++++++++++++++++++++++
3 files changed, 730 insertions(+)
diff --git a/airflow/operators/sql_branch_operator.py b/airflow/operators/sql_branch_operator.py
new file mode 100644
index 0000000..072c40c
--- /dev/null
+++ b/airflow/operators/sql_branch_operator.py
@@ -0,0 +1,173 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+ALLOWED_CONN_TYPE = {
+ "google_cloud_platform",
+ "jdbc",
+ "mssql",
+ "mysql",
+ "odbc",
+ "oracle",
+ "postgres",
+ "presto",
+ "sqlite",
+ "vertica",
+}
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+ """
+ Executes sql code in a specific database
+
+ :param sql: the sql code to be executed. (templated)
+ :type sql: Can receive a str representing a sql statement or reference to a template file.
+ Template reference are recognized by str ending in '.sql'.
+ Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+ or string (true/y/yes/1/on/false/n/no/0/off).
+ :param follow_task_ids_if_true: task id or task ids to follow if query return true
+ :type follow_task_ids_if_true: str or list
+ :param follow_task_ids_if_false: task id or task ids to follow if query return true
+ :type follow_task_ids_if_false: str or list
+ :param conn_id: reference to a specific database
+ :type conn_id: str
+ :param database: name of database which overwrite defined one in connection
+ :param parameters: (optional) the parameters to render the SQL query with.
+ :type parameters: mapping or iterable
+ """
+
+ template_fields = ("sql",)
+ template_ext = (".sql",)
+ ui_color = "#a22034"
+ ui_fgcolor = "#F7F7F7"
+
+ @apply_defaults
+ def __init__(
+ self,
+ sql,
+ follow_task_ids_if_true,
+ follow_task_ids_if_false,
+ conn_id="default_conn_id",
+ database=None,
+ parameters=None,
+ *args,
+ **kwargs):
+ super(BranchSqlOperator, self).__init__(*args, **kwargs)
+ self.conn_id = conn_id
+ self.sql = sql
+ self.parameters = parameters
+ self.follow_task_ids_if_true = follow_task_ids_if_true
+ self.follow_task_ids_if_false = follow_task_ids_if_false
+ self.database = database
+ self._hook = None
+
+ def _get_hook(self):
+ self.log.debug("Get connection for %s", self.conn_id)
+ conn = BaseHook.get_connection(self.conn_id)
+
+ if conn.conn_type not in ALLOWED_CONN_TYPE:
+ raise AirflowException(
+ "The connection type is not supported by BranchSqlOperator. "
+ + "Supported connection types: {}".format(list(ALLOWED_CONN_TYPE))
+ )
+
+ if not self._hook:
+ self._hook = conn.get_hook()
+ if self.database:
+ self._hook.schema = self.database
+
+ return self._hook
+
+ def execute(self, context):
+ # get supported hook
+ self._hook = self._get_hook()
+
+ if self._hook is None:
+ raise AirflowException(
+ "Failed to establish connection to '%s'" % self.conn_id
+ )
+
+ if self.sql is None:
+ raise AirflowException("Expected 'sql' parameter is missing.")
+
+ if self.follow_task_ids_if_true is None:
+ raise AirflowException(
+ "Expected 'follow_task_ids_if_true' paramter is missing."
+ )
+
+ if self.follow_task_ids_if_false is None:
+ raise AirflowException(
+ "Expected 'follow_task_ids_if_false' parameter is missing."
+ )
+
+ self.log.info(
+ "Executing: %s (with parameters %s) with connection: %s",
+ self.sql,
+ self.parameters,
+ self._hook,
+ )
+ record = self._hook.get_first(self.sql, self.parameters)
+ if not record:
+ raise AirflowException(
+ "No rows returned from sql query. Operator expected True or False return value."
+ )
+
+ if isinstance(record, list):
+ if isinstance(record[0], list):
+ query_result = record[0][0]
+ else:
+ query_result = record[0]
+ elif isinstance(record, tuple):
+ query_result = record[0]
+ else:
+ query_result = record
+
+ self.log.info("Query returns %s, type '%s'", query_result, type(query_result))
+
+ follow_branch = None
+ try:
+ if isinstance(query_result, bool):
+ if query_result:
+ follow_branch = self.follow_task_ids_if_true
+ elif isinstance(query_result, str):
+ # return result is not Boolean, try to convert from String to Boolean
+ if bool(strtobool(query_result)):
+ follow_branch = self.follow_task_ids_if_true
+ elif isinstance(query_result, int):
+ if bool(query_result):
+ follow_branch = self.follow_task_ids_if_true
+ else:
+ raise AirflowException(
+ "Unexpected query return result '%s' type '%s'"
+ % (query_result, type(query_result))
+ )
+
+ if follow_branch is None:
+ follow_branch = self.follow_task_ids_if_false
+ except ValueError:
+ raise AirflowException(
+ "Unexpected query return result '%s' type '%s'"
+ % (query_result, type(query_result))
+ )
+
+ self.skip_all_except(context["ti"], follow_branch)
diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst
index 6c80858..55176f8 100644
--- a/docs/operators-and-hooks-ref.rst
+++ b/docs/operators-and-hooks-ref.rst
@@ -22,6 +22,87 @@ Operators and Hooks Reference
:local:
:depth: 1
+.. _fundamentals:
+
+Fundamentals
+------------
+
+**Base:**
+
+.. list-table::
+ :header-rows: 1
+
+ * - Module
+ - Guides
+
+ * - :mod:`airflow.hooks.base_hook`
+ -
+
+ * - :mod:`airflow.hooks.dbapi_hook`
+ -
+
+ * - :mod:`airflow.models.baseoperator`
+ -
+
+ * - :mod:`airflow.sensors.base_sensor_operator`
+ -
+
+**Operators:**
+
+.. list-table::
+ :header-rows: 1
+
+ * - Operators
+ - Guides
+
+ * - :mod:`airflow.operators.branch_operator`
+ -
+
+ * - :mod:`airflow.operators.check_operator`
+ -
+
+ * - :mod:`airflow.operators.dagrun_operator`
+ -
+
+ * - :mod:`airflow.operators.dummy_operator`
+ -
+
+ * - :mod:`airflow.operators.generic_transfer`
+ -
+
+ * - :mod:`airflow.operators.latest_only_operator`
+ -
+
+ * - :mod:`airflow.operators.subdag_operator`
+ -
+
+ * - :mod:`airflow.operators.sql_branch_operator`
+ -
+
+**Sensors:**
+
+.. list-table::
+ :header-rows: 1
+
+ * - Sensors
+ - Guides
+
+ * - :mod:`airflow.sensors.weekday_sensor`
+ -
+
+ * - :mod:`airflow.sensors.external_task_sensor`
+ - :doc:`How to use <howto/operator/external_task_sensor>`
+
+ * - :mod:`airflow.sensors.sql_sensor`
+ -
+
+ * - :mod:`airflow.sensors.time_delta_sensor`
+ -
+
+ * - :mod:`airflow.sensors.time_sensor`
+ -
+
+
.. _Apache:
ASF: Apache Software Foundation
diff --git a/tests/operators/test_sql_branch_operator.py b/tests/operators/test_sql_branch_operator.py
new file mode 100644
index 0000000..b7a0885
--- /dev/null
+++ b/tests/operators/test_sql_branch_operator.py
@@ -0,0 +1,476 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import unittest
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.models import DAG, DagRun, TaskInstance as TI
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.operators.sql_branch_operator import BranchSqlOperator
+from airflow.utils import timezone
+from airflow.utils.session import create_session
+from airflow.utils.state import State
+from tests.providers.apache.hive import TestHiveEnvironment
+
+DEFAULT_DATE = timezone.datetime(2016, 1, 1)
+INTERVAL = datetime.timedelta(hours=12)
+
+SUPPORTED_TRUE_VALUES = [
+ ["True"],
+ ["true"],
+ ["1"],
+ ["on"],
+ [1],
+ True,
+ "true",
+ "1",
+ "on",
+ 1,
+]
+SUPPORTED_FALSE_VALUES = [
+ ["False"],
+ ["false"],
+ ["0"],
+ ["off"],
+ [0],
+ False,
+ "false",
+ "0",
+ "off",
+ 0,
+]
+
+
+class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
+ """
+ Test for SQL Branch Operator
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ super(TestSqlBranch, cls).setUpClass()
+
+ with create_session() as session:
+ session.query(DagRun).delete()
+ session.query(TI).delete()
+
+ def setUp(self):
+ super(TestSqlBranch, self).setUp()
+ self.dag = DAG(
+ "sql_branch_operator_test",
+ default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+ schedule_interval=INTERVAL,
+ )
+ self.branch_1 = DummyOperator(task_id="branch_1", dag=self.dag)
+ self.branch_2 = DummyOperator(task_id="branch_2", dag=self.dag)
+ self.branch_3 = None
+
+ def tearDown(self):
+ super(TestSqlBranch, self).tearDown()
+
+ with create_session() as session:
+ session.query(DagRun).delete()
+ session.query(TI).delete()
+
+ def test_unsupported_conn_type(self):
+ """ Check if BranchSqlOperator throws an exception for unsupported connection type """
+ op = BranchSqlOperator(
+ task_id="make_choice",
+ conn_id="redis_default",
+ sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ with self.assertRaises(AirflowException):
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+ def test_invalid_conn(self):
+ """ Check if BranchSqlOperator throws an exception for invalid connection """
+ op = BranchSqlOperator(
+ task_id="make_choice",
+ conn_id="invalid_connection",
+ sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ with self.assertRaises(AirflowException):
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+ def test_invalid_follow_task_true(self):
+ """ Check if BranchSqlOperator throws an exception for invalid connection """
+ op = BranchSqlOperator(
+ task_id="make_choice",
+ conn_id="invalid_connection",
+ sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+ follow_task_ids_if_true=None,
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ with self.assertRaises(AirflowException):
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+ def test_invalid_follow_task_false(self):
+ """ Check if BranchSqlOperator throws an exception for invalid connection """
+ op = BranchSqlOperator(
+ task_id="make_choice",
+ conn_id="invalid_connection",
+ sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false=None,
+ dag=self.dag,
+ )
+
+ with self.assertRaises(AirflowException):
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+ @pytest.mark.backend("mysql")
+ def test_sql_branch_operator_mysql(self):
+ """ Check if BranchSqlOperator works with backend """
+ branch_op = BranchSqlOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+ branch_op.run(
+ start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
+ )
+
+ @pytest.mark.backend("postgres")
+ def test_sql_branch_operator_postgres(self):
+ """ Check if BranchSqlOperator works with backend """
+ branch_op = BranchSqlOperator(
+ task_id="make_choice",
+ conn_id="postgres_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+ branch_op.run(
+ start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
+ )
+
+ @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+ def test_branch_single_value_with_dag_run(self, mock_hook):
+ """ Check BranchSqlOperator branch operation """
+ branch_op = BranchSqlOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ self.branch_1.set_upstream(branch_op)
+ self.branch_2.set_upstream(branch_op)
+ self.dag.clear()
+
+ dr = self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+
+ mock_hook.get_connection("mysql_default").conn_type = "mysql"
+ mock_get_records = (
+ mock_hook.get_connection.return_value.get_hook.return_value.get_first
+ )
+
+ mock_get_records.return_value = 1
+
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == "make_choice":
+ self.assertEqual(ti.state, State.SUCCESS)
+ elif ti.task_id == "branch_1":
+ self.assertEqual(ti.state, State.NONE)
+ elif ti.task_id == "branch_2":
+ self.assertEqual(ti.state, State.SKIPPED)
+ else:
+ raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
+
+ @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+ def test_branch_true_with_dag_run(self, mock_hook):
+ """ Check BranchSqlOperator branch operation """
+ branch_op = BranchSqlOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ self.branch_1.set_upstream(branch_op)
+ self.branch_2.set_upstream(branch_op)
+ self.dag.clear()
+
+ dr = self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+
+ mock_hook.get_connection("mysql_default").conn_type = "mysql"
+ mock_get_records = (
+ mock_hook.get_connection.return_value.get_hook.return_value.get_first
+ )
+
+ for true_value in SUPPORTED_TRUE_VALUES:
+ mock_get_records.return_value = true_value
+
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == "make_choice":
+ self.assertEqual(ti.state, State.SUCCESS)
+ elif ti.task_id == "branch_1":
+ self.assertEqual(ti.state, State.NONE)
+ elif ti.task_id == "branch_2":
+ self.assertEqual(ti.state, State.SKIPPED)
+ else:
+ raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
+
+ @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+ def test_branch_false_with_dag_run(self, mock_hook):
+ """ Check BranchSqlOperator branch operation """
+ branch_op = BranchSqlOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ self.branch_1.set_upstream(branch_op)
+ self.branch_2.set_upstream(branch_op)
+ self.dag.clear()
+
+ dr = self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+
+ mock_hook.get_connection("mysql_default").conn_type = "mysql"
+ mock_get_records = (
+ mock_hook.get_connection.return_value.get_hook.return_value.get_first
+ )
+
+ for false_value in SUPPORTED_FALSE_VALUES:
+ mock_get_records.return_value = false_value
+
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == "make_choice":
+ self.assertEqual(ti.state, State.SUCCESS)
+ elif ti.task_id == "branch_1":
+ self.assertEqual(ti.state, State.SKIPPED)
+ elif ti.task_id == "branch_2":
+ self.assertEqual(ti.state, State.NONE)
+ else:
+ raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
+
+ @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+ def test_branch_list_with_dag_run(self, mock_hook):
+ """ Checks if the BranchSqlOperator supports branching off to a list of tasks."""
+ branch_op = BranchSqlOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true=["branch_1", "branch_2"],
+ follow_task_ids_if_false="branch_3",
+ dag=self.dag,
+ )
+
+ self.branch_1.set_upstream(branch_op)
+ self.branch_2.set_upstream(branch_op)
+ self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag)
+ self.branch_3.set_upstream(branch_op)
+ self.dag.clear()
+
+ dr = self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+
+ mock_hook.get_connection("mysql_default").conn_type = "mysql"
+ mock_get_records = (
+ mock_hook.get_connection.return_value.get_hook.return_value.get_first
+ )
+ mock_get_records.return_value = [["1"]]
+
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == "make_choice":
+ self.assertEqual(ti.state, State.SUCCESS)
+ elif ti.task_id == "branch_1":
+ self.assertEqual(ti.state, State.NONE)
+ elif ti.task_id == "branch_2":
+ self.assertEqual(ti.state, State.NONE)
+ elif ti.task_id == "branch_3":
+ self.assertEqual(ti.state, State.SKIPPED)
+ else:
+ raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
+
+ @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+ def test_invalid_query_result_with_dag_run(self, mock_hook):
+ """ Check BranchSqlOperator branch operation """
+ branch_op = BranchSqlOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ self.branch_1.set_upstream(branch_op)
+ self.branch_2.set_upstream(branch_op)
+ self.dag.clear()
+
+ self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+
+ mock_hook.get_connection("mysql_default").conn_type = "mysql"
+ mock_get_records = (
+ mock_hook.get_connection.return_value.get_hook.return_value.get_first
+ )
+
+ mock_get_records.return_value = ["Invalid Value"]
+
+ with self.assertRaises(AirflowException):
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+ def test_with_skip_in_branch_downstream_dependencies(self, mock_hook):
+ """ Test SQL Branch with skipping all downstream dependencies """
+ branch_op = BranchSqlOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ branch_op >> self.branch_1 >> self.branch_2
+ branch_op >> self.branch_2
+ self.dag.clear()
+
+ dr = self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+
+ mock_hook.get_connection("mysql_default").conn_type = "mysql"
+ mock_get_records = (
+ mock_hook.get_connection.return_value.get_hook.return_value.get_first
+ )
+
+ for true_value in SUPPORTED_TRUE_VALUES:
+ mock_get_records.return_value = [true_value]
+
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == "make_choice":
+ self.assertEqual(ti.state, State.SUCCESS)
+ elif ti.task_id == "branch_1":
+ self.assertEqual(ti.state, State.NONE)
+ elif ti.task_id == "branch_2":
+ self.assertEqual(ti.state, State.NONE)
+ else:
+ raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))
+
+ @mock.patch("airflow.operators.sql_branch_operator.BaseHook")
+ def test_with_skip_in_branch_downstream_dependencies2(self, mock_hook):
+ """ Test skipping downstream dependency for false condition"""
+ branch_op = BranchSqlOperator(
+ task_id="make_choice",
+ conn_id="mysql_default",
+ sql="SELECT 1",
+ follow_task_ids_if_true="branch_1",
+ follow_task_ids_if_false="branch_2",
+ dag=self.dag,
+ )
+
+ branch_op >> self.branch_1 >> self.branch_2
+ branch_op >> self.branch_2
+ self.dag.clear()
+
+ dr = self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+
+ mock_hook.get_connection("mysql_default").conn_type = "mysql"
+ mock_get_records = (
+ mock_hook.get_connection.return_value.get_hook.return_value.get_first
+ )
+
+ for false_value in SUPPORTED_FALSE_VALUES:
+ mock_get_records.return_value = [false_value]
+
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == "make_choice":
+ self.assertEqual(ti.state, State.SUCCESS)
+ elif ti.task_id == "branch_1":
+ self.assertEqual(ti.state, State.SKIPPED)
+ elif ti.task_id == "branch_2":
+ self.assertEqual(ti.state, State.NONE)
+ else:
+ raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))