You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2021/12/14 19:50:59 UTC

[airflow] 01/01: Lazy Jinja2 context (#20217)

This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 05a91b3e89842e9f8af32f2ecf0cae0a05ffeb7c
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Tue Dec 14 15:27:59 2021 +0800

    Lazy Jinja2 context (#20217)
    
    Co-authored-by: Jed Cunningham <66...@users.noreply.github.com>
    (cherry picked from commit 181d60cdd182a9523890bf4822a76ea80666ea92)
---
 airflow/models/baseoperator.py           | 25 ++++++++++------
 airflow/models/param.py                  |  3 +-
 airflow/models/xcom_arg.py               |  5 ++--
 airflow/ti_deps/deps/trigger_rule_dep.py | 15 ++++++++--
 airflow/utils/context.py                 | 49 ++++++++++++++++++++++--------
 airflow/utils/helpers.py                 | 51 ++++++++++++++++++++++++++++----
 airflow/utils/log/file_task_handler.py   | 21 ++++++-------
 tests/conftest.py                        |  1 +
 tests/models/test_taskinstance.py        | 21 +++++++++++++
 9 files changed, 148 insertions(+), 43 deletions(-)

diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 04aebee..9e4cedd 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -69,7 +69,7 @@ from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
 from airflow.triggers.base import BaseTrigger
 from airflow.utils import timezone
 from airflow.utils.edgemodifier import EdgeModifier
-from airflow.utils.helpers import validate_key
+from airflow.utils.helpers import render_template_as_native, render_template_to_string, validate_key
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.operator_resources import Resources
 from airflow.utils.session import provide_session
@@ -1042,7 +1042,11 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
         self.__dict__ = state
         self._log = logging.getLogger("airflow.task.operators")
 
-    def render_template_fields(self, context: Dict, jinja_env: Optional[jinja2.Environment] = None) -> None:
+    def render_template_fields(
+        self,
+        context: Context,
+        jinja_env: Optional[jinja2.Environment] = None,
+    ) -> None:
         """
         Template all attributes listed in template_fields. Note this operation is irreversible.
 
@@ -1060,7 +1064,7 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
         self,
         parent: Any,
         template_fields: Iterable[str],
-        context: Dict,
+        context: Context,
         jinja_env: jinja2.Environment,
         seen_oids: Set,
     ) -> None:
@@ -1073,7 +1077,7 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
     def render_template(
         self,
         content: Any,
-        context: Dict,
+        context: Context,
         jinja_env: Optional[jinja2.Environment] = None,
         seen_oids: Optional[Set] = None,
     ) -> Any:
@@ -1100,11 +1104,14 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
         from airflow.models.xcom_arg import XComArg
 
         if isinstance(content, str):
-            if any(content.endswith(ext) for ext in self.template_ext):
-                # Content contains a filepath
-                return jinja_env.get_template(content).render(**context)
+            if any(content.endswith(ext) for ext in self.template_ext):  # Content contains a filepath.
+                template = jinja_env.get_template(content)
             else:
-                return jinja_env.from_string(content).render(**context)
+                template = jinja_env.from_string(content)
+            if self.has_dag() and self.dag.render_template_as_native_obj:
+                return render_template_as_native(template, context)
+            return render_template_to_string(template, context)
+
         elif isinstance(content, (XComArg, DagParam)):
             return content.resolve(context)
 
@@ -1133,7 +1140,7 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
             return content
 
     def _render_nested_template_fields(
-        self, content: Any, context: Dict, jinja_env: jinja2.Environment, seen_oids: Set
+        self, content: Any, context: Context, jinja_env: jinja2.Environment, seen_oids: Set
     ) -> None:
         if id(content) not in seen_oids:
             seen_oids.add(id(content))
diff --git a/airflow/models/param.py b/airflow/models/param.py
index 53ac79a..6ae6593 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -21,6 +21,7 @@ from jsonschema import FormatChecker
 from jsonschema.exceptions import ValidationError
 
 from airflow.exceptions import AirflowException
+from airflow.utils.context import Context
 
 
 class NoValueSentinel:
@@ -215,7 +216,7 @@ class DagParam:
         self._name = name
         self._default = default
 
-    def resolve(self, context: Dict) -> Any:
+    def resolve(self, context: Context) -> Any:
         """Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
         default = self._default
         if not self._default:
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index dd08ab3..6503106 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -15,12 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from typing import Any, Dict, List, Optional, Sequence, Union
+from typing import Any, List, Optional, Sequence, Union
 
 from airflow.exceptions import AirflowException
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.taskmixin import TaskMixin
 from airflow.models.xcom import XCOM_RETURN_KEY
+from airflow.utils.context import Context
 from airflow.utils.edgemodifier import EdgeModifier
 
 
@@ -128,7 +129,7 @@ class XComArg(TaskMixin):
         """Proxy to underlying operator set_downstream method. Required by TaskMixin."""
         self.operator.set_downstream(task_or_task_list, edge_modifier)
 
-    def resolve(self, context: Dict) -> Any:
+    def resolve(self, context: Context) -> Any:
         """
         Pull XCom value for the existing arg. This method is run during ``op.execute()``
         in respectable context.
diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py
index 5d72410..2d02557 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -18,8 +18,10 @@
 
 from collections import Counter
 
+from sqlalchemy.orm import Session
+
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.state import State
 from airflow.utils.trigger_rule import TriggerRule as TR
 
@@ -82,7 +84,16 @@ class TriggerRuleDep(BaseTIDep):
 
     @provide_session
     def _evaluate_trigger_rule(
-        self, ti, successes, skipped, failed, upstream_failed, done, flag_upstream_failed, session
+        self,
+        ti,
+        successes,
+        skipped,
+        failed,
+        upstream_failed,
+        done,
+        flag_upstream_failed,
+        *,
+        session: Session = NEW_SESSION,
     ):
         """
         Yields a dependency status that indicate whether the given task instance's trigger
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index fca55c1..61f9319 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -19,8 +19,20 @@
 """Jinja2 template rendering context helper."""
 
 import contextlib
+import copy
 import warnings
-from typing import Any, Container, Dict, Iterable, Iterator, List, MutableMapping, Tuple
+from typing import (
+    AbstractSet,
+    Any,
+    Container,
+    Dict,
+    Iterator,
+    List,
+    MutableMapping,
+    Optional,
+    Tuple,
+    ValuesView,
+)
 
 _NOT_SET: Any = object()
 
@@ -74,16 +86,20 @@ class ConnectionAccessor:
             return default_conn
 
 
+class AirflowContextDeprecationWarning(DeprecationWarning):
+    """Warn for usage of deprecated context variables in a task."""
+
+
 def _create_deprecation_warning(key: str, replacements: List[str]) -> DeprecationWarning:
     message = f"Accessing {key!r} from the template is deprecated and will be removed in a future version."
     if not replacements:
-        return DeprecationWarning(message)
+        return AirflowContextDeprecationWarning(message)
     display_except_last = ", ".join(repr(r) for r in replacements[:-1])
     if display_except_last:
         message += f" Please use {display_except_last} or {replacements[-1]!r} instead."
     else:
         message += f" Please use {replacements[-1]!r} instead."
-    return DeprecationWarning(message)
+    return AirflowContextDeprecationWarning(message)
 
 
 class Context(MutableMapping[str, Any]):
@@ -108,8 +124,10 @@ class Context(MutableMapping[str, Any]):
         "yesterday_ds_nodash": [],
     }
 
-    def __init__(self, context: MutableMapping[str, Any]) -> None:
-        self._context = context
+    def __init__(self, context: Optional[MutableMapping[str, Any]] = None, **kwargs: Any) -> None:
+        self._context = context or {}
+        if kwargs:
+            self._context.update(kwargs)
         self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy()
 
     def __repr__(self) -> str:
@@ -124,9 +142,14 @@ class Context(MutableMapping[str, Any]):
         items = [(key, self[key]) for key in self._context]
         return dict, (items,)
 
+    def __copy__(self) -> "Context":
+        new = type(self)(copy.copy(self._context))
+        new._deprecation_replacements = self._deprecation_replacements.copy()
+        return new
+
     def __getitem__(self, key: str) -> Any:
         with contextlib.suppress(KeyError):
-            warnings.warn(_create_deprecation_warning(key, self._deprecation_replacements[key]), stacklevel=2)
+            warnings.warn(_create_deprecation_warning(key, self._deprecation_replacements[key]))
         with contextlib.suppress(KeyError):
             return self._context[key]
         raise KeyError(key)
@@ -139,7 +162,7 @@ class Context(MutableMapping[str, Any]):
         self._deprecation_replacements.pop(key, None)
         del self._context[key]
 
-    def __contains__(self, key: str) -> bool:
+    def __contains__(self, key: object) -> bool:
         return key in self._context
 
     def __iter__(self) -> Iterator[str]:
@@ -158,14 +181,16 @@ class Context(MutableMapping[str, Any]):
             return NotImplemented
         return self._context != other._context
 
-    def keys(self) -> Iterable[str]:
+    def keys(self) -> AbstractSet[str]:
         return self._context.keys()
 
-    def items(self) -> Iterable[Tuple[str, Any]]:
+    def items(self) -> AbstractSet[Tuple[str, Any]]:
         return self._context.items()
 
-    def values(self) -> Iterable[Any]:
+    def values(self) -> ValuesView[Any]:
         return self._context.values()
 
-    def copy_only(self, keys: Container[str]) -> "Context[str, Any]":
-        return type(self)({k: v for k, v in self._context.items() if k in keys})
+    def copy_only(self, keys: Container[str]) -> "Context":
+        new = type(self)({k: v for k, v in self._context.items() if k in keys})
+        new._deprecation_replacements = self._deprecation_replacements.copy()
+        return new
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index e6ab39a..c5f9f27 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -15,7 +15,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+import copy
 import re
 import warnings
 from datetime import datetime
@@ -24,11 +24,13 @@ from itertools import filterfalse, tee
 from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, TypeVar
 from urllib import parse
 
-from flask import url_for
-from jinja2 import Template
+import flask
+import jinja2
+import jinja2.nativetypes
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
+from airflow.utils.context import Context
 from airflow.utils.module_loading import import_string
 
 if TYPE_CHECKING:
@@ -146,7 +148,7 @@ def as_flattened_list(iterable: Iterable[Iterable[T]]) -> List[T]:
 def parse_template_string(template_string):
     """Parses Jinja template string."""
     if "{{" in template_string:  # jinja mode
-        return None, Template(template_string)
+        return None, jinja2.Template(template_string)
     else:
         return template_string, None
 
@@ -228,5 +230,44 @@ def build_airflow_url_with_query(query: Dict[str, Any]) -> str:
     'http://0.0.0.0:8000/base/graph?dag_id=my-task&root=&execution_date=2020-10-27T10%3A59%3A25.615587
     """
     view = conf.get('webserver', 'dag_default_view').lower()
-    url = url_for(f"Airflow.{view}")
+    url = flask.url_for(f"Airflow.{view}")
     return f"{url}?{parse.urlencode(query)}"
+
+
+# The 'template' argument is typed as Any because the jinja2.Template is too
+# dynamic to be effectively type-checked.
+def render_template(template: Any, context: Context, *, native: bool) -> Any:
+    """Render a Jinja2 template with given Airflow context.
+
+    The default implementation of ``jinja2.Template.render()`` converts the
+    input context into dict eagerly many times, which triggers deprecation
+    messages in our custom context class. This takes the implementation apart
+    and retain the context mapping without resolving instead.
+
+    :param template: A Jinja2 template to render.
+    :param context: The Airflow task context to render the template with.
+    :param native: If set to *True*, render the template into a native type. A
+        DAG can enable this with ``render_template_as_native_obj=True``.
+    :returns: The render result.
+    """
+    context = copy.copy(context)
+    env = template.environment
+    if template.globals:
+        context.update((k, v) for k, v in template.globals.items() if k not in context)
+    try:
+        nodes = template.root_render_func(env.context_class(env, context, template.name, template.blocks))
+    except Exception:
+        env.handle_exception()  # Rewrite traceback to point to the template.
+    if native:
+        return jinja2.nativetypes.native_concat(nodes)
+    return "".join(nodes)
+
+
+def render_template_to_string(template: jinja2.Template, context: Context) -> str:
+    """Shorthand to ``render_template(native=False)`` with better typing support."""
+    return render_template(template, context, native=False)
+
+
+def render_template_as_native(template: jinja2.Template, context: Context) -> Any:
+    """Shorthand to ``render_template(native=True)`` with better typing support."""
+    return render_template(template, context, native=True)
diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py
index 6d88c20..6e57c67 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -25,7 +25,8 @@ import httpx
 from itsdangerous import TimedJSONWebSignatureSerializer
 
 from airflow.configuration import AirflowConfigException, conf
-from airflow.utils.helpers import parse_template_string
+from airflow.utils.context import Context
+from airflow.utils.helpers import parse_template_string, render_template_to_string
 from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler
 
 if TYPE_CHECKING:
@@ -73,23 +74,19 @@ class FileTaskHandler(logging.Handler):
         if self.handler:
             self.handler.close()
 
-    def _render_filename(self, ti, try_number):
+    def _render_filename(self, ti: "TaskInstance", try_number: int) -> str:
         if self.filename_jinja_template:
-            if hasattr(ti, 'task'):
-                jinja_context = ti.get_template_context()
-                jinja_context['try_number'] = try_number
+            if hasattr(ti, "task"):
+                context = ti.get_template_context()
             else:
-                jinja_context = {
-                    'ti': ti,
-                    'ts': ti.execution_date.isoformat(),
-                    'try_number': try_number,
-                }
-            return self.filename_jinja_template.render(**jinja_context)
+                context = Context(ti=ti, ts=ti.get_dagrun().logical_date.isoformat())
+            context["try_number"] = try_number
+            return render_template_to_string(self.filename_jinja_template, context)
 
         return self.filename_template.format(
             dag_id=ti.dag_id,
             task_id=ti.task_id,
-            execution_date=ti.execution_date.isoformat(),
+            execution_date=ti.get_dagrun().logical_date.isoformat(),
             try_number=try_number,
         )
 
diff --git a/tests/conftest.py b/tests/conftest.py
index 94f915f..f7248d1 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -231,6 +231,7 @@ def breeze_test_helper(request):
 
 
 def pytest_configure(config):
+    config.addinivalue_line("filterwarnings", "error::airflow.utils.context.AirflowContextDeprecationWarning")
     config.addinivalue_line("markers", "integration(name): mark test to run with named integration")
     config.addinivalue_line("markers", "backend(name): mark test to run with named backend")
     config.addinivalue_line("markers", "system(name): mark test to run with named system")
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index f07147f..8458ea9 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1512,6 +1512,27 @@ class TestTaskInstance:
         assert isinstance(template_context["data_interval_start"], pendulum.DateTime)
         assert isinstance(template_context["data_interval_end"], pendulum.DateTime)
 
+    def test_template_render(self, create_task_instance):
+        ti = create_task_instance(
+            dag_id="test_template_render",
+            task_id="test_template_render_task",
+            schedule_interval="0 12 * * *",
+        )
+        template_context = ti.get_template_context()
+        result = ti.task.render_template("Task: {{ dag.dag_id }} -> {{ task.task_id }}", template_context)
+        assert result == "Task: test_template_render -> test_template_render_task"
+
+    def test_template_render_deprecated(self, create_task_instance):
+        ti = create_task_instance(
+            dag_id="test_template_render",
+            task_id="test_template_render_task",
+            schedule_interval="0 12 * * *",
+        )
+        template_context = ti.get_template_context()
+        with pytest.deprecated_call():
+            result = ti.task.render_template("Execution date: {{ execution_date }}", template_context)
+        assert result.startswith("Execution date: ")
+
     @pytest.mark.parametrize(
         "content, expected_output",
         [