You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/09/07 22:02:07 UTC

[airflow] branch main updated: Make `execution_date_or_run_id` optional in `tasks test` command (#26114)

This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 243b3d7158 Make `execution_date_or_run_id` optional in `tasks test` command (#26114)
243b3d7158 is described below

commit 243b3d71580b63fe3f8daf200215cc95bdac253a
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Wed Sep 7 15:01:59 2022 -0700

    Make `execution_date_or_run_id` optional in `tasks test` command (#26114)
---
 airflow/cli/cli_parser.py                     |  7 ++-
 airflow/cli/commands/task_command.py          | 68 +++++++++++++++------------
 docs/apache-airflow/tutorial/fundamentals.rst |  4 +-
 tests/cli/commands/test_task_command.py       | 23 +++++++--
 4 files changed, 66 insertions(+), 36 deletions(-)

diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py
index 231bbed490..21031bd1c2 100644
--- a/airflow/cli/cli_parser.py
+++ b/airflow/cli/cli_parser.py
@@ -183,6 +183,11 @@ ARG_EXECUTION_DATE_OPTIONAL = Arg(
 ARG_EXECUTION_DATE_OR_RUN_ID = Arg(
     ('execution_date_or_run_id',), help="The execution_date of the DAG or run_id of the DAGRun"
 )
+ARG_EXECUTION_DATE_OR_RUN_ID_OPTIONAL = Arg(
+    ('execution_date_or_run_id',),
+    nargs='?',
+    help="The execution_date of the DAG or run_id of the DAGRun (optional)",
+)
 ARG_TASK_REGEX = Arg(
     ("-t", "--task-regex"), help="The regex to filter specific task_ids to backfill (optional)"
 )
@@ -1296,7 +1301,7 @@ TASKS_COMMANDS = (
         args=(
             ARG_DAG_ID,
             ARG_TASK_ID,
-            ARG_EXECUTION_DATE_OR_RUN_ID,
+            ARG_EXECUTION_DATE_OR_RUN_ID_OPTIONAL,
             ARG_SUBDIR,
             ARG_DRY_RUN,
             ARG_TASK_PARAMS,
diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py
index 694594d68d..f8916f0466 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -22,7 +22,7 @@ import json
 import logging
 import os
 import textwrap
-from contextlib import contextmanager, redirect_stderr, redirect_stdout
+from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress
 from typing import Dict, Generator, List, Optional, Tuple, Union
 
 from pendulum.parsing.exceptions import ParserError
@@ -75,8 +75,8 @@ def _generate_temporary_run_id() -> str:
 def _get_dag_run(
     *,
     dag: DAG,
-    exec_date_or_run_id: str,
     create_if_necessary: CreateIfNecessary,
+    exec_date_or_run_id: Optional[str] = None,
     session: Session,
 ) -> Tuple[DagRun, bool]:
     """Try to retrieve a DAG run from a string representing either a run ID or logical date.
@@ -92,33 +92,35 @@ def _get_dag_run(
        the logical date; otherwise use it as a run ID and set the logical date
        to the current time.
     """
-    dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session)
-    if dag_run:
-        return dag_run, False
-
-    try:
-        execution_date: Optional[datetime.datetime] = timezone.parse(exec_date_or_run_id)
-    except (ParserError, TypeError):
-        execution_date = None
-
-    try:
-        dag_run = (
-            session.query(DagRun)
-            .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date)
-            .one()
-        )
-    except NoResultFound:
-        if not create_if_necessary:
-            raise DagRunNotFound(
-                f"DagRun for {dag.dag_id} with run_id or execution_date of {exec_date_or_run_id!r} not found"
-            ) from None
-    else:
-        return dag_run, False
+    if not exec_date_or_run_id and not create_if_necessary:
+        raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.")
+    execution_date: Optional[datetime.datetime] = None
+    if exec_date_or_run_id:
+        dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session)
+        if dag_run:
+            return dag_run, False
+        with suppress(ParserError, TypeError):
+            execution_date = timezone.parse(exec_date_or_run_id)
+        try:
+            dag_run = (
+                session.query(DagRun)
+                .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date)
+                .one()
+            )
+        except NoResultFound:
+            if not create_if_necessary:
+                raise DagRunNotFound(
+                    f"DagRun for {dag.dag_id} with run_id or execution_date "
+                    f"of {exec_date_or_run_id!r} not found"
+                ) from None
+        else:
+            return dag_run, False
 
     if execution_date is not None:
         dag_run_execution_date = execution_date
     else:
         dag_run_execution_date = timezone.utcnow()
+
     if create_if_necessary == "memory":
         dag_run = DagRun(dag.dag_id, run_id=exec_date_or_run_id, execution_date=dag_run_execution_date)
         return dag_run, True
@@ -136,14 +138,16 @@ def _get_dag_run(
 @provide_session
 def _get_ti(
     task: BaseOperator,
-    exec_date_or_run_id: str,
     map_index: int,
     *,
+    exec_date_or_run_id: Optional[str] = None,
     pool: Optional[str] = None,
     create_if_necessary: CreateIfNecessary = False,
     session: Session = NEW_SESSION,
 ) -> Tuple[TaskInstance, bool]:
     """Get the task instance through DagRun.run_id, if that fails, get the TI the old way"""
+    if not exec_date_or_run_id and not create_if_necessary:
+        raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.")
     if task.is_mapped:
         if map_index < 0:
             raise RuntimeError("No map_index passed to mapped task")
@@ -370,7 +374,7 @@ def task_run(args, dag=None):
         # Use DAG from parameter
         pass
     task = dag.get_task(task_id=args.task_id)
-    ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index, pool=args.pool)
+    ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, pool=args.pool)
     ti.init_run_context(raw=args.raw)
 
     hostname = get_hostname()
@@ -398,7 +402,7 @@ def task_failed_deps(args):
     """
     dag = get_dag(args.subdir, args.dag_id)
     task = dag.get_task(task_id=args.task_id)
-    ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index)
+    ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id)
 
     dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS)
     failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context))
@@ -421,7 +425,7 @@ def task_state(args):
     """
     dag = get_dag(args.subdir, args.dag_id)
     task = dag.get_task(task_id=args.task_id)
-    ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index)
+    ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id)
     print(ti.current_state())
 
 
@@ -544,7 +548,9 @@ def task_test(args, dag=None):
     if task.params:
         task.params.validate()
 
-    ti, dr_created = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary="db")
+    ti, dr_created = _get_ti(
+        task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="db"
+    )
 
     try:
         with redirect_stdout(RedactedIO()):
@@ -574,7 +580,9 @@ def task_render(args):
     """Renders and displays templated fields for a given task"""
     dag = get_dag(args.subdir, args.dag_id)
     task = dag.get_task(task_id=args.task_id)
-    ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary="memory")
+    ti, _ = _get_ti(
+        task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="memory"
+    )
     ti.render_templates()
     for attr in task.__class__.template_fields:
         print(
diff --git a/docs/apache-airflow/tutorial/fundamentals.rst b/docs/apache-airflow/tutorial/fundamentals.rst
index 351c215912..d2071a0682 100644
--- a/docs/apache-airflow/tutorial/fundamentals.rst
+++ b/docs/apache-airflow/tutorial/fundamentals.rst
@@ -326,7 +326,7 @@ its data interval.
 
 .. code-block:: bash
 
-    # command layout: command subcommand dag_id task_id date
+    # command layout: command subcommand [dag_id] [task_id] [(optional) date]
 
     # testing print_date
     airflow tasks test tutorial print_date 2015-06-01
@@ -350,7 +350,7 @@ their log to stdout (on screen), does not bother with dependencies, and
 does not communicate state (running, success, failed, ...) to the database.
 It simply allows testing a single task instance.
 
-The same applies to ``airflow dags test [dag_id] [logical_date]``, but on a DAG
+The same applies to ``airflow dags test``, but on a DAG
 level. It performs a single DAG run of the given DAG id. While it does take task
 dependencies into account, no state is registered in the database. It is
 convenient for locally testing a full run of your DAG, given that e.g. if one of
diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py
index 8476d7f3e9..f3defcaeb8 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -15,6 +15,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 import io
 import json
 import logging
@@ -26,6 +27,7 @@ from contextlib import redirect_stdout
 from pathlib import Path
 from unittest import mock
 
+import pendulum
 import pytest
 from parameterized import parameterized
 
@@ -103,6 +105,21 @@ class TestCliTasks:
         # Check that prints, and log messages, are shown
         assert "'example_python_operator__print_the_context__20180101'" in stdout.getvalue()
 
+    @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
+    @mock.patch('airflow.utils.timezone.utcnow')
+    def test_test_no_execution_date(self, mock_utcnow):
+        """Test the `airflow test` command"""
+        now = pendulum.now('UTC')
+        mock_utcnow.return_value = now
+        ds = now.strftime("%Y%m%d")
+        args = self.parser.parse_args(["tasks", "test", "example_python_operator", 'print_the_context'])
+
+        with redirect_stdout(io.StringIO()) as stdout:
+            task_command.task_test(args)
+
+        # Check that prints, and log messages, are shown
+        assert f"'example_python_operator__print_the_context__{ds}'" in stdout.getvalue()
+
     @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
     def test_test_with_existing_dag_run(self, caplog):
         """Test the `airflow test` command"""
@@ -255,9 +272,9 @@ class TestCliTasks:
                     'test',
                     'example_passing_params_via_test_command',
                     'run_this',
+                    DEFAULT_DATE.isoformat(),
                     '--task-params',
                     '{"foo":"bar"}',
-                    DEFAULT_DATE.isoformat(),
                 ]
             )
         )
@@ -268,9 +285,9 @@ class TestCliTasks:
                     'test',
                     'example_passing_params_via_test_command',
                     'also_run_this',
+                    DEFAULT_DATE.isoformat(),
                     '--task-params',
                     '{"foo":"bar"}',
-                    DEFAULT_DATE.isoformat(),
                 ]
             )
         )
@@ -284,9 +301,9 @@ class TestCliTasks:
                         'test',
                         'example_passing_params_via_test_command',
                         'env_var_test_task',
+                        DEFAULT_DATE.isoformat(),
                         '--env-vars',
                         '{"foo":"bar"}',
-                        DEFAULT_DATE.isoformat(),
                     ]
                 )
             )