You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2017/04/05 17:17:29 UTC
incubator-airflow git commit: Merge pull request #2195 from
bolkedebruin/AIRFLOW-719
Repository: incubator-airflow
Updated Branches:
refs/heads/v1-8-test 9070a8277 -> dff6d21bf
Merge pull request #2195 from bolkedebruin/AIRFLOW-719
(cherry picked from commit 4a6bef69d1817a5fc3ddd6ffe14c2578eaa49cf0)
Signed-off-by: Bolke de Bruin <bo...@xs4all.nl>
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/dff6d21b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/dff6d21b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/dff6d21b
Branch: refs/heads/v1-8-test
Commit: dff6d21bfd9a2585ca484fc8fd56aa100f640908
Parents: 9070a82
Author: Bolke de Bruin <bo...@xs4all.nl>
Authored: Tue Apr 4 17:04:12 2017 +0200
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Wed Apr 5 19:16:22 2017 +0200
----------------------------------------------------------------------
airflow/operators/latest_only_operator.py | 30 ++-
airflow/operators/python_operator.py | 82 +++++--
airflow/ti_deps/deps/trigger_rule_dep.py | 6 +-
scripts/ci/requirements.txt | 1 +
tests/dags/test_dagrun_short_circuit_false.py | 38 ----
tests/models.py | 77 +++----
tests/operators/__init__.py | 2 +
tests/operators/latest_only_operator.py | 12 +-
tests/operators/python_operator.py | 244 +++++++++++++++++++++
9 files changed, 384 insertions(+), 108 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/airflow/operators/latest_only_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/latest_only_operator.py b/airflow/operators/latest_only_operator.py
index 8b4e614..9d5defb 100644
--- a/airflow/operators/latest_only_operator.py
+++ b/airflow/operators/latest_only_operator.py
@@ -34,7 +34,7 @@ class LatestOnlyOperator(BaseOperator):
def execute(self, context):
# If the DAG Run is externally triggered, then return without
# skipping downstream tasks
- if context['dag_run'].external_trigger:
+ if context['dag_run'] and context['dag_run'].external_trigger:
logging.info("""Externally triggered DAG_Run:
allowing execution to proceed.""")
return
@@ -46,17 +46,39 @@ class LatestOnlyOperator(BaseOperator):
logging.info(
'Checking latest only with left_window: %s right_window: %s '
'now: %s', left_window, right_window, now)
+
if not left_window < now <= right_window:
logging.info('Not latest execution, skipping downstream.')
session = settings.Session()
- for task in context['task'].downstream_list:
- ti = TaskInstance(
- task, execution_date=context['ti'].execution_date)
+
+ TI = TaskInstance
+ tis = session.query(TI).filter(
+ TI.execution_date == context['ti'].execution_date,
+ TI.task_id.in_(context['task'].downstream_task_ids)
+ ).with_for_update().all()
+
+ for ti in tis:
logging.info('Skipping task: %s', ti.task_id)
ti.state = State.SKIPPED
ti.start_date = now
ti.end_date = now
session.merge(ti)
+
+ # this is defensive against dag runs that are not complete
+ for task in context['task'].downstream_list:
+ if task.task_id in tis:
+ continue
+
+ logging.warning("Task {} was not part of a dag run. "
+ "This should not happen."
+ .format(task))
+ now = datetime.datetime.now()
+ ti = TaskInstance(task, execution_date=context['ti'].execution_date)
+ ti.state = State.SKIPPED
+ ti.start_date = now
+ ti.end_date = now
+ session.merge(ti)
+
session.commit()
session.close()
logging.info('Done.')
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/airflow/operators/python_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py
index b5f6386..114bc7e 100644
--- a/airflow/operators/python_operator.py
+++ b/airflow/operators/python_operator.py
@@ -106,14 +106,36 @@ class BranchPythonOperator(PythonOperator):
logging.info("Following branch " + branch)
logging.info("Marking other directly downstream tasks as skipped")
session = settings.Session()
+
+ TI = TaskInstance
+ tis = session.query(TI).filter(
+ TI.execution_date == context['ti'].execution_date,
+ TI.task_id.in_(context['task'].downstream_task_ids),
+ TI.task_id != branch,
+ ).with_for_update().all()
+
+ for ti in tis:
+ logging.info('Skipping task: %s', ti.task_id)
+ ti.state = State.SKIPPED
+ ti.start_date = datetime.now()
+ ti.end_date = datetime.now()
+
+ # this is defensive against dag runs that are not complete
for task in context['task'].downstream_list:
- if task.task_id != branch:
- ti = TaskInstance(
- task, execution_date=context['ti'].execution_date)
- ti.state = State.SKIPPED
- ti.start_date = datetime.now()
- ti.end_date = datetime.now()
- session.merge(ti)
+ if task.task_id in tis:
+ continue
+
+ if task.task_id == branch:
+ continue
+
+ logging.warning("Task {} was not part of a dag run. This should not happen."
+ .format(task))
+ ti = TaskInstance(task, execution_date=context['ti'].execution_date)
+ ti.state = State.SKIPPED
+ ti.start_date = datetime.now()
+ ti.end_date = datetime.now()
+ session.merge(ti)
+
session.commit()
session.close()
logging.info("Done.")
@@ -134,19 +156,39 @@ class ShortCircuitOperator(PythonOperator):
def execute(self, context):
condition = super(ShortCircuitOperator, self).execute(context)
logging.info("Condition result is {}".format(condition))
+
if condition:
logging.info('Proceeding with downstream tasks...')
return
- else:
- logging.info('Skipping downstream tasks...')
- session = settings.Session()
- for task in context['task'].downstream_list:
- ti = TaskInstance(
- task, execution_date=context['ti'].execution_date)
- ti.state = State.SKIPPED
- ti.start_date = datetime.now()
- ti.end_date = datetime.now()
- session.merge(ti)
- session.commit()
- session.close()
- logging.info("Done.")
+
+ logging.info('Skipping downstream tasks...')
+ session = settings.Session()
+
+ TI = TaskInstance
+ tis = session.query(TI).filter(
+ TI.execution_date == context['ti'].execution_date,
+ TI.task_id.in_(context['task'].downstream_task_ids),
+ ).with_for_update().all()
+
+ for ti in tis:
+ logging.info('Skipping task: %s', ti.task_id)
+ ti.state = State.SKIPPED
+ ti.start_date = datetime.now()
+ ti.end_date = datetime.now()
+
+ # this is defensive against dag runs that are not complete
+ for task in context['task'].downstream_list:
+ if task.task_id in tis:
+ continue
+
+ logging.warning("Task {} was not part of a dag run. This should not happen."
+ .format(task))
+ ti = TaskInstance(task, execution_date=context['ti'].execution_date)
+ ti.state = State.SKIPPED
+ ti.start_date = datetime.now()
+ ti.end_date = datetime.now()
+ session.merge(ti)
+
+ session.commit()
+ session.close()
+ logging.info("Done.")
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/airflow/ti_deps/deps/trigger_rule_dep.py
----------------------------------------------------------------------
diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py
index da13bba..281ed51 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -135,7 +135,7 @@ class TriggerRuleDep(BaseTIDep):
if tr == TR.ALL_SUCCESS:
if upstream_failed or failed:
ti.set_state(State.UPSTREAM_FAILED, session)
- elif skipped == upstream:
+ elif skipped:
ti.set_state(State.SKIPPED, session)
elif tr == TR.ALL_FAILED:
if successes or skipped:
@@ -148,7 +148,7 @@ class TriggerRuleDep(BaseTIDep):
ti.set_state(State.SKIPPED, session)
if tr == TR.ONE_SUCCESS:
- if successes <= 0 and skipped <= 0:
+ if successes <= 0:
yield self._failing_status(
reason="Task's trigger rule '{0}' requires one upstream "
"task success, but none were found. "
@@ -162,7 +162,7 @@ class TriggerRuleDep(BaseTIDep):
"upstream_tasks_state={1}, upstream_task_ids={2}"
.format(tr, upstream_tasks_state, task.upstream_task_ids))
elif tr == TR.ALL_SUCCESS:
- num_failures = upstream - (successes + skipped)
+ num_failures = upstream - successes
if num_failures > 0:
yield self._failing_status(
reason="Task's trigger rule '{0}' requires all upstream "
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/scripts/ci/requirements.txt
----------------------------------------------------------------------
diff --git a/scripts/ci/requirements.txt b/scripts/ci/requirements.txt
index a5786f6..9a2bce2 100644
--- a/scripts/ci/requirements.txt
+++ b/scripts/ci/requirements.txt
@@ -20,6 +20,7 @@ flask-cache
flask-login==0.2.11
Flask-WTF
flower
+freezegun
future
gunicorn
hdfs
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/tests/dags/test_dagrun_short_circuit_false.py
----------------------------------------------------------------------
diff --git a/tests/dags/test_dagrun_short_circuit_false.py b/tests/dags/test_dagrun_short_circuit_false.py
deleted file mode 100644
index 805ab67..0000000
--- a/tests/dags/test_dagrun_short_circuit_false.py
+++ /dev/null
@@ -1,38 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from datetime import datetime
-
-from airflow.models import DAG
-from airflow.operators.python_operator import ShortCircuitOperator
-from airflow.operators.dummy_operator import DummyOperator
-
-
-# DAG that has its short circuit op fail and skip multiple downstream tasks
-dag = DAG(
- dag_id='test_dagrun_short_circuit_false',
- start_date=datetime(2017, 1, 1)
-)
-dag_task1 = ShortCircuitOperator(
- task_id='test_short_circuit_false',
- dag=dag,
- python_callable=lambda: False)
-dag_task2 = DummyOperator(
- task_id='test_state_skipped1',
- dag=dag)
-dag_task3 = DummyOperator(
- task_id='test_state_skipped2',
- dag=dag)
-dag_task1.set_downstream(dag_task2)
-dag_task2.set_downstream(dag_task3)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index 83183f8..9478088 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -31,11 +31,12 @@ from airflow.models import DagModel
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.bash_operator import BashOperator
from airflow.operators.python_operator import PythonOperator
+from airflow.operators.python_operator import ShortCircuitOperator
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils.state import State
from mock import patch
from nose_parameterized import parameterized
-from tests.core import TEST_DAG_FOLDER
+
DEFAULT_DATE = datetime.datetime(2016, 1, 1)
TEST_DAGS_FOLDER = os.path.join(
@@ -235,17 +236,13 @@ class DagTest(unittest.TestCase):
class DagRunTest(unittest.TestCase):
- def setUp(self):
- self.dagbag = models.DagBag(dag_folder=TEST_DAG_FOLDER)
-
- def create_dag_run(self, dag_id, state=State.RUNNING, task_states=None):
+ def create_dag_run(self, dag, state=State.RUNNING, task_states=None):
now = datetime.datetime.now()
- dag = self.dagbag.get_dag(dag_id)
dag_run = dag.create_dagrun(
run_id='manual__' + now.isoformat(),
execution_date=now,
start_date=now,
- state=State.RUNNING,
+ state=state,
external_trigger=False,
)
@@ -298,33 +295,34 @@ class DagRunTest(unittest.TestCase):
self.assertEqual(0, len(models.DagRun.find(dag_id=dag_id2, external_trigger=True)))
self.assertEqual(1, len(models.DagRun.find(dag_id=dag_id2, external_trigger=False)))
- def test_dagrun_running_when_upstream_skipped(self):
- """
- Tests that a DAG run is not failed when an upstream task is skipped
- """
- initial_task_states = {
- 'test_short_circuit_false': State.SUCCESS,
- 'test_state_skipped1': State.SKIPPED,
- 'test_state_skipped2': State.NONE,
- }
- # dags/test_dagrun_short_circuit_false.py
- dag_run = self.create_dag_run('test_dagrun_short_circuit_false',
- state=State.RUNNING,
- task_states=initial_task_states)
- updated_dag_state = dag_run.update_state()
- self.assertEqual(State.RUNNING, updated_dag_state)
-
def test_dagrun_success_when_all_skipped(self):
"""
Tests that a DAG run succeeds when all tasks are skipped
"""
+ dag = DAG(
+ dag_id='test_dagrun_success_when_all_skipped',
+ start_date=datetime.datetime(2017, 1, 1)
+ )
+ dag_task1 = ShortCircuitOperator(
+ task_id='test_short_circuit_false',
+ dag=dag,
+ python_callable=lambda: False)
+ dag_task2 = DummyOperator(
+ task_id='test_state_skipped1',
+ dag=dag)
+ dag_task3 = DummyOperator(
+ task_id='test_state_skipped2',
+ dag=dag)
+ dag_task1.set_downstream(dag_task2)
+ dag_task2.set_downstream(dag_task3)
+
initial_task_states = {
'test_short_circuit_false': State.SUCCESS,
'test_state_skipped1': State.SKIPPED,
'test_state_skipped2': State.SKIPPED,
}
- # dags/test_dagrun_short_circuit_false.py
- dag_run = self.create_dag_run('test_dagrun_short_circuit_false',
+
+ dag_run = self.create_dag_run(dag=dag,
state=State.RUNNING,
task_states=initial_task_states)
updated_dag_state = dag_run.update_state()
@@ -385,10 +383,17 @@ class DagRunTest(unittest.TestCase):
"""
Make sure that a proper value is returned when a dagrun has no task instances
"""
+ dag = DAG(
+ dag_id='test_get_task_instance_on_empty_dagrun',
+ start_date=datetime.datetime(2017, 1, 1)
+ )
+ dag_task1 = ShortCircuitOperator(
+ task_id='test_short_circuit_false',
+ dag=dag,
+ python_callable=lambda: False)
+
session = settings.Session()
- # Any dag will work for this
- dag = self.dagbag.get_dag('test_dagrun_short_circuit_false')
now = datetime.datetime.now()
# Don't use create_dagrun since it will create the task instances too which we
@@ -784,7 +789,7 @@ class TaskInstanceTest(unittest.TestCase):
self.assertEqual(dt, ti.end_date+max_delay)
def test_depends_on_past(self):
- dagbag = models.DagBag(dag_folder=TEST_DAG_FOLDER)
+ dagbag = models.DagBag()
dag = dagbag.get_dag('test_depends_on_past')
dag.clear()
task = dag.tasks[0]
@@ -813,11 +818,10 @@ class TaskInstanceTest(unittest.TestCase):
#
# Tests for all_success
#
- ['all_success', 5, 0, 0, 0, 5, True, None, True],
- ['all_success', 2, 0, 0, 0, 2, True, None, False],
- ['all_success', 2, 0, 1, 0, 3, True, ST.UPSTREAM_FAILED, False],
- ['all_success', 2, 1, 0, 0, 3, True, None, False],
- ['all_success', 0, 5, 0, 0, 5, True, ST.SKIPPED, True],
+ ['all_success', 5, 0, 0, 0, 0, True, None, True],
+ ['all_success', 2, 0, 0, 0, 0, True, None, False],
+ ['all_success', 2, 0, 1, 0, 0, True, ST.UPSTREAM_FAILED, False],
+ ['all_success', 2, 1, 0, 0, 0, True, ST.SKIPPED, False],
#
# Tests for one_success
#
@@ -825,7 +829,6 @@ class TaskInstanceTest(unittest.TestCase):
['one_success', 2, 0, 0, 0, 2, True, None, True],
['one_success', 2, 0, 1, 0, 3, True, None, True],
['one_success', 2, 1, 0, 0, 3, True, None, True],
- ['one_success', 0, 2, 0, 0, 2, True, None, True],
#
# Tests for all_failed
#
@@ -837,9 +840,9 @@ class TaskInstanceTest(unittest.TestCase):
#
# Tests for one_failed
#
- ['one_failed', 5, 0, 0, 0, 5, True, ST.SKIPPED, False],
- ['one_failed', 2, 0, 0, 0, 2, True, None, False],
- ['one_failed', 2, 0, 1, 0, 2, True, None, True],
+ ['one_failed', 5, 0, 0, 0, 0, True, None, False],
+ ['one_failed', 2, 0, 0, 0, 0, True, None, False],
+ ['one_failed', 2, 0, 1, 0, 0, True, None, True],
['one_failed', 2, 1, 0, 0, 3, True, None, False],
['one_failed', 2, 3, 0, 0, 5, True, ST.SKIPPED, False],
#
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/tests/operators/__init__.py
----------------------------------------------------------------------
diff --git a/tests/operators/__init__.py b/tests/operators/__init__.py
index 1fb0e5e..aeb243c 100644
--- a/tests/operators/__init__.py
+++ b/tests/operators/__init__.py
@@ -18,3 +18,5 @@ from .operators import *
from .sensors import *
from .hive_operator import *
from .s3_to_hive_operator import *
+from .python_operator import *
+from .latest_only_operator import *
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/tests/operators/latest_only_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/latest_only_operator.py b/tests/operators/latest_only_operator.py
index 37aec38..9137491 100644
--- a/tests/operators/latest_only_operator.py
+++ b/tests/operators/latest_only_operator.py
@@ -77,17 +77,17 @@ class LatestOnlyOperatorTest(unittest.TestCase):
latest_instances = get_task_instances('latest')
exec_date_to_latest_state = {
ti.execution_date: ti.state for ti in latest_instances}
- assert exec_date_to_latest_state == {
+ self.assertEqual({
datetime.datetime(2016, 1, 1): 'success',
datetime.datetime(2016, 1, 1, 12): 'success',
- datetime.datetime(2016, 1, 2): 'success',
- }
+ datetime.datetime(2016, 1, 2): 'success', },
+ exec_date_to_latest_state)
downstream_instances = get_task_instances('downstream')
exec_date_to_downstream_state = {
ti.execution_date: ti.state for ti in downstream_instances}
- assert exec_date_to_downstream_state == {
+ self.assertEqual({
datetime.datetime(2016, 1, 1): 'skipped',
datetime.datetime(2016, 1, 1, 12): 'skipped',
- datetime.datetime(2016, 1, 2): 'success',
- }
+ datetime.datetime(2016, 1, 2): 'success',},
+ exec_date_to_downstream_state)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/tests/operators/python_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/python_operator.py b/tests/operators/python_operator.py
new file mode 100644
index 0000000..3aa8b6c
--- /dev/null
+++ b/tests/operators/python_operator.py
@@ -0,0 +1,244 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function, unicode_literals
+
+import datetime
+import unittest
+
+from airflow import configuration, DAG
+from airflow.models import TaskInstance as TI
+from airflow.operators.python_operator import PythonOperator, BranchPythonOperator
+from airflow.operators.python_operator import ShortCircuitOperator
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.settings import Session
+from airflow.utils.state import State
+
+from airflow.exceptions import AirflowException
+
+DEFAULT_DATE = datetime.datetime(2016, 1, 1)
+END_DATE = datetime.datetime(2016, 1, 2)
+INTERVAL = datetime.timedelta(hours=12)
+FROZEN_NOW = datetime.datetime(2016, 1, 2, 12, 1, 1)
+
+
+class PythonOperatorTest(unittest.TestCase):
+
+ def setUp(self):
+ super(PythonOperatorTest, self).setUp()
+ configuration.load_test_config()
+ self.dag = DAG(
+ 'test_dag',
+ default_args={
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE},
+ schedule_interval=INTERVAL)
+ self.addCleanup(self.dag.clear)
+ self.clear_run()
+ self.addCleanup(self.clear_run)
+
+ def do_run(self):
+ self.run = True
+
+ def clear_run(self):
+ self.run = False
+
+ def is_run(self):
+ return self.run
+
+ def test_python_operator_run(self):
+ """Tests that the python callable is invoked on task run."""
+ task = PythonOperator(
+ python_callable=self.do_run,
+ task_id='python_operator',
+ dag=self.dag)
+ self.assertFalse(self.is_run())
+ task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ self.assertTrue(self.is_run())
+
+ def test_python_operator_python_callable_is_callable(self):
+ """Tests that PythonOperator will only instantiate if
+ the python_callable argument is callable."""
+ not_callable = {}
+ with self.assertRaises(AirflowException):
+ PythonOperator(
+ python_callable=not_callable,
+ task_id='python_operator',
+ dag=self.dag)
+ not_callable = None
+ with self.assertRaises(AirflowException):
+ PythonOperator(
+ python_callable=not_callable,
+ task_id='python_operator',
+ dag=self.dag)
+
+
+class BranchOperatorTest(unittest.TestCase):
+ def setUp(self):
+ self.dag = DAG('branch_operator_test',
+ default_args={
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE},
+ schedule_interval=INTERVAL)
+ self.branch_op = BranchPythonOperator(task_id='make_choice',
+ dag=self.dag,
+ python_callable=lambda: 'branch_1')
+
+ self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
+ self.branch_1.set_upstream(self.branch_op)
+ self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag)
+ self.branch_2.set_upstream(self.branch_op)
+ self.dag.clear()
+
+ def test_without_dag_run(self):
+ """This checks the defensive against non existent tasks in a dag run"""
+ self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ session = Session()
+ tis = session.query(TI).filter(
+ TI.dag_id == self.dag.dag_id,
+ TI.execution_date == DEFAULT_DATE
+ )
+ session.close()
+
+ for ti in tis:
+ if ti.task_id == 'make_choice':
+ self.assertEquals(ti.state, State.SUCCESS)
+ elif ti.task_id == 'branch_1':
+ # should not exist
+ raise
+ elif ti.task_id == 'branch_2':
+ self.assertEquals(ti.state, State.SKIPPED)
+ else:
+ raise
+
+ def test_with_dag_run(self):
+ dr = self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=datetime.datetime.now(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING
+ )
+
+ self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == 'make_choice':
+ self.assertEquals(ti.state, State.SUCCESS)
+ elif ti.task_id == 'branch_1':
+ self.assertEquals(ti.state, State.NONE)
+ elif ti.task_id == 'branch_2':
+ self.assertEquals(ti.state, State.SKIPPED)
+ else:
+ raise
+
+
+class ShortCircuitOperatorTest(unittest.TestCase):
+ def setUp(self):
+ self.dag = DAG('shortcircuit_operator_test',
+ default_args={
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE},
+ schedule_interval=INTERVAL)
+ self.short_op = ShortCircuitOperator(task_id='make_choice',
+ dag=self.dag,
+ python_callable=lambda: self.value)
+
+ self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
+ self.branch_1.set_upstream(self.short_op)
+ self.upstream = DummyOperator(task_id='upstream', dag=self.dag)
+ self.upstream.set_downstream(self.short_op)
+ self.dag.clear()
+
+ self.value = True
+
+ def test_without_dag_run(self):
+ """This checks the defensive against non existent tasks in a dag run"""
+ self.value = False
+ self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ session = Session()
+ tis = session.query(TI).filter(
+ TI.dag_id == self.dag.dag_id,
+ TI.execution_date == DEFAULT_DATE
+ )
+
+ for ti in tis:
+ if ti.task_id == 'make_choice':
+ self.assertEquals(ti.state, State.SUCCESS)
+ elif ti.task_id == 'upstream':
+ # should not exist
+ raise
+ elif ti.task_id == 'branch_1':
+ self.assertEquals(ti.state, State.SKIPPED)
+ else:
+ raise
+
+ self.value = True
+ self.dag.clear()
+
+ self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ for ti in tis:
+ if ti.task_id == 'make_choice':
+ self.assertEquals(ti.state, State.SUCCESS)
+ elif ti.task_id == 'upstream':
+ # should not exist
+ raise
+ elif ti.task_id == 'branch_1':
+ self.assertEquals(ti.state, State.NONE)
+ else:
+ raise
+
+ session.close()
+
+ def test_with_dag_run(self):
+ self.value = False
+ dr = self.dag.create_dagrun(
+ run_id="manual__",
+ start_date=datetime.datetime.now(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING
+ )
+
+ self.upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == 'make_choice':
+ self.assertEquals(ti.state, State.SUCCESS)
+ elif ti.task_id == 'upstream':
+ self.assertEquals(ti.state, State.SUCCESS)
+ elif ti.task_id == 'branch_1':
+ self.assertEquals(ti.state, State.SKIPPED)
+ else:
+ raise
+
+ self.value = True
+ self.dag.clear()
+
+ self.upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ tis = dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == 'make_choice':
+ self.assertEquals(ti.state, State.SUCCESS)
+ elif ti.task_id == 'upstream':
+ self.assertEquals(ti.state, State.SUCCESS)
+ elif ti.task_id == 'branch_1':
+ self.assertEquals(ti.state, State.NONE)
+ else:
+ raise