You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2020/06/01 10:53:40 UTC

[GitHub] [airflow] ashb commented on a change in pull request #8805: Resolve upstream tasks when template field is XComArg

ashb commented on a change in pull request #8805:
URL: https://github.com/apache/airflow/pull/8805#discussion_r433166090



##########
File path: tests/test_utils/__init__.py
##########
@@ -20,3 +20,11 @@
 AIRFLOW_MAIN_FOLDER = os.path.realpath(
     os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir)
 )
+
+EXAMPLE_DAGS_FOLDER = os.path.join(
+    AIRFLOW_MAIN_FOLDER, "airflow", "example_dags"
+)
+
+TEST_DAGS_FOLDER = os.path.realpath(

Review comment:
       This is already set from airflow.settings.TEST_DAGS_FOLDER - we don't need another constant

##########
File path: tests/models/test_baseoperator.py
##########
@@ -347,3 +350,87 @@ def test_lineage_composition(self):
         task4 = DummyOperator(task_id="op4", dag=dag)
         task4 > [inlet, outlet, extra]
         self.assertEqual(task4.get_outlet_defs(), [inlet, outlet, extra])
+
+
+class CustomOp(DummyOperator):
+    template_fields = ("field", "field2")
+
+    @apply_defaults
+    def __init__(self, field=None, field2=None, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.field = field
+        self.field2 = field2
+
+    def execute(self, context):
+        self.field = None
+
+
+class TestXComArgsRelationsAreResolved:
+    def test_setattr_performs_no_custom_action_at_execute_time(self):
+        op = CustomOp(task_id="test_task")
+        op.lock_for_execution()
+
+        with mock.patch(
+            "airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies"
+        ) as method_mock:
+            op.execute({})
+        assert method_mock.call_count == 0
+
+    def test_upstream_is_set_when_template_field_is_xcomarg(self):
+        with DAG("xcomargs_test", default_args={"start_date": datetime.today()}):
+            op1 = DummyOperator(task_id="op1")
+            op2 = CustomOp(task_id="op2", field=op1.output)
+
+        assert op1 in op2.upstream_list
+        assert op2 in op1.downstream_list
+
+    def test_set_xcomargs_dependencies_works_recursively(self):
+        with DAG("xcomargs_test", default_args={"start_date": datetime.today()}):
+            op1 = DummyOperator(task_id="op1")
+            op2 = DummyOperator(task_id="op2")
+            op3 = CustomOp(task_id="op3", field=[op1.output, op2.output])
+            op4 = CustomOp(task_id="op4", field={"op1": op1.output, "op2": op2.output})
+
+        assert op1 in op3.upstream_list
+        assert op2 in op3.upstream_list
+        assert op1 in op4.upstream_list
+        assert op2 in op4.upstream_list
+
+    def test_set_xcomargs_dependencies_works_when_set_after_init(self):
+        with DAG(dag_id='xcomargs_test', default_args={"start_date": datetime.today()}):
+            op1 = DummyOperator(task_id="op1")
+            op2 = CustomOp(task_id="op2")
+            op2.field = op1.output  # value is set after init
+
+        assert op1 in op2.upstream_list
+
+    def test_set_xcomargs_dependencies_error_when_outside_dag(self):
+        with pytest.raises(AirflowException):
+            op1 = DummyOperator(task_id="op1")
+            CustomOp(task_id="op2", field=op1.output)
+
+    def test_set_xcomargs_dependencies_when_creating_dagbag(self):
+        dag_bag = DagBag(dag_folder=TEST_DAGS_FOLDER, include_examples=False)
+        dag_id = "xcomargs_test_1"
+        dag: DAG = dag_bag.get_dag(dag_id)

Review comment:
       i.e. just create this DAG here, as it isn't used anywhere else.
   
   Same for xcomargs_test_2

##########
File path: airflow/models/baseoperator.py
##########
@@ -292,6 +311,12 @@ class derived from this one results in the creation of a task object,
     # Defines if the operator supports lineage without manual definitions
     supports_lineage = False
 
+    # If True then the class constructor was called
+    _instantiated = False

Review comment:
       ```suggestion
       __instantiated = False
   ```
   
   I think we can do this, which then means this is much much less likely to clash with any fields in a subclass.
   
   Internally Python converts `__x` to `_BaseOperator__x` when used: https://docs.python.org/3/tutorial/classes.html#tut-private

##########
File path: tests/dags/test_xcomargs_dag.py
##########
@@ -0,0 +1,59 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Example DAG demonstrating the usage of the XComArgs."""
+from datetime import datetime
+
+from airflow import DAG
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.utils.dates import days_ago
+from airflow.utils.decorators import apply_defaults
+
+
+class CustomOp(DummyOperator):
+    template_fields = ("field",)
+
+    @apply_defaults
+    def __init__(self, field=None, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.field = field
+
+
+args = {
+    'owner': 'airflow',
+    'start_date': days_ago(2),
+}
+
+
+def dummy(*args, **kwargs):
+    """Dummy function"""
+    return "pass"
+
+
+with DAG(dag_id='xcomargs_test_1', default_args={"start_date": datetime.today()}) as dag1:

Review comment:
       Every dag we add to tests/dags slows down our many tests, so unless we _need_ the dag to exist in a normal dags folder (i.e. when we are running it via a full scheduler/backfill) we should just create dags in the unit tests please. 

##########
File path: airflow/models/baseoperator.py
##########
@@ -633,6 +670,51 @@ def deps(self) -> Set[BaseTIDep]:
             NotPreviouslySkippedDep(),
         }
 
+    def lock_for_execution(self) -> None:
+        """Sets _lock_for_execution to True"""

Review comment:
       This description is somewhat self-referrential - I could have guess this is what it did based on the name of the method.
   
   Either say what this setting actually does, or remove the doc comment please -- having this doc doesn't add any value

##########
File path: airflow/models/taskinstance.py
##########
@@ -968,6 +968,7 @@ def _run_raw_task(
                 context = self.get_template_context()
 
                 task_copy = copy.copy(task)
+                task_copy.lock_for_execution()

Review comment:
       How about we combine both these in to a single function `task.prepare_for_execution()` that returns a new copied instance with the "lock" set.
   
   We also do `task_copy = copy.copy(task) in the `dry_run` fn later on -- we should use that here to.

##########
File path: tests/models/test_baseoperator.py
##########
@@ -347,3 +350,87 @@ def test_lineage_composition(self):
         task4 = DummyOperator(task_id="op4", dag=dag)
         task4 > [inlet, outlet, extra]
         self.assertEqual(task4.get_outlet_defs(), [inlet, outlet, extra])
+
+
+class CustomOp(DummyOperator):
+    template_fields = ("field", "field2")
+
+    @apply_defaults
+    def __init__(self, field=None, field2=None, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.field = field
+        self.field2 = field2
+
+    def execute(self, context):
+        self.field = None
+
+
+class TestXComArgsRelationsAreResolved:
+    def test_setattr_performs_no_custom_action_at_execute_time(self):
+        op = CustomOp(task_id="test_task")
+        op.lock_for_execution()
+
+        with mock.patch(
+            "airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies"
+        ) as method_mock:
+            op.execute({})
+        assert method_mock.call_count == 0
+
+    def test_upstream_is_set_when_template_field_is_xcomarg(self):
+        with DAG("xcomargs_test", default_args={"start_date": datetime.today()}):
+            op1 = DummyOperator(task_id="op1")
+            op2 = CustomOp(task_id="op2", field=op1.output)
+
+        assert op1 in op2.upstream_list
+        assert op2 in op1.downstream_list
+
+    def test_set_xcomargs_dependencies_works_recursively(self):
+        with DAG("xcomargs_test", default_args={"start_date": datetime.today()}):
+            op1 = DummyOperator(task_id="op1")
+            op2 = DummyOperator(task_id="op2")
+            op3 = CustomOp(task_id="op3", field=[op1.output, op2.output])
+            op4 = CustomOp(task_id="op4", field={"op1": op1.output, "op2": op2.output})
+
+        assert op1 in op3.upstream_list
+        assert op2 in op3.upstream_list
+        assert op1 in op4.upstream_list
+        assert op2 in op4.upstream_list
+
+    def test_set_xcomargs_dependencies_works_when_set_after_init(self):
+        with DAG(dag_id='xcomargs_test', default_args={"start_date": datetime.today()}):
+            op1 = DummyOperator(task_id="op1")
+            op2 = CustomOp(task_id="op2")
+            op2.field = op1.output  # value is set after init
+
+        assert op1 in op2.upstream_list
+
+    def test_set_xcomargs_dependencies_error_when_outside_dag(self):
+        with pytest.raises(AirflowException):
+            op1 = DummyOperator(task_id="op1")
+            CustomOp(task_id="op2", field=op1.output)
+
+    def test_set_xcomargs_dependencies_when_creating_dagbag(self):
+        dag_bag = DagBag(dag_folder=TEST_DAGS_FOLDER, include_examples=False)
+        dag_id = "xcomargs_test_1"
+        dag: DAG = dag_bag.get_dag(dag_id)

Review comment:
       (Not having to load the full test dag bag makes this test much quicker as well.)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org