You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/06/25 20:30:30 UTC

[airflow] branch main updated: Remove duplicated/overlapping tests around render_k8s_pod_yaml (#16642)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c79bbb2  Remove duplicated/overlapping tests around render_k8s_pod_yaml (#16642)
c79bbb2 is described below

commit c79bbb26f383a4c68255aa24fd9426635b2b5ac3
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Fri Jun 25 21:30:15 2021 +0100

    Remove duplicated/overlapping tests around render_k8s_pod_yaml (#16642)
    
    When making another change here, I noticed that we were basically
    testing the same thing twice in test_taskinstance and
    test_renderedtifields, which does no one any good.
    
    I have updated the tests to use mocking to avoid duplication, and
    exercised a few more of the branches in the functions under test
---
 airflow/models/taskinstance.py        |  5 ++-
 tests/models/test_renderedtifields.py | 69 ++++++++---------------------------
 tests/models/test_taskinstance.py     | 40 ++++++++++++++------
 3 files changed, 46 insertions(+), 68 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index e91321b..3506c07 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1693,11 +1693,12 @@ class TaskInstance(Base, LoggingMixin):  # pylint: disable=R0902,R0904
                     "rendering of template_fields."
                 ) from e
 
-    def get_rendered_k8s_spec(self):
+    @provide_session
+    def get_rendered_k8s_spec(self, session=None):
         """Fetch rendered template fields from DB"""
         from airflow.models.renderedtifields import RenderedTaskInstanceFields
 
-        rendered_k8s_spec = RenderedTaskInstanceFields.get_k8s_pod_yaml(self)
+        rendered_k8s_spec = RenderedTaskInstanceFields.get_k8s_pod_yaml(self, session=session)
         if not rendered_k8s_spec:
             try:
                 rendered_k8s_spec = self.render_k8s_pod_yaml()
diff --git a/tests/models/test_renderedtifields.py b/tests/models/test_renderedtifields.py
index bd88d5d..f76078c 100644
--- a/tests/models/test_renderedtifields.py
+++ b/tests/models/test_renderedtifields.py
@@ -33,7 +33,6 @@ from airflow.models.taskinstance import TaskInstance as TI
 from airflow.operators.bash import BashOperator
 from airflow.utils.session import create_session
 from airflow.utils.timezone import datetime
-from airflow.version import version
 from tests.test_utils.asserts import assert_queries_count
 from tests.test_utils.db import clear_rendered_ti_fields
 
@@ -237,8 +236,7 @@ class TestRenderedTaskInstanceFields(unittest.TestCase):
 
     @mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"})
     @mock.patch('airflow.utils.log.secrets_masker.redact', autospec=True, side_effect=lambda d, _=None: d)
-    @mock.patch("airflow.settings.pod_mutation_hook")
-    def test_get_k8s_pod_yaml(self, mock_pod_mutation_hook, redact):
+    def test_get_k8s_pod_yaml(self, redact):
         """
         Test that k8s_pod_yaml is rendered correctly, stored in the Database,
         and are correctly fetched using RTIF.get_k8s_pod_yaml
@@ -248,71 +246,34 @@ class TestRenderedTaskInstanceFields(unittest.TestCase):
             task = BashOperator(task_id="test", bash_command="echo hi")
 
         ti = TI(task=task, execution_date=EXECUTION_DATE)
-        rtif = RTIF(ti=ti)
 
-        # Test that pod_mutation_hook is called
-        mock_pod_mutation_hook.assert_called_once_with(mock.ANY)
+        render_k8s_pod_yaml = mock.patch.object(
+            ti, 'render_k8s_pod_yaml', return_value={"I'm a": "pod"}
+        ).start()
+
+        rtif = RTIF(ti=ti)
 
         assert ti.dag_id == rtif.dag_id
         assert ti.task_id == rtif.task_id
         assert ti.execution_date == rtif.execution_date
 
-        expected_pod_yaml = {
-            'metadata': {
-                'annotations': {
-                    'dag_id': 'test_get_k8s_pod_yaml',
-                    'execution_date': '2019-01-01T00:00:00+00:00',
-                    'task_id': 'test',
-                    'try_number': '1',
-                },
-                'labels': {
-                    'airflow-worker': 'worker-config',
-                    'airflow_version': version,
-                    'dag_id': 'test_get_k8s_pod_yaml',
-                    'execution_date': '2019-01-01T00_00_00_plus_00_00',
-                    'kubernetes_executor': 'True',
-                    'task_id': 'test',
-                    'try_number': '1',
-                },
-                'name': mock.ANY,
-                'namespace': 'default',
-            },
-            'spec': {
-                'containers': [
-                    {
-                        'args': [
-                            'airflow',
-                            'tasks',
-                            'run',
-                            'test_get_k8s_pod_yaml',
-                            'test',
-                            '2019-01-01T00:00:00+00:00',
-                        ],
-                        'image': ':',
-                        'name': 'base',
-                        'env': [{'name': 'AIRFLOW_IS_K8S_EXECUTOR_POD', 'value': 'True'}],
-                    }
-                ]
-            },
-        }
+        expected_pod_yaml = {"I'm a": "pod"}
 
-        assert expected_pod_yaml == rtif.k8s_pod_yaml
+        assert rtif.k8s_pod_yaml == render_k8s_pod_yaml.return_value
         # K8s pod spec dict was passed to redact
         redact.assert_any_call(rtif.k8s_pod_yaml)
 
         with create_session() as session:
             session.add(rtif)
+            session.flush()
 
-        assert expected_pod_yaml == RTIF.get_k8s_pod_yaml(ti=ti)
+            assert expected_pod_yaml == RTIF.get_k8s_pod_yaml(ti=ti, session=session)
+            session.rollback()
 
-        # Test the else part of get_k8s_pod_yaml
-        # i.e. for the TIs that are not stored in RTIF table
-        # Fetching them will return None
-        with dag:
-            task_2 = BashOperator(task_id="test2", bash_command="echo hello")
-
-        ti2 = TI(task_2, EXECUTION_DATE)
-        assert RTIF.get_k8s_pod_yaml(ti=ti2) is None
+            # Test the else part of get_k8s_pod_yaml
+            # i.e. for the TIs that are not stored in RTIF table
+            # Fetching them will return None
+            assert RTIF.get_k8s_pod_yaml(ti=ti, session=session) is None
 
     @mock.patch.dict(os.environ, {"AIRFLOW_VAR_API_KEY": "secret"})
     @mock.patch('airflow.utils.log.secrets_masker.redact', autospec=True)
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 6a6c6eb..0d14e95 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1851,7 +1851,8 @@ class TestTaskInstance(unittest.TestCase):
             session.query(RenderedTaskInstanceFields).delete()
 
     @mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"})
-    def test_get_rendered_k8s_spec(self):
+    @mock.patch("airflow.settings.pod_mutation_hook")
+    def test_render_k8s_pod_yaml(self, pod_mutation_hook):
         with DAG('test_get_rendered_k8s_spec', start_date=DEFAULT_DATE):
             task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}")
 
@@ -1896,23 +1897,38 @@ class TestTaskInstance(unittest.TestCase):
             },
         }
 
-        with create_session() as session:
-            rtif = RenderedTaskInstanceFields(ti)
-            session.add(rtif)
-            assert rtif.k8s_pod_yaml == expected_pod_spec
+        assert ti.render_k8s_pod_yaml() == expected_pod_spec
+        pod_mutation_hook.assert_called_once_with(mock.ANY)
 
+    @mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"})
+    @mock.patch.object(RenderedTaskInstanceFields, 'get_k8s_pod_yaml')
+    def test_get_rendered_k8s_spec(self, rtif_get_k8s_pod_yaml):
         # Create new TI for the same Task
         with DAG('test_get_rendered_k8s_spec', start_date=DEFAULT_DATE):
-            new_task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}")
+            task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}")
 
-        new_ti = TI(task=new_task, execution_date=DEFAULT_DATE)
-        pod_spec = new_ti.get_rendered_k8s_spec()
+        ti = TI(task=task, execution_date=DEFAULT_DATE)
 
-        assert expected_pod_spec == pod_spec
+        patcher = mock.patch.object(ti, 'render_k8s_pod_yaml', autospec=True)
 
-        # CleanUp
-        with create_session() as session:
-            session.query(RenderedTaskInstanceFields).delete()
+        fake_spec = {"ermagawds": "pods"}
+
+        session = mock.Mock()
+
+        with patcher as render_k8s_pod_yaml:
+            rtif_get_k8s_pod_yaml.return_value = fake_spec
+            assert ti.get_rendered_k8s_spec(session) == fake_spec
+
+            rtif_get_k8s_pod_yaml.assert_called_once_with(ti, session=session)
+            render_k8s_pod_yaml.assert_not_called()
+
+            # Now test that when we _dont_ find it in the DB, it calles render_k8s_pod_yaml
+            rtif_get_k8s_pod_yaml.return_value = None
+            render_k8s_pod_yaml.return_value = fake_spec
+
+            assert ti.get_rendered_k8s_spec(session) == fake_spec
+
+            render_k8s_pod_yaml.assert_called_once()
 
     def test_set_state_up_for_retry(self):
         dag = DAG('dag', start_date=DEFAULT_DATE)