You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2023/02/20 03:57:05 UTC

[airflow] branch main updated: Avoid importing executor during conf validation (#29569)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7dd19731f2 Avoid importing executor during conf validation (#29569)
7dd19731f2 is described below

commit 7dd19731f232d99d2947128c29e393b8ce013741
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Mon Feb 20 11:56:55 2023 +0800

    Avoid importing executor during conf validation (#29569)
---
 airflow/cli/commands/scheduler_command.py     | 21 +++++++-------
 airflow/configuration.py                      | 39 ++++++++++++--------------
 airflow/executors/executor_loader.py          | 40 ++++++++++++++++++++++++---
 airflow/www/views.py                          | 11 +++++---
 tests/conftest.py                             |  7 ++++-
 tests/executors/test_executor_loader.py       | 30 +++++++++++++++++++-
 tests/providers/github/sensors/test_github.py |  4 ++-
 tests/www/views/test_views_tasks.py           |  2 +-
 8 files changed, 110 insertions(+), 44 deletions(-)

diff --git a/airflow/cli/commands/scheduler_command.py b/airflow/cli/commands/scheduler_command.py
index 3965b552b7..2c7c7c0f1a 100644
--- a/airflow/cli/commands/scheduler_command.py
+++ b/airflow/cli/commands/scheduler_command.py
@@ -34,16 +34,10 @@ from airflow.utils.cli import process_subdir, setup_locations, setup_logging, si
 from airflow.utils.scheduler_health import serve_health_check
 
 
-def _run_scheduler_job(args):
+def _run_scheduler_job(job: SchedulerJob, *, skip_serve_logs: bool) -> None:
     InternalApiConfig.force_database_direct_access()
-
-    job = SchedulerJob(
-        subdir=process_subdir(args.subdir),
-        num_runs=args.num_runs,
-        do_pickle=args.do_pickle,
-    )
     enable_health_check = conf.getboolean("scheduler", "ENABLE_HEALTH_CHECK")
-    with _serve_logs(args.skip_serve_logs), _serve_health_check(enable_health_check):
+    with _serve_logs(skip_serve_logs), _serve_health_check(enable_health_check):
         job.run()
 
 
@@ -52,6 +46,13 @@ def scheduler(args):
     """Starts Airflow Scheduler."""
     print(settings.HEADER)
 
+    job = SchedulerJob(
+        subdir=process_subdir(args.subdir),
+        num_runs=args.num_runs,
+        do_pickle=args.do_pickle,
+    )
+    ExecutorLoader.validate_database_executor_compatibility(job.executor)
+
     if args.daemon:
         pid, stdout, stderr, log_file = setup_locations(
             "scheduler", args.pid, args.stdout, args.stderr, args.log_file
@@ -69,12 +70,12 @@ def scheduler(args):
                 umask=int(settings.DAEMON_UMASK, 8),
             )
             with ctx:
-                _run_scheduler_job(args=args)
+                _run_scheduler_job(job, skip_serve_logs=args.skip_serve_logs)
     else:
         signal.signal(signal.SIGINT, sigint_handler)
         signal.signal(signal.SIGTERM, sigint_handler)
         signal.signal(signal.SIGQUIT, sigquit_handler)
-        _run_scheduler_job(args=args)
+        _run_scheduler_job(job, skip_serve_logs=args.skip_serve_logs)
 
 
 @contextmanager
diff --git a/airflow/configuration.py b/airflow/configuration.py
index c2d611f718..ff29009f25 100644
--- a/airflow/configuration.py
+++ b/airflow/configuration.py
@@ -43,7 +43,6 @@ from typing_extensions import overload
 
 from airflow.compat.functools import cached_property
 from airflow.exceptions import AirflowConfigException
-from airflow.executors.executor_loader import ExecutorLoader
 from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH, BaseSecretsBackend
 from airflow.utils import yaml
 from airflow.utils.module_loading import import_string
@@ -339,7 +338,7 @@ class AirflowConfigParser(ConfigParser):
         self._suppress_future_warnings = False
 
     def validate(self):
-        self._validate_config_dependencies()
+        self._validate_sqlite3_version()
         self._validate_enums()
 
         for section, replacement in self.deprecated_values.items():
@@ -428,31 +427,27 @@ class AirflowConfigParser(ConfigParser):
                         f"{value!r}. Possible values: {', '.join(enum_options)}."
                     )
 
-    def _validate_config_dependencies(self):
-        """
-        Validate that config based on condition.
+    def _validate_sqlite3_version(self):
+        """Validate SQLite version.
 
-        Values are considered invalid when they conflict with other config values
-        or system-level limitations and requirements.
+        Some features in storing rendered fields require SQLite >= 3.15.0.
         """
-        executor, _ = ExecutorLoader.import_default_executor_cls()
-        is_sqlite = "sqlite" in self.get("database", "sql_alchemy_conn")
+        if "sqlite" not in self.get("database", "sql_alchemy_conn"):
+            return
 
-        if is_sqlite and not executor.is_single_threaded:
-            raise AirflowConfigException(f"error: cannot use sqlite with the {executor.__name__}")
-        if is_sqlite:
-            import sqlite3
+        import sqlite3
 
-            from airflow.utils.docs import get_docs_url
+        min_sqlite_version = (3, 15, 0)
+        if _parse_sqlite_version(sqlite3.sqlite_version) >= min_sqlite_version:
+            return
 
-            # Some features in storing rendered fields require sqlite version >= 3.15.0
-            min_sqlite_version = (3, 15, 0)
-            if _parse_sqlite_version(sqlite3.sqlite_version) < min_sqlite_version:
-                min_sqlite_version_str = ".".join(str(s) for s in min_sqlite_version)
-                raise AirflowConfigException(
-                    f"error: sqlite C library version too old (< {min_sqlite_version_str}). "
-                    f"See {get_docs_url('howto/set-up-database.html#setting-up-a-sqlite-database')}"
-                )
+        from airflow.utils.docs import get_docs_url
+
+        min_sqlite_version_str = ".".join(str(s) for s in min_sqlite_version)
+        raise AirflowConfigException(
+            f"error: SQLite C library too old (< {min_sqlite_version_str}). "
+            f"See {get_docs_url('howto/set-up-database.html#setting-up-a-sqlite-database')}"
+        )
 
     def _using_old_value(self, old: Pattern, current_value: str) -> bool:
         return old.search(current_value) is not None
diff --git a/airflow/executors/executor_loader.py b/airflow/executors/executor_loader.py
index 50ecca8d33..ee7759717f 100644
--- a/airflow/executors/executor_loader.py
+++ b/airflow/executors/executor_loader.py
@@ -17,7 +17,9 @@
 """All executors."""
 from __future__ import annotations
 
+import functools
 import logging
+import os
 from contextlib import suppress
 from enum import Enum, unique
 from typing import TYPE_CHECKING
@@ -122,8 +124,14 @@ class ExecutorLoader:
 
         :return: executor class via executor_name and executor import source
         """
+
+        def _import_and_validate(path: str) -> type[BaseExecutor]:
+            executor = import_string(path)
+            cls.validate_database_executor_compatibility(executor)
+            return executor
+
         if executor_name in cls.executors:
-            return import_string(cls.executors[executor_name]), ConnectorSource.CORE
+            return _import_and_validate(cls.executors[executor_name]), ConnectorSource.CORE
         if executor_name.count(".") == 1:
             log.debug(
                 "The executor name looks like the plugin path (executor_name=%s). Trying to import a "
@@ -136,8 +144,8 @@ class ExecutorLoader:
                 from airflow import plugins_manager
 
                 plugins_manager.integrate_executor_plugins()
-                return import_string(f"airflow.executors.{executor_name}"), ConnectorSource.PLUGIN
-        return import_string(executor_name), ConnectorSource.CUSTOM_PATH
+                return _import_and_validate(f"airflow.executors.{executor_name}"), ConnectorSource.PLUGIN
+        return _import_and_validate(executor_name), ConnectorSource.CUSTOM_PATH
 
     @classmethod
     def import_default_executor_cls(cls) -> tuple[type[BaseExecutor], ConnectorSource]:
@@ -147,8 +155,32 @@ class ExecutorLoader:
         :return: executor class and executor import source
         """
         executor_name = cls.get_default_executor_name()
+        executor, source = cls.import_executor_cls(executor_name)
+        return executor, source
+
+    @classmethod
+    @functools.lru_cache(maxsize=None)
+    def validate_database_executor_compatibility(cls, executor: type[BaseExecutor]) -> None:
+        """Validate database and executor compatibility.
+
+        Most of the databases work universally, but SQLite can only work with
+        single-threaded executors (e.g. Sequential).
+
+        This is NOT done in ``airflow.configuration`` (when configuration is
+        initialized) because loading the executor class is heavy work we want to
+        avoid unless needed.
+        """
+        if not executor.is_single_threaded:
+            return
+
+        # This is set in tests when we want to be able to use the SequentialExecutor.
+        if os.environ.get("_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK") == "1":
+            return
+
+        from airflow.settings import engine
 
-        return cls.import_executor_cls(executor_name)
+        if engine.dialect.name == "sqlite":
+            raise AirflowConfigException(f"error: cannot use SQLite with the {executor.__name__}")
 
     @classmethod
     def __load_celery_kubernetes_executor(cls) -> BaseExecutor:
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 597b53429c..c4fa14ebc5 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1852,7 +1852,9 @@ class Airflow(AirflowBaseView):
         dag_run_id = request.form.get("dag_run_id")
         map_index = request.args.get("map_index", -1, type=int)
         origin = get_safe_url(request.form.get("origin"))
-        dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id)
+        dag = get_airflow_app().dag_bag.get_dag(dag_id)
+        if not dag:
+            return redirect_or_json(origin, "DAG not found", "error", 404)
         task = dag.get_task(task_id)
 
         ignore_all_deps = request.form.get("ignore_all_deps") == "true"
@@ -1864,9 +1866,10 @@ class Airflow(AirflowBaseView):
         if not executor.supports_ad_hoc_ti_run:
             msg = f"{executor.__class__.__name__} does not support ad hoc task runs"
             return redirect_or_json(origin, msg, "error", 400)
-
-        dag_run = dag.get_dagrun(run_id=dag_run_id)
-        ti = dag_run.get_task_instance(task_id=task.task_id, map_index=map_index)
+        dag_run = dag.get_dagrun(run_id=dag_run_id, session=session)
+        if not dag_run:
+            return redirect_or_json(origin, "DAG run not found", "error", 404)
+        ti = dag_run.get_task_instance(task_id=task.task_id, map_index=map_index, session=session)
         if not ti:
             msg = "Could not queue task instance for execution, task instance is missing"
             return redirect_or_json(origin, msg, "error", 400)
diff --git a/tests/conftest.py b/tests/conftest.py
index 232d99d4a8..509fc9b33d 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -43,7 +43,7 @@ from airflow import settings  # noqa: E402
 from airflow.models.tasklog import LogTemplate  # noqa: E402
 from tests.test_utils.db import clear_all  # noqa: E402
 
-from tests.test_utils.perf.perf_kit.sqlalchemy import (  # noqa isort:skip
+from tests.test_utils.perf.perf_kit.sqlalchemy import (  # noqa: E402  # isort: skip
     count_queries,
     trace_queries,
 )
@@ -278,6 +278,11 @@ def pytest_configure(config):
     config.addinivalue_line(
         "markers", "credential_file(name): mark tests that require credential file in CREDENTIALS_DIR"
     )
+    os.environ["_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK"] = "1"
+
+
+def pytest_unconfigure(config):
+    os.environ["_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK"]
 
 
 def skip_if_not_marked_with_integration(selected_integrations, item):
diff --git a/tests/executors/test_executor_loader.py b/tests/executors/test_executor_loader.py
index 96d7030e17..c436c96778 100644
--- a/tests/executors/test_executor_loader.py
+++ b/tests/executors/test_executor_loader.py
@@ -16,11 +16,13 @@
 # under the License.
 from __future__ import annotations
 
+from contextlib import nullcontext
 from unittest import mock
 
 import pytest
 
 from airflow import plugins_manager
+from airflow.exceptions import AirflowConfigException
 from airflow.executors.executor_loader import ConnectorSource, ExecutorLoader
 from tests.test_utils.config import conf_vars
 
@@ -29,7 +31,11 @@ TEST_PLUGIN_NAME = "unique_plugin_name_to_avoid_collision_i_love_kitties"
 
 
 class FakeExecutor:
-    pass
+    is_single_threaded = False
+
+
+class FakeSingleThreadedExecutor:
+    is_single_threaded = True
 
 
 class FakePlugin(plugins_manager.AirflowPlugin):
@@ -103,3 +109,25 @@ class TestExecutorLoader:
             executor, import_source = ExecutorLoader.import_default_executor_cls()
             assert "FakeExecutor" == executor.__name__
             assert import_source == ConnectorSource.CUSTOM_PATH
+
+    @pytest.mark.backend("mssql", "mysql", "postgres")
+    @pytest.mark.parametrize("executor", [FakeExecutor, FakeSingleThreadedExecutor])
+    def test_validate_database_executor_compatibility_general(self, monkeypatch, executor):
+        monkeypatch.delenv("_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK")
+        ExecutorLoader.validate_database_executor_compatibility(executor)
+
+    @pytest.mark.backend("sqlite")
+    @pytest.mark.parametrize(
+        ["executor", "expectation"],
+        [
+            (FakeExecutor, nullcontext()),
+            (
+                FakeSingleThreadedExecutor,
+                pytest.raises(AirflowConfigException, match=r"^error: cannot use SQLite with the .+"),
+            ),
+        ],
+    )
+    def test_validate_database_executor_compatibility_sqlite(self, monkeypatch, executor, expectation):
+        monkeypatch.delenv("_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK")
+        with expectation:
+            ExecutorLoader.validate_database_executor_compatibility(executor)
diff --git a/tests/providers/github/sensors/test_github.py b/tests/providers/github/sensors/test_github.py
index db8999d64c..8ae5956bc3 100644
--- a/tests/providers/github/sensors/test_github.py
+++ b/tests/providers/github/sensors/test_github.py
@@ -43,7 +43,9 @@ class TestGithubSensor:
         )
 
     @patch(
-        "airflow.providers.github.hooks.github.GithubClient", autospec=True, return_value=github_client_mock
+        "airflow.providers.github.hooks.github.GithubClient",
+        autospec=True,
+        return_value=github_client_mock,
     )
     def test_github_tag_created(self, github_mock):
         class MockTag:
diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py
index 68ccff62cd..5add70b261 100644
--- a/tests/www/views/test_views_tasks.py
+++ b/tests/www/views/test_views_tasks.py
@@ -492,7 +492,7 @@ def test_code_from_db_all_example_dags(admin_client):
                 dag_id="example_bash_operator",
                 ignore_all_deps="false",
                 ignore_ti_state="true",
-                execution_date=DEFAULT_DATE,
+                dag_run_id=DEFAULT_DAGRUN,
             ),
             "",
         ),