You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/09/07 22:02:07 UTC
[airflow] branch main updated: Make `execution_date_or_run_id` optional in `tasks test` command (#26114)
This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 243b3d7158 Make `execution_date_or_run_id` optional in `tasks test` command (#26114)
243b3d7158 is described below
commit 243b3d71580b63fe3f8daf200215cc95bdac253a
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Wed Sep 7 15:01:59 2022 -0700
Make `execution_date_or_run_id` optional in `tasks test` command (#26114)
---
airflow/cli/cli_parser.py | 7 ++-
airflow/cli/commands/task_command.py | 68 +++++++++++++++------------
docs/apache-airflow/tutorial/fundamentals.rst | 4 +-
tests/cli/commands/test_task_command.py | 23 +++++++--
4 files changed, 66 insertions(+), 36 deletions(-)
diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py
index 231bbed490..21031bd1c2 100644
--- a/airflow/cli/cli_parser.py
+++ b/airflow/cli/cli_parser.py
@@ -183,6 +183,11 @@ ARG_EXECUTION_DATE_OPTIONAL = Arg(
ARG_EXECUTION_DATE_OR_RUN_ID = Arg(
('execution_date_or_run_id',), help="The execution_date of the DAG or run_id of the DAGRun"
)
+ARG_EXECUTION_DATE_OR_RUN_ID_OPTIONAL = Arg(
+ ('execution_date_or_run_id',),
+ nargs='?',
+ help="The execution_date of the DAG or run_id of the DAGRun (optional)",
+)
ARG_TASK_REGEX = Arg(
("-t", "--task-regex"), help="The regex to filter specific task_ids to backfill (optional)"
)
@@ -1296,7 +1301,7 @@ TASKS_COMMANDS = (
args=(
ARG_DAG_ID,
ARG_TASK_ID,
- ARG_EXECUTION_DATE_OR_RUN_ID,
+ ARG_EXECUTION_DATE_OR_RUN_ID_OPTIONAL,
ARG_SUBDIR,
ARG_DRY_RUN,
ARG_TASK_PARAMS,
diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py
index 694594d68d..f8916f0466 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -22,7 +22,7 @@ import json
import logging
import os
import textwrap
-from contextlib import contextmanager, redirect_stderr, redirect_stdout
+from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress
from typing import Dict, Generator, List, Optional, Tuple, Union
from pendulum.parsing.exceptions import ParserError
@@ -75,8 +75,8 @@ def _generate_temporary_run_id() -> str:
def _get_dag_run(
*,
dag: DAG,
- exec_date_or_run_id: str,
create_if_necessary: CreateIfNecessary,
+ exec_date_or_run_id: Optional[str] = None,
session: Session,
) -> Tuple[DagRun, bool]:
"""Try to retrieve a DAG run from a string representing either a run ID or logical date.
@@ -92,33 +92,35 @@ def _get_dag_run(
the logical date; otherwise use it as a run ID and set the logical date
to the current time.
"""
- dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session)
- if dag_run:
- return dag_run, False
-
- try:
- execution_date: Optional[datetime.datetime] = timezone.parse(exec_date_or_run_id)
- except (ParserError, TypeError):
- execution_date = None
-
- try:
- dag_run = (
- session.query(DagRun)
- .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date)
- .one()
- )
- except NoResultFound:
- if not create_if_necessary:
- raise DagRunNotFound(
- f"DagRun for {dag.dag_id} with run_id or execution_date of {exec_date_or_run_id!r} not found"
- ) from None
- else:
- return dag_run, False
+ if not exec_date_or_run_id and not create_if_necessary:
+ raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.")
+ execution_date: Optional[datetime.datetime] = None
+ if exec_date_or_run_id:
+ dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session)
+ if dag_run:
+ return dag_run, False
+ with suppress(ParserError, TypeError):
+ execution_date = timezone.parse(exec_date_or_run_id)
+ try:
+ dag_run = (
+ session.query(DagRun)
+ .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date)
+ .one()
+ )
+ except NoResultFound:
+ if not create_if_necessary:
+ raise DagRunNotFound(
+ f"DagRun for {dag.dag_id} with run_id or execution_date "
+ f"of {exec_date_or_run_id!r} not found"
+ ) from None
+ else:
+ return dag_run, False
if execution_date is not None:
dag_run_execution_date = execution_date
else:
dag_run_execution_date = timezone.utcnow()
+
if create_if_necessary == "memory":
dag_run = DagRun(dag.dag_id, run_id=exec_date_or_run_id, execution_date=dag_run_execution_date)
return dag_run, True
@@ -136,14 +138,16 @@ def _get_dag_run(
@provide_session
def _get_ti(
task: BaseOperator,
- exec_date_or_run_id: str,
map_index: int,
*,
+ exec_date_or_run_id: Optional[str] = None,
pool: Optional[str] = None,
create_if_necessary: CreateIfNecessary = False,
session: Session = NEW_SESSION,
) -> Tuple[TaskInstance, bool]:
"""Get the task instance through DagRun.run_id, if that fails, get the TI the old way"""
+ if not exec_date_or_run_id and not create_if_necessary:
+ raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.")
if task.is_mapped:
if map_index < 0:
raise RuntimeError("No map_index passed to mapped task")
@@ -370,7 +374,7 @@ def task_run(args, dag=None):
# Use DAG from parameter
pass
task = dag.get_task(task_id=args.task_id)
- ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index, pool=args.pool)
+ ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, pool=args.pool)
ti.init_run_context(raw=args.raw)
hostname = get_hostname()
@@ -398,7 +402,7 @@ def task_failed_deps(args):
"""
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
- ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index)
+ ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id)
dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS)
failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context))
@@ -421,7 +425,7 @@ def task_state(args):
"""
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
- ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index)
+ ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id)
print(ti.current_state())
@@ -544,7 +548,9 @@ def task_test(args, dag=None):
if task.params:
task.params.validate()
- ti, dr_created = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary="db")
+ ti, dr_created = _get_ti(
+ task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="db"
+ )
try:
with redirect_stdout(RedactedIO()):
@@ -574,7 +580,9 @@ def task_render(args):
"""Renders and displays templated fields for a given task"""
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
- ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary="memory")
+ ti, _ = _get_ti(
+ task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="memory"
+ )
ti.render_templates()
for attr in task.__class__.template_fields:
print(
diff --git a/docs/apache-airflow/tutorial/fundamentals.rst b/docs/apache-airflow/tutorial/fundamentals.rst
index 351c215912..d2071a0682 100644
--- a/docs/apache-airflow/tutorial/fundamentals.rst
+++ b/docs/apache-airflow/tutorial/fundamentals.rst
@@ -326,7 +326,7 @@ its data interval.
.. code-block:: bash
- # command layout: command subcommand dag_id task_id date
+ # command layout: command subcommand [dag_id] [task_id] [(optional) date]
# testing print_date
airflow tasks test tutorial print_date 2015-06-01
@@ -350,7 +350,7 @@ their log to stdout (on screen), does not bother with dependencies, and
does not communicate state (running, success, failed, ...) to the database.
It simply allows testing a single task instance.
-The same applies to ``airflow dags test [dag_id] [logical_date]``, but on a DAG
+The same applies to ``airflow dags test``, but on a DAG
level. It performs a single DAG run of the given DAG id. While it does take task
dependencies into account, no state is registered in the database. It is
convenient for locally testing a full run of your DAG, given that e.g. if one of
diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py
index 8476d7f3e9..f3defcaeb8 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
import io
import json
import logging
@@ -26,6 +27,7 @@ from contextlib import redirect_stdout
from pathlib import Path
from unittest import mock
+import pendulum
import pytest
from parameterized import parameterized
@@ -103,6 +105,21 @@ class TestCliTasks:
# Check that prints, and log messages, are shown
assert "'example_python_operator__print_the_context__20180101'" in stdout.getvalue()
+ @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
+ @mock.patch('airflow.utils.timezone.utcnow')
+ def test_test_no_execution_date(self, mock_utcnow):
+ """Test the `airflow test` command"""
+ now = pendulum.now('UTC')
+ mock_utcnow.return_value = now
+ ds = now.strftime("%Y%m%d")
+ args = self.parser.parse_args(["tasks", "test", "example_python_operator", 'print_the_context'])
+
+ with redirect_stdout(io.StringIO()) as stdout:
+ task_command.task_test(args)
+
+ # Check that prints, and log messages, are shown
+ assert f"'example_python_operator__print_the_context__{ds}'" in stdout.getvalue()
+
@pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
def test_test_with_existing_dag_run(self, caplog):
"""Test the `airflow test` command"""
@@ -255,9 +272,9 @@ class TestCliTasks:
'test',
'example_passing_params_via_test_command',
'run_this',
+ DEFAULT_DATE.isoformat(),
'--task-params',
'{"foo":"bar"}',
- DEFAULT_DATE.isoformat(),
]
)
)
@@ -268,9 +285,9 @@ class TestCliTasks:
'test',
'example_passing_params_via_test_command',
'also_run_this',
+ DEFAULT_DATE.isoformat(),
'--task-params',
'{"foo":"bar"}',
- DEFAULT_DATE.isoformat(),
]
)
)
@@ -284,9 +301,9 @@ class TestCliTasks:
'test',
'example_passing_params_via_test_command',
'env_var_test_task',
+ DEFAULT_DATE.isoformat(),
'--env-vars',
'{"foo":"bar"}',
- DEFAULT_DATE.isoformat(),
]
)
)