You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/11/20 23:44:19 UTC

[airflow] 01/07: Create UndefinedJinjaVariablesRule (Resolves #11144) (#11241)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 415cda438e1d21fb9ca3a64864218f711913b014
Author: Ashmeet Lamba <as...@gmail.com>
AuthorDate: Thu Nov 19 16:33:06 2020 +0530

    Create UndefinedJinjaVariablesRule (Resolves #11144) (#11241)
    
    Adding a rule to check for undefined jinja variables when upgrading to Airflow2.0
    
    (cherry picked from commit 18100a0ec96692bb4d7c9e80f206b66a30c65e0d)
---
 airflow/models/dag.py                              |   4 +-
 airflow/upgrade/rules/undefined_jinja_varaibles.py | 153 ++++++++++++++++
 .../rules/test_undefined_jinja_varaibles.py        | 192 +++++++++++++++++++++
 3 files changed, 347 insertions(+), 2 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 348e19d..a1908e3 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -224,7 +224,7 @@ class DAG(BaseDag, LoggingMixin):
         end_date=None,  # type: Optional[datetime]
         full_filepath=None,  # type: Optional[str]
         template_searchpath=None,  # type: Optional[Union[str, Iterable[str]]]
-        template_undefined=jinja2.Undefined,  # type: Type[jinja2.Undefined]
+        template_undefined=None,  # type: Optional[Type[jinja2.Undefined]]
         user_defined_macros=None,  # type: Optional[Dict]
         user_defined_filters=None,  # type: Optional[Dict]
         default_args=None,  # type: Optional[Dict]
@@ -807,7 +807,7 @@ class DAG(BaseDag, LoggingMixin):
         # Default values (for backward compatibility)
         jinja_env_options = {
             'loader': jinja2.FileSystemLoader(searchpath),
-            'undefined': self.template_undefined,
+            'undefined': self.template_undefined or jinja2.Undefined,
             'extensions': ["jinja2.ext.do"],
             'cache_size': 0
         }
diff --git a/airflow/upgrade/rules/undefined_jinja_varaibles.py b/airflow/upgrade/rules/undefined_jinja_varaibles.py
new file mode 100644
index 0000000..b97cfbc
--- /dev/null
+++ b/airflow/upgrade/rules/undefined_jinja_varaibles.py
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import absolute_import
+
+import re
+
+import jinja2
+import six
+
+from airflow import conf
+from airflow.models import DagBag, TaskInstance
+from airflow.upgrade.rules.base_rule import BaseRule
+from airflow.utils import timezone
+
+
+class UndefinedJinjaVariablesRule(BaseRule):
+
+    title = "Jinja Template Variables cannot be undefined"
+
+    description = """\
+The default behavior for DAG's Jinja templates has changed. Now, more restrictive validation
+of non-existent variables is applied - `jinja2.StrictUndefined`.
+
+The user should do either of the following to fix this -
+1. Fix the Jinja Templates by defining every variable or providing default values
+2. Explicitly declare `template_undefined=jinja2.Undefined` while defining the DAG
+"""
+
+    def _check_rendered_content(self, rendered_content, seen_oids=None):
+        """Replicates the logic in BaseOperator.render_template() to
+        cover all the cases needed to be checked.
+        """
+        if isinstance(rendered_content, six.string_types):
+            return set(re.findall(r"{{(.*?)}}", rendered_content))
+
+        elif isinstance(rendered_content, (int, float, bool)):
+            return set()
+
+        elif isinstance(rendered_content, (tuple, list, set)):
+            debug_error_messages = set()
+            for element in rendered_content:
+                debug_error_messages.update(self._check_rendered_content(element))
+            return debug_error_messages
+
+        elif isinstance(rendered_content, dict):
+            debug_error_messages = set()
+            for key, value in rendered_content.items():
+                debug_error_messages.update(self._check_rendered_content(value))
+            return debug_error_messages
+
+        else:
+            if seen_oids is None:
+                seen_oids = set()
+            return self._nested_check_rendered(rendered_content, seen_oids)
+
+    def _nested_check_rendered(self, rendered_content, seen_oids):
+        debug_error_messages = set()
+        if id(rendered_content) not in seen_oids:
+            seen_oids.add(id(rendered_content))
+            nested_template_fields = rendered_content.template_fields
+            for attr_name in nested_template_fields:
+                nested_rendered_content = getattr(rendered_content, attr_name)
+
+                if nested_rendered_content:
+                    errors = list(
+                        self._check_rendered_content(nested_rendered_content, seen_oids)
+                    )
+                    for i in range(len(errors)):
+                        errors[i].strip()
+                        errors[i] += " NestedTemplateField={}".format(attr_name)
+                    debug_error_messages.update(errors)
+        return debug_error_messages
+
+    def _render_task_content(self, task, content, context):
+        completed_rendering = False
+        errors_while_rendering = []
+        while not completed_rendering:
+            # Catch errors such as {{ object.element }} where
+            # object is not defined
+            try:
+                renderend_content = task.render_template(content, context)
+                completed_rendering = True
+            except Exception as e:
+                undefined_variable = re.sub(" is undefined", "", str(e))
+                undefined_variable = re.sub("'", "", undefined_variable)
+                context[undefined_variable] = dict()
+                message = "Could not find the object '{}'".format(undefined_variable)
+                errors_while_rendering.append(message)
+        return renderend_content, errors_while_rendering
+
+    def iterate_over_template_fields(self, task):
+        messages = {}
+        task_instance = TaskInstance(task=task, execution_date=timezone.utcnow())
+        context = task_instance.get_template_context()
+        for attr_name in task.template_fields:
+            content = getattr(task, attr_name)
+            if content:
+                rendered_content, errors_while_rendering = self._render_task_content(
+                    task, content, context
+                )
+                debug_error_messages = list(
+                    self._check_rendered_content(rendered_content, set())
+                )
+                messages[attr_name] = errors_while_rendering + debug_error_messages
+
+        return messages
+
+    def iterate_over_dag_tasks(self, dag):
+        dag.template_undefined = jinja2.DebugUndefined
+        tasks = dag.tasks
+        messages = {}
+        for task in tasks:
+            error_messages = self.iterate_over_template_fields(task)
+            messages[task.task_id] = error_messages
+        return messages
+
+    def check(self, dagbag=None):
+        if not dagbag:
+            dag_folder = conf.get("core", "dags_folder")
+            dagbag = DagBag(dag_folder)
+        dags = dagbag.dags
+        messages = []
+        for dag_id, dag in dags.items():
+            if dag.template_undefined:
+                continue
+            dag_messages = self.iterate_over_dag_tasks(dag)
+
+            for task_id, task_messages in dag_messages.items():
+                for attr_name, error_messages in task_messages.items():
+                    for error_message in error_messages:
+                        message = (
+                            "Possible UndefinedJinjaVariable -> DAG: {}, Task: {}, "
+                            "Attribute: {}, Error: {}".format(
+                                dag_id, task_id, attr_name, error_message.strip()
+                            )
+                        )
+                        messages.append(message)
+        return messages
diff --git a/tests/upgrade/rules/test_undefined_jinja_varaibles.py b/tests/upgrade/rules/test_undefined_jinja_varaibles.py
new file mode 100644
index 0000000..83f99a3
--- /dev/null
+++ b/tests/upgrade/rules/test_undefined_jinja_varaibles.py
@@ -0,0 +1,192 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from tempfile import mkdtemp
+from unittest import TestCase
+
+import jinja2
+
+from airflow import DAG
+from airflow.models import DagBag
+from airflow.operators.bash_operator import BashOperator
+from airflow.upgrade.rules.undefined_jinja_varaibles import UndefinedJinjaVariablesRule
+from tests.models import DEFAULT_DATE
+
+
+class ClassWithCustomAttributes:
+    """Class for testing purpose: allows to create objects with custom attributes in one single statement."""
+
+    def __init__(self, **kwargs):
+        for key, value in kwargs.items():
+            setattr(self, key, value)
+
+    def __str__(self):
+        return "{}({})".format(ClassWithCustomAttributes.__name__, str(self.__dict__))
+
+    def __repr__(self):
+        return self.__str__()
+
+    def __eq__(self, other):
+        return self.__dict__ == other.__dict__
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+
+class TestUndefinedJinjaVariablesRule(TestCase):
+    @classmethod
+    def setUpClass(cls):
+        cls.empty_dir = mkdtemp()
+
+    def setUpValidDag(self):
+        self.valid_dag = DAG(
+            dag_id="test-defined-jinja-variables", start_date=DEFAULT_DATE
+        )
+
+        BashOperator(
+            task_id="templated_string",
+            depends_on_past=False,
+            bash_command="echo",
+            env={
+                "integer": "{{ params.integer }}",
+                "float": "{{ params.float }}",
+                "string": "{{ params.string }}",
+                "boolean": "{{ params.boolean }}",
+            },
+            params={
+                "integer": 1,
+                "float": 1.0,
+                "string": "test_string",
+                "boolean": True,
+            },
+            dag=self.valid_dag,
+        )
+
+    def setUpDagToSkip(self):
+        self.skip_dag = DAG(
+            dag_id="test-defined-jinja-variables",
+            start_date=DEFAULT_DATE,
+            template_undefined=jinja2.Undefined,
+        )
+
+        BashOperator(
+            task_id="templated_string",
+            depends_on_past=False,
+            bash_command="{{ undefined }}",
+            dag=self.skip_dag,
+        )
+
+    def setUpInvalidDag(self):
+        self.invalid_dag = DAG(
+            dag_id="test-undefined-jinja-variables", start_date=DEFAULT_DATE
+        )
+
+        invalid_template_command = """
+            {% for i in range(5) %}
+                echo "{{ params.defined_variable }}"
+                echo "{{ execution_date.today }}"
+                echo "{{ execution_date.invalid_element }}"
+                echo "{{ params.undefined_variable }}"
+                echo "{{ foo }}"
+            {% endfor %}
+            """
+
+        nested_validation = ClassWithCustomAttributes(
+            nested1=ClassWithCustomAttributes(
+                att1="{{ nested.undefined }}", template_fields=["att1"]
+            ),
+            nested2=ClassWithCustomAttributes(
+                att2="{{ bar }}", template_fields=["att2"]
+            ),
+            template_fields=["nested1", "nested2"],
+        )
+
+        BashOperator(
+            task_id="templated_string",
+            depends_on_past=False,
+            bash_command=invalid_template_command,
+            env={
+                "undefined_object": "{{ undefined_object.element }}",
+                "nested_object": nested_validation,
+            },
+            params={"defined_variable": "defined_value"},
+            dag=self.invalid_dag,
+        )
+
+    def setUp(self):
+        self.setUpValidDag()
+        self.setUpDagToSkip()
+        self.setUpInvalidDag()
+
+    def test_description_and_title_is_defined(self):
+        rule = UndefinedJinjaVariablesRule()
+        assert isinstance(rule.description, str)
+        assert isinstance(rule.title, str)
+
+    def test_valid_check(self):
+        dagbag = DagBag(dag_folder=self.empty_dir, include_examples=False)
+        dagbag.dags[self.valid_dag.dag_id] = self.valid_dag
+        rule = UndefinedJinjaVariablesRule()
+
+        messages = rule.check(dagbag)
+
+        assert len(messages) == 0
+
+    def test_skipping_dag_check(self):
+        dagbag = DagBag(dag_folder=self.empty_dir, include_examples=False)
+        dagbag.dags[self.skip_dag.dag_id] = self.skip_dag
+        rule = UndefinedJinjaVariablesRule()
+
+        messages = rule.check(dagbag)
+
+        assert len(messages) == 0
+
+    def test_invalid_check(self):
+        dagbag = DagBag(dag_folder=self.empty_dir, include_examples=False)
+        dagbag.dags[self.invalid_dag.dag_id] = self.invalid_dag
+        rule = UndefinedJinjaVariablesRule()
+
+        messages = rule.check(dagbag)
+
+        expected_messages = [
+            "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
+            "Task: templated_string, Attribute: bash_command, Error: no such element: "
+            "dict object['undefined_variable']",
+            "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
+            "Task: templated_string, Attribute: bash_command, Error: no such element: "
+            "pendulum.pendulum.Pendulum object['invalid_element']",
+            "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
+            "Task: templated_string, Attribute: bash_command, Error: foo",
+            "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
+            "Task: templated_string, Attribute: env, Error: Could not find the "
+            "object 'undefined_object",
+            "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
+            "Task: templated_string, Attribute: env, Error: Could not find the object 'nested'",
+            "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
+            "Task: templated_string, Attribute: env, Error: bar  NestedTemplateField=att2 "
+            "NestedTemplateField=nested2",
+            "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
+            "Task: templated_string, Attribute: env, Error: no such element: "
+            "dict object['undefined']  NestedTemplateField=att1 NestedTemplateField=nested1",
+            "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
+            "Task: templated_string, Attribute: env, Error: no such element: dict object['element']",
+        ]
+
+        assert len(messages) == len(expected_messages)
+        assert [m for m in messages if m in expected_messages], len(messages) == len(
+            expected_messages
+        )