You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2021/12/14 18:13:48 UTC
[airflow] branch v2-2-test updated: Lazy Jinja2 context (#20217)
This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v2-2-test by this push:
new e66dd0b Lazy Jinja2 context (#20217)
e66dd0b is described below
commit e66dd0bfeaa4cc98281244d2bcaf38d8ec929614
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Tue Dec 14 15:27:59 2021 +0800
Lazy Jinja2 context (#20217)
Co-authored-by: Jed Cunningham <66...@users.noreply.github.com>
(cherry picked from commit 181d60cdd182a9523890bf4822a76ea80666ea92)
---
airflow/models/baseoperator.py | 25 +++++++++-----
airflow/models/param.py | 3 +-
airflow/models/xcom_arg.py | 5 +--
airflow/ti_deps/deps/trigger_rule_dep.py | 15 +++++++--
airflow/utils/context.py | 57 +++++++++++++++++++++++---------
airflow/utils/helpers.py | 51 +++++++++++++++++++++++++---
airflow/utils/log/file_task_handler.py | 21 +++++-------
tests/conftest.py | 1 +
tests/models/test_taskinstance.py | 21 ++++++++++++
9 files changed, 152 insertions(+), 47 deletions(-)
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 04aebee..9e4cedd 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -69,7 +69,7 @@ from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.triggers.base import BaseTrigger
from airflow.utils import timezone
from airflow.utils.edgemodifier import EdgeModifier
-from airflow.utils.helpers import validate_key
+from airflow.utils.helpers import render_template_as_native, render_template_to_string, validate_key
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.operator_resources import Resources
from airflow.utils.session import provide_session
@@ -1042,7 +1042,11 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
self.__dict__ = state
self._log = logging.getLogger("airflow.task.operators")
- def render_template_fields(self, context: Dict, jinja_env: Optional[jinja2.Environment] = None) -> None:
+ def render_template_fields(
+ self,
+ context: Context,
+ jinja_env: Optional[jinja2.Environment] = None,
+ ) -> None:
"""
Template all attributes listed in template_fields. Note this operation is irreversible.
@@ -1060,7 +1064,7 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
self,
parent: Any,
template_fields: Iterable[str],
- context: Dict,
+ context: Context,
jinja_env: jinja2.Environment,
seen_oids: Set,
) -> None:
@@ -1073,7 +1077,7 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
def render_template(
self,
content: Any,
- context: Dict,
+ context: Context,
jinja_env: Optional[jinja2.Environment] = None,
seen_oids: Optional[Set] = None,
) -> Any:
@@ -1100,11 +1104,14 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
from airflow.models.xcom_arg import XComArg
if isinstance(content, str):
- if any(content.endswith(ext) for ext in self.template_ext):
- # Content contains a filepath
- return jinja_env.get_template(content).render(**context)
+ if any(content.endswith(ext) for ext in self.template_ext): # Content contains a filepath.
+ template = jinja_env.get_template(content)
else:
- return jinja_env.from_string(content).render(**context)
+ template = jinja_env.from_string(content)
+ if self.has_dag() and self.dag.render_template_as_native_obj:
+ return render_template_as_native(template, context)
+ return render_template_to_string(template, context)
+
elif isinstance(content, (XComArg, DagParam)):
return content.resolve(context)
@@ -1133,7 +1140,7 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
return content
def _render_nested_template_fields(
- self, content: Any, context: Dict, jinja_env: jinja2.Environment, seen_oids: Set
+ self, content: Any, context: Context, jinja_env: jinja2.Environment, seen_oids: Set
) -> None:
if id(content) not in seen_oids:
seen_oids.add(id(content))
diff --git a/airflow/models/param.py b/airflow/models/param.py
index 53ac79a..6ae6593 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -21,6 +21,7 @@ from jsonschema import FormatChecker
from jsonschema.exceptions import ValidationError
from airflow.exceptions import AirflowException
+from airflow.utils.context import Context
class NoValueSentinel:
@@ -215,7 +216,7 @@ class DagParam:
self._name = name
self._default = default
- def resolve(self, context: Dict) -> Any:
+ def resolve(self, context: Context) -> Any:
"""Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
default = self._default
if not self._default:
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index dd08ab3..6503106 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -15,12 +15,13 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional, Sequence, Union
+from typing import Any, List, Optional, Sequence, Union
from airflow.exceptions import AirflowException
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskmixin import TaskMixin
from airflow.models.xcom import XCOM_RETURN_KEY
+from airflow.utils.context import Context
from airflow.utils.edgemodifier import EdgeModifier
@@ -128,7 +129,7 @@ class XComArg(TaskMixin):
"""Proxy to underlying operator set_downstream method. Required by TaskMixin."""
self.operator.set_downstream(task_or_task_list, edge_modifier)
- def resolve(self, context: Dict) -> Any:
+ def resolve(self, context: Context) -> Any:
"""
Pull XCom value for the existing arg. This method is run during ``op.execute()``
in respectable context.
diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py
index 5d72410..2d02557 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -18,8 +18,10 @@
from collections import Counter
+from sqlalchemy.orm import Session
+
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule as TR
@@ -82,7 +84,16 @@ class TriggerRuleDep(BaseTIDep):
@provide_session
def _evaluate_trigger_rule(
- self, ti, successes, skipped, failed, upstream_failed, done, flag_upstream_failed, session
+ self,
+ ti,
+ successes,
+ skipped,
+ failed,
+ upstream_failed,
+ done,
+ flag_upstream_failed,
+ *,
+ session: Session = NEW_SESSION,
):
"""
Yields a dependency status that indicate whether the given task instance's trigger
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index fca55c1..d1c75bc 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -19,10 +19,22 @@
"""Jinja2 template rendering context helper."""
import contextlib
+import copy
import warnings
-from typing import Any, Container, Dict, Iterable, Iterator, List, MutableMapping, Tuple
-
-_NOT_SET: Any = object()
+from typing import (
+ AbstractSet,
+ Any,
+ Container,
+ Dict,
+ Iterator,
+ List,
+ MutableMapping,
+ Optional,
+ Tuple,
+ ValuesView,
+)
+
+from airflow.utils.types import NOTSET
class VariableAccessor:
@@ -41,10 +53,10 @@ class VariableAccessor:
def __repr__(self) -> str:
return str(self.var)
- def get(self, key, default: Any = _NOT_SET) -> Any:
+ def get(self, key, default: Any = NOTSET) -> Any:
from airflow.models.variable import Variable
- if default is _NOT_SET:
+ if default is NOTSET:
return Variable.get(key, deserialize_json=self._deserialize_json)
return Variable.get(key, default, deserialize_json=self._deserialize_json)
@@ -74,16 +86,20 @@ class ConnectionAccessor:
return default_conn
+class AirflowContextDeprecationWarning(DeprecationWarning):
+ """Warn for usage of deprecated context variables in a task."""
+
+
def _create_deprecation_warning(key: str, replacements: List[str]) -> DeprecationWarning:
message = f"Accessing {key!r} from the template is deprecated and will be removed in a future version."
if not replacements:
- return DeprecationWarning(message)
+ return AirflowContextDeprecationWarning(message)
display_except_last = ", ".join(repr(r) for r in replacements[:-1])
if display_except_last:
message += f" Please use {display_except_last} or {replacements[-1]!r} instead."
else:
message += f" Please use {replacements[-1]!r} instead."
- return DeprecationWarning(message)
+ return AirflowContextDeprecationWarning(message)
class Context(MutableMapping[str, Any]):
@@ -108,8 +124,10 @@ class Context(MutableMapping[str, Any]):
"yesterday_ds_nodash": [],
}
- def __init__(self, context: MutableMapping[str, Any]) -> None:
- self._context = context
+ def __init__(self, context: Optional[MutableMapping[str, Any]] = None, **kwargs: Any) -> None:
+ self._context = context or {}
+ if kwargs:
+ self._context.update(kwargs)
self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy()
def __repr__(self) -> str:
@@ -124,9 +142,14 @@ class Context(MutableMapping[str, Any]):
items = [(key, self[key]) for key in self._context]
return dict, (items,)
+ def __copy__(self) -> "Context":
+ new = type(self)(copy.copy(self._context))
+ new._deprecation_replacements = self._deprecation_replacements.copy()
+ return new
+
def __getitem__(self, key: str) -> Any:
with contextlib.suppress(KeyError):
- warnings.warn(_create_deprecation_warning(key, self._deprecation_replacements[key]), stacklevel=2)
+ warnings.warn(_create_deprecation_warning(key, self._deprecation_replacements[key]))
with contextlib.suppress(KeyError):
return self._context[key]
raise KeyError(key)
@@ -139,7 +162,7 @@ class Context(MutableMapping[str, Any]):
self._deprecation_replacements.pop(key, None)
del self._context[key]
- def __contains__(self, key: str) -> bool:
+ def __contains__(self, key: object) -> bool:
return key in self._context
def __iter__(self) -> Iterator[str]:
@@ -158,14 +181,16 @@ class Context(MutableMapping[str, Any]):
return NotImplemented
return self._context != other._context
- def keys(self) -> Iterable[str]:
+ def keys(self) -> AbstractSet[str]:
return self._context.keys()
- def items(self) -> Iterable[Tuple[str, Any]]:
+ def items(self) -> AbstractSet[Tuple[str, Any]]:
return self._context.items()
- def values(self) -> Iterable[Any]:
+ def values(self) -> ValuesView[Any]:
return self._context.values()
- def copy_only(self, keys: Container[str]) -> "Context[str, Any]":
- return type(self)({k: v for k, v in self._context.items() if k in keys})
+ def copy_only(self, keys: Container[str]) -> "Context":
+ new = type(self)({k: v for k, v in self._context.items() if k in keys})
+ new._deprecation_replacements = self._deprecation_replacements.copy()
+ return new
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index e6ab39a..c5f9f27 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+import copy
import re
import warnings
from datetime import datetime
@@ -24,11 +24,13 @@ from itertools import filterfalse, tee
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, TypeVar
from urllib import parse
-from flask import url_for
-from jinja2 import Template
+import flask
+import jinja2
+import jinja2.nativetypes
from airflow.configuration import conf
from airflow.exceptions import AirflowException
+from airflow.utils.context import Context
from airflow.utils.module_loading import import_string
if TYPE_CHECKING:
@@ -146,7 +148,7 @@ def as_flattened_list(iterable: Iterable[Iterable[T]]) -> List[T]:
def parse_template_string(template_string):
"""Parses Jinja template string."""
if "{{" in template_string: # jinja mode
- return None, Template(template_string)
+ return None, jinja2.Template(template_string)
else:
return template_string, None
@@ -228,5 +230,44 @@ def build_airflow_url_with_query(query: Dict[str, Any]) -> str:
'http://0.0.0.0:8000/base/graph?dag_id=my-task&root=&execution_date=2020-10-27T10%3A59%3A25.615587
"""
view = conf.get('webserver', 'dag_default_view').lower()
- url = url_for(f"Airflow.{view}")
+ url = flask.url_for(f"Airflow.{view}")
return f"{url}?{parse.urlencode(query)}"
+
+
+# The 'template' argument is typed as Any because the jinja2.Template is too
+# dynamic to be effectively type-checked.
+def render_template(template: Any, context: Context, *, native: bool) -> Any:
+ """Render a Jinja2 template with given Airflow context.
+
+ The default implementation of ``jinja2.Template.render()`` converts the
+ input context into dict eagerly many times, which triggers deprecation
+ messages in our custom context class. This takes the implementation apart
+ and retain the context mapping without resolving instead.
+
+ :param template: A Jinja2 template to render.
+ :param context: The Airflow task context to render the template with.
+ :param native: If set to *True*, render the template into a native type. A
+ DAG can enable this with ``render_template_as_native_obj=True``.
+ :returns: The render result.
+ """
+ context = copy.copy(context)
+ env = template.environment
+ if template.globals:
+ context.update((k, v) for k, v in template.globals.items() if k not in context)
+ try:
+ nodes = template.root_render_func(env.context_class(env, context, template.name, template.blocks))
+ except Exception:
+ env.handle_exception() # Rewrite traceback to point to the template.
+ if native:
+ return jinja2.nativetypes.native_concat(nodes)
+ return "".join(nodes)
+
+
+def render_template_to_string(template: jinja2.Template, context: Context) -> str:
+ """Shorthand to ``render_template(native=False)`` with better typing support."""
+ return render_template(template, context, native=False)
+
+
+def render_template_as_native(template: jinja2.Template, context: Context) -> Any:
+ """Shorthand to ``render_template(native=True)`` with better typing support."""
+ return render_template(template, context, native=True)
diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py
index 6d88c20..6e57c67 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -25,7 +25,8 @@ import httpx
from itsdangerous import TimedJSONWebSignatureSerializer
from airflow.configuration import AirflowConfigException, conf
-from airflow.utils.helpers import parse_template_string
+from airflow.utils.context import Context
+from airflow.utils.helpers import parse_template_string, render_template_to_string
from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler
if TYPE_CHECKING:
@@ -73,23 +74,19 @@ class FileTaskHandler(logging.Handler):
if self.handler:
self.handler.close()
- def _render_filename(self, ti, try_number):
+ def _render_filename(self, ti: "TaskInstance", try_number: int) -> str:
if self.filename_jinja_template:
- if hasattr(ti, 'task'):
- jinja_context = ti.get_template_context()
- jinja_context['try_number'] = try_number
+ if hasattr(ti, "task"):
+ context = ti.get_template_context()
else:
- jinja_context = {
- 'ti': ti,
- 'ts': ti.execution_date.isoformat(),
- 'try_number': try_number,
- }
- return self.filename_jinja_template.render(**jinja_context)
+ context = Context(ti=ti, ts=ti.get_dagrun().logical_date.isoformat())
+ context["try_number"] = try_number
+ return render_template_to_string(self.filename_jinja_template, context)
return self.filename_template.format(
dag_id=ti.dag_id,
task_id=ti.task_id,
- execution_date=ti.execution_date.isoformat(),
+ execution_date=ti.get_dagrun().logical_date.isoformat(),
try_number=try_number,
)
diff --git a/tests/conftest.py b/tests/conftest.py
index 94f915f..f7248d1 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -231,6 +231,7 @@ def breeze_test_helper(request):
def pytest_configure(config):
+ config.addinivalue_line("filterwarnings", "error::airflow.utils.context.AirflowContextDeprecationWarning")
config.addinivalue_line("markers", "integration(name): mark test to run with named integration")
config.addinivalue_line("markers", "backend(name): mark test to run with named backend")
config.addinivalue_line("markers", "system(name): mark test to run with named system")
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index f07147f..8458ea9 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1512,6 +1512,27 @@ class TestTaskInstance:
assert isinstance(template_context["data_interval_start"], pendulum.DateTime)
assert isinstance(template_context["data_interval_end"], pendulum.DateTime)
+ def test_template_render(self, create_task_instance):
+ ti = create_task_instance(
+ dag_id="test_template_render",
+ task_id="test_template_render_task",
+ schedule_interval="0 12 * * *",
+ )
+ template_context = ti.get_template_context()
+ result = ti.task.render_template("Task: {{ dag.dag_id }} -> {{ task.task_id }}", template_context)
+ assert result == "Task: test_template_render -> test_template_render_task"
+
+ def test_template_render_deprecated(self, create_task_instance):
+ ti = create_task_instance(
+ dag_id="test_template_render",
+ task_id="test_template_render_task",
+ schedule_interval="0 12 * * *",
+ )
+ template_context = ti.get_template_context()
+ with pytest.deprecated_call():
+ result = ti.task.render_template("Execution date: {{ execution_date }}", template_context)
+ assert result.startswith("Execution date: ")
+
@pytest.mark.parametrize(
"content, expected_output",
[