You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/03/19 22:03:59 UTC

[airflow] 02/03: Don't create unittest.cfg when not running in unit test mode (#14420)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit cbd181cf65ac4888c0d0ee37703976418a9895b1
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Fri Feb 26 10:20:38 2021 +0000

    Don't create unittest.cfg when not running in unit test mode (#14420)
    
    Right now on airflow startup it would _always_ create a unittest.cfg,
    even if unit test mode was not enabled.
    
    This PR changes it so that this file is only created when unit tests
    mode is enabled (via the environment variable or in airflow.cfg).
    
    I have also refactored the mess of top-level code that was interspersed
    between functions to all be at one place -- at the end of the file.
    
    The bulk of the config loading code now lives in the `initialize_config`
    function.
    
    Adding lazy module attributes to airflow.configuration via the Pep562
    class caused some weird behaviour where AirflowConfigParser became
    unpickleable because of some "leaking" closed-over scope. (Pickling this
    class is used almost exclusively by PythonVirtualEnvOperator). The fix
    here is to add custom get/set state methods to not store more than we
    need.
    
    (cherry picked from commit b16b9ee6894711a8af7143286189c4a3cc31d1c4)
---
 airflow/configuration.py         | 293 ++++++++++++++++++++++++++-------------
 tests/core/test_configuration.py |  13 +-
 tests/test_utils/config.py       |   2 +
 3 files changed, 207 insertions(+), 101 deletions(-)

diff --git a/airflow/configuration.py b/airflow/configuration.py
index 4155d64..f1c48ed 100644
--- a/airflow/configuration.py
+++ b/airflow/configuration.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import functools
 import json
 import logging
 import multiprocessing
@@ -30,16 +31,11 @@ from collections import OrderedDict
 
 # Ignored Mypy on configparser because it thinks the configparser module has no _UNSET attribute
 from configparser import _UNSET, ConfigParser, NoOptionError, NoSectionError  # type: ignore
-from distutils.version import StrictVersion
 from json.decoder import JSONDecodeError
-from typing import Dict, List, Optional, Tuple, Union
-
-import yaml
-from cryptography.fernet import Fernet
+from typing import Dict, List, Optional, Union
 
 from airflow.exceptions import AirflowConfigException
 from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH, BaseSecretsBackend
-from airflow.utils.docs import get_docs_url
 from airflow.utils.module_loading import import_string
 
 log = logging.getLogger(__name__)
@@ -90,15 +86,9 @@ def _get_config_value_from_secret_backend(config_key):
     return secrets_client.get_config(config_key)
 
 
-def _read_default_config_file(file_name: str) -> Tuple[str, str]:
+def _default_config_file_path(file_name: str):
     templates_dir = os.path.join(os.path.dirname(__file__), 'config_templates')
-    file_path = os.path.join(templates_dir, file_name)
-    with open(file_path, encoding='utf-8') as config_file:
-        return config_file.read(), file_path
-
-
-DEFAULT_CONFIG, DEFAULT_CONFIG_FILE_PATH = _read_default_config_file('default_airflow.cfg')
-TEST_CONFIG, TEST_CONFIG_FILE_PATH = _read_default_config_file('default_test.cfg')
+    return os.path.join(templates_dir, file_name)
 
 
 def default_config_yaml() -> dict:
@@ -107,10 +97,9 @@ def default_config_yaml() -> dict:
 
     :return: Python dictionary containing configs & their info
     """
-    templates_dir = os.path.join(os.path.dirname(__file__), 'config_templates')
-    file_path = os.path.join(templates_dir, "config.yml")
+    import yaml
 
-    with open(file_path) as config_file:
+    with open(_default_config_file_path('config.yml')) as config_file:
         return yaml.safe_load(config_file)
 
 
@@ -240,6 +229,9 @@ class AirflowConfigParser(ConfigParser):  # pylint: disable=too-many-ancestors
             raise AirflowConfigException(f"error: cannot use sqlite with the {self.get('core', 'executor')}")
         if is_sqlite:
             import sqlite3
+            from distutils.version import StrictVersion
+
+            from airflow.utils.docs import get_docs_url
 
             # Some of the features in storing rendered fields require sqlite version >= 3.15.0
             min_sqlite_version = '3.15.0'
@@ -683,12 +675,15 @@ class AirflowConfigParser(ConfigParser):  # pylint: disable=too-many-ancestors
 
         Note: this is not reversible.
         """
-        # override any custom settings with defaults
-        log.info("Overriding settings with defaults from %s", DEFAULT_CONFIG_FILE_PATH)
-        self.read_string(parameterized_config(DEFAULT_CONFIG))
+        # remove all sections, falling back to defaults
+        for section in self.sections():
+            self.remove_section(section)
+
         # then read test config
-        log.info("Reading default test configuration from %s", TEST_CONFIG_FILE_PATH)
-        self.read_string(parameterized_config(TEST_CONFIG))
+
+        path = _default_config_file_path('default_test.cfg')
+        log.info("Reading default test configuration from %s", path)
+        self.read_string(_parameterized_config_from_template('default_test.cfg'))
         # then read any "custom" test settings
         log.info("Reading test configuration from %s", TEST_CONFIG_FILE)
         self.read(TEST_CONFIG_FILE)
@@ -719,6 +714,22 @@ class AirflowConfigParser(ConfigParser):  # pylint: disable=too-many-ancestors
                 stacklevel=3,
             )
 
+    def __getstate__(self):
+        return {
+            name: getattr(self, name)
+            for name in [
+                '_sections',
+                'is_validated',
+                'airflow_defaults',
+            ]
+        }
+
+    def __setstate__(self, state):
+        self.__init__()
+        config = state.pop('_sections')
+        self.read_dict(config)
+        self.__dict__.update(state)
+
 
 def get_airflow_home():
     """Get path to Airflow Home"""
@@ -732,32 +743,16 @@ def get_airflow_config(airflow_home):
     return expand_env_var(os.environ['AIRFLOW_CONFIG'])
 
 
-# Setting AIRFLOW_HOME and AIRFLOW_CONFIG from environment variables, using
-# "~/airflow" and "$AIRFLOW_HOME/airflow.cfg" respectively as defaults.
-
-AIRFLOW_HOME = get_airflow_home()
-AIRFLOW_CONFIG = get_airflow_config(AIRFLOW_HOME)
-pathlib.Path(AIRFLOW_HOME).mkdir(parents=True, exist_ok=True)
-
+def _parameterized_config_from_template(filename) -> str:
+    TEMPLATE_START = '# ----------------------- TEMPLATE BEGINS HERE -----------------------\n'
 
-# Set up dags folder for unit tests
-# this directory won't exist if users install via pip
-_TEST_DAGS_FOLDER = os.path.join(
-    os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 'tests', 'dags'
-)
-if os.path.exists(_TEST_DAGS_FOLDER):
-    TEST_DAGS_FOLDER = _TEST_DAGS_FOLDER
-else:
-    TEST_DAGS_FOLDER = os.path.join(AIRFLOW_HOME, 'dags')
-
-# Set up plugins folder for unit tests
-_TEST_PLUGINS_FOLDER = os.path.join(
-    os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 'tests', 'plugins'
-)
-if os.path.exists(_TEST_PLUGINS_FOLDER):
-    TEST_PLUGINS_FOLDER = _TEST_PLUGINS_FOLDER
-else:
-    TEST_PLUGINS_FOLDER = os.path.join(AIRFLOW_HOME, 'plugins')
+    path = _default_config_file_path(filename)
+    with open(path) as fh:
+        for line in fh:
+            if line != TEMPLATE_START:
+                continue
+            return parameterized_config(fh.read().strip())
+    raise RuntimeError(f"Template marker not found in {path!r}")
 
 
 def parameterized_config(template):
@@ -778,65 +773,93 @@ def get_airflow_test_config(airflow_home):
     return expand_env_var(os.environ['AIRFLOW_TEST_CONFIG'])
 
 
-TEST_CONFIG_FILE = get_airflow_test_config(AIRFLOW_HOME)
+def _generate_fernet_key():
+    from cryptography.fernet import Fernet
 
-# only generate a Fernet key if we need to create a new config file
-if not os.path.isfile(TEST_CONFIG_FILE) or not os.path.isfile(AIRFLOW_CONFIG):
-    FERNET_KEY = Fernet.generate_key().decode()
-else:
-    FERNET_KEY = ''
+    return Fernet.generate_key().decode()
 
-SECRET_KEY = b64encode(os.urandom(16)).decode('utf-8')
 
-TEMPLATE_START = '# ----------------------- TEMPLATE BEGINS HERE -----------------------'
-if not os.path.isfile(TEST_CONFIG_FILE):
-    log.info('Creating new Airflow config file for unit tests in: %s', TEST_CONFIG_FILE)
-    with open(TEST_CONFIG_FILE, 'w') as file:
-        cfg = parameterized_config(TEST_CONFIG)
-        file.write(cfg.split(TEMPLATE_START)[-1].strip())
-if not os.path.isfile(AIRFLOW_CONFIG):
-    log.info('Creating new Airflow config file in: %s', AIRFLOW_CONFIG)
-    with open(AIRFLOW_CONFIG, 'w') as file:
-        cfg = parameterized_config(DEFAULT_CONFIG)
-        cfg = cfg.split(TEMPLATE_START)[-1].strip()
-        file.write(cfg)
-
-log.info("Reading the config from %s", AIRFLOW_CONFIG)
-
-conf = AirflowConfigParser(default_config=parameterized_config(DEFAULT_CONFIG))
-
-conf.read(AIRFLOW_CONFIG)
-
-if conf.has_option('core', 'AIRFLOW_HOME'):
-    msg = (
-        'Specifying both AIRFLOW_HOME environment variable and airflow_home '
-        'in the config file is deprecated. Please use only the AIRFLOW_HOME '
-        'environment variable and remove the config file entry.'
-    )
-    if 'AIRFLOW_HOME' in os.environ:
-        warnings.warn(msg, category=DeprecationWarning)
-    elif conf.get('core', 'airflow_home') == AIRFLOW_HOME:
-        warnings.warn(
-            'Specifying airflow_home in the config file is deprecated. As you '
-            'have left it at the default value you should remove the setting '
-            'from your airflow.cfg and suffer no change in behaviour.',
-            category=DeprecationWarning,
-        )
+def initialize_config():
+    """
+    Load the Airflow config files.
+
+    Called for you automatically as part of the Airflow boot process.
+    """
+    global FERNET_KEY, AIRFLOW_HOME
+
+    default_config = _parameterized_config_from_template('default_airflow.cfg')
+
+    conf = AirflowConfigParser(default_config=default_config)
+
+    if conf.getboolean('core', 'unit_test_mode'):
+        # Load test config only
+        if not os.path.isfile(TEST_CONFIG_FILE):
+            from cryptography.fernet import Fernet
+
+            log.info('Creating new Airflow config file for unit tests in: %s', TEST_CONFIG_FILE)
+            pathlib.Path(AIRFLOW_HOME).mkdir(parents=True, exist_ok=True)
+
+            FERNET_KEY = Fernet.generate_key().decode()
+
+            with open(TEST_CONFIG_FILE, 'w') as file:
+                cfg = _parameterized_config_from_template('default_test.cfg')
+                file.write(cfg)
+
+        conf.load_test_config()
     else:
-        AIRFLOW_HOME = conf.get('core', 'airflow_home')
-        warnings.warn(msg, category=DeprecationWarning)
+        # Load normal config
+        if not os.path.isfile(AIRFLOW_CONFIG):
+            from cryptography.fernet import Fernet
 
+            log.info('Creating new Airflow config file in: %s', AIRFLOW_CONFIG)
+            pathlib.Path(AIRFLOW_HOME).mkdir(parents=True, exist_ok=True)
 
-WEBSERVER_CONFIG = AIRFLOW_HOME + '/webserver_config.py'
+            FERNET_KEY = Fernet.generate_key().decode()
 
-if not os.path.isfile(WEBSERVER_CONFIG):
-    log.info('Creating new FAB webserver config file in: %s', WEBSERVER_CONFIG)
-    DEFAULT_WEBSERVER_CONFIG, _ = _read_default_config_file('default_webserver_config.py')
-    with open(WEBSERVER_CONFIG, 'w') as file:
-        file.write(DEFAULT_WEBSERVER_CONFIG)
+            with open(AIRFLOW_CONFIG, 'w') as file:
+                file.write(default_config)
 
-if conf.getboolean('core', 'unit_test_mode'):
-    conf.load_test_config()
+        log.info("Reading the config from %s", AIRFLOW_CONFIG)
+
+        conf.read(AIRFLOW_CONFIG)
+
+        if conf.has_option('core', 'AIRFLOW_HOME'):
+            msg = (
+                'Specifying both AIRFLOW_HOME environment variable and airflow_home '
+                'in the config file is deprecated. Please use only the AIRFLOW_HOME '
+                'environment variable and remove the config file entry.'
+            )
+            if 'AIRFLOW_HOME' in os.environ:
+                warnings.warn(msg, category=DeprecationWarning)
+            elif conf.get('core', 'airflow_home') == AIRFLOW_HOME:
+                warnings.warn(
+                    'Specifying airflow_home in the config file is deprecated. As you '
+                    'have left it at the default value you should remove the setting '
+                    'from your airflow.cfg and suffer no change in behaviour.',
+                    category=DeprecationWarning,
+                )
+            else:
+                AIRFLOW_HOME = conf.get('core', 'airflow_home')
+                warnings.warn(msg, category=DeprecationWarning)
+
+        # They _might_ have set unit_test_mode in the airflow.cfg, we still
+        # want to respect that and then load the unittests.cfg
+        if conf.getboolean('core', 'unit_test_mode'):
+            conf.load_test_config()
+
+    # Make it no longer a proxy variable, just set it to an actual string
+    global WEBSERVER_CONFIG
+    WEBSERVER_CONFIG = AIRFLOW_HOME + '/webserver_config.py'
+
+    if not os.path.isfile(WEBSERVER_CONFIG):
+        import shutil
+
+        log.info('Creating new FAB webserver config file in: %s', WEBSERVER_CONFIG)
+        shutil.copy(_default_config_file_path('default_webserver_config.py'), WEBSERVER_CONFIG)
+
+    conf.validate()
+
+    return conf
 
 
 # Historical convenience functions to access config entries
@@ -1006,6 +1029,78 @@ def initialize_secrets_backends() -> List[BaseSecretsBackend]:
     return backend_list
 
 
+@functools.lru_cache(maxsize=None)
+def _DEFAULT_CONFIG():
+    path = _default_config_file_path('default_airflow.cfg')
+    with open(path) as fh:
+        return fh.read()
+
+
+@functools.lru_cache(maxsize=None)
+def _TEST_CONFIG():
+    path = _default_config_file_path('default_test.cfg')
+    with open(path) as fh:
+        return fh.read()
+
+
+_deprecated = {
+    'DEFAULT_CONFIG': _DEFAULT_CONFIG,
+    'TEST_CONFIG': _TEST_CONFIG,
+    'TEST_CONFIG_FILE_PATH': functools.partial(_default_config_file_path, ('default_test.cfg')),
+    'DEFAULT_CONFIG_FILE_PATH': functools.partial(_default_config_file_path, ('default_airflow.cfg')),
+}
+
+
+def __getattr__(name):
+    if name in _deprecated:
+        warnings.warn(
+            f"{__name__}.{name} is deprecated and will be removed in future",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        return _deprecated[name]()
+    raise AttributeError(f"module {__name__} has no attribute {name}")
+
+
+# Setting AIRFLOW_HOME and AIRFLOW_CONFIG from environment variables, using
+# "~/airflow" and "$AIRFLOW_HOME/airflow.cfg" respectively as defaults.
+
+AIRFLOW_HOME = get_airflow_home()
+AIRFLOW_CONFIG = get_airflow_config(AIRFLOW_HOME)
+
+
+# Set up dags folder for unit tests
+# this directory won't exist if users install via pip
+_TEST_DAGS_FOLDER = os.path.join(
+    os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 'tests', 'dags'
+)
+if os.path.exists(_TEST_DAGS_FOLDER):
+    TEST_DAGS_FOLDER = _TEST_DAGS_FOLDER
+else:
+    TEST_DAGS_FOLDER = os.path.join(AIRFLOW_HOME, 'dags')
+
+# Set up plugins folder for unit tests
+_TEST_PLUGINS_FOLDER = os.path.join(
+    os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 'tests', 'plugins'
+)
+if os.path.exists(_TEST_PLUGINS_FOLDER):
+    TEST_PLUGINS_FOLDER = _TEST_PLUGINS_FOLDER
+else:
+    TEST_PLUGINS_FOLDER = os.path.join(AIRFLOW_HOME, 'plugins')
+
+
+TEST_CONFIG_FILE = get_airflow_test_config(AIRFLOW_HOME)
+
+SECRET_KEY = b64encode(os.urandom(16)).decode('utf-8')
+FERNET_KEY = ''  # Set only if needed when generating a new file
+WEBSERVER_CONFIG = ''  # Set by initialize_config
+
+conf = initialize_config()
 secrets_backend_list = initialize_secrets_backends()
 
-conf.validate()
+
+PY37 = sys.version_info >= (3, 7)
+if not PY37:
+    from pep562 import Pep562
+
+    Pep562(__name__)
diff --git a/tests/core/test_configuration.py b/tests/core/test_configuration.py
index 4e32c5e..a64407e 100644
--- a/tests/core/test_configuration.py
+++ b/tests/core/test_configuration.py
@@ -19,6 +19,7 @@ import io
 import os
 import re
 import tempfile
+import textwrap
 import unittest
 import warnings
 from collections import OrderedDict
@@ -28,7 +29,6 @@ import pytest
 
 from airflow import configuration
 from airflow.configuration import (
-    DEFAULT_CONFIG,
     AirflowConfigException,
     AirflowConfigParser,
     conf,
@@ -561,8 +561,17 @@ notacommand = OK
             assert test_cmdenv_conf.get('testcmdenv', 'notacommand') == 'OK'
 
     def test_parameterized_config_gen(self):
+        config = textwrap.dedent(
+            """
+            [core]
+            dags_folder = {AIRFLOW_HOME}/dags
+            sql_alchemy_conn = sqlite:///{AIRFLOW_HOME}/airflow.db
+            parallelism = 32
+            fernet_key = {FERNET_KEY}
+        """
+        )
 
-        cfg = parameterized_config(DEFAULT_CONFIG)
+        cfg = parameterized_config(config)
 
         # making sure some basic building blocks are present:
         assert "[core]" in cfg
diff --git a/tests/test_utils/config.py b/tests/test_utils/config.py
index c55a2b5..7bffb47 100644
--- a/tests/test_utils/config.py
+++ b/tests/test_utils/config.py
@@ -38,6 +38,8 @@ def conf_vars(overrides):
         else:
             original[(section, key)] = None
         if value is not None:
+            if not conf.has_section(section):
+                conf.add_section(section)
             conf.set(section, key, value)
         else:
             conf.remove_option(section, key)