You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/03/19 22:03:57 UTC

[airflow] branch v2-0-test updated (9854c43 -> a857c90)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a change to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git.


    from 9854c43  Webserver: Sanitize string passed to origin param (#14738)
     new 14f978e  Fix error when running tasks with Sentry integration enabled. (#13929)
     new cbd181c  Don't create unittest.cfg when not running in unit test mode (#14420)
     new a857c90  Use libyaml C library when available. (#14577)

The 3 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 airflow/cli/commands/connection_command.py         |   2 +-
 airflow/cli/commands/kubernetes_command.py         |   2 +-
 airflow/cli/simple_table.py                        |   2 +-
 airflow/configuration.py                           | 293 ++++++++++++++-------
 airflow/kubernetes/pod_generator.py                |   2 +-
 airflow/kubernetes/refresh_config.py               |   3 +-
 .../providers/cncf/kubernetes/hooks/kubernetes.py  |   6 +-
 .../cncf/kubernetes/operators/kubernetes_pod.py    |   6 +-
 .../google/cloud/operators/cloud_build.py          |   5 +-
 airflow/providers_manager.py                       |   2 +-
 airflow/secrets/local_filesystem.py                |   3 +-
 airflow/sentry.py                                  |  13 +-
 airflow/utils/session.py                           |  21 +-
 airflow/utils/yaml.py                              |  76 ++++++
 airflow/www/views.py                               |   2 +-
 .../copy_provider_package_sources.py               |   1 +
 dev/provider_packages/prepare_provider_packages.py |   7 +-
 docs/conf.py                                       |   7 +-
 docs/exts/docs_build/lint_checks.py                |   8 +-
 docs/exts/provider_yaml_utils.py                   |   8 +-
 .../pre_commit_check_pre_commit_hook_names.py      |   7 +-
 .../pre_commit_check_provider_yaml_files.py        |   7 +-
 tests/core/test_configuration.py                   |  13 +-
 tests/test_utils/config.py                         |   2 +
 .../utils/test_session.py                          |  50 ++--
 25 files changed, 396 insertions(+), 152 deletions(-)
 create mode 100644 airflow/utils/yaml.py
 copy airflow/contrib/operators/bigquery_table_delete_operator.py => tests/utils/test_session.py (50%)

[airflow] 02/03: Don't create unittest.cfg when not running in unit test mode (#14420)

Posted by as...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit cbd181cf65ac4888c0d0ee37703976418a9895b1
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Fri Feb 26 10:20:38 2021 +0000

    Don't create unittest.cfg when not running in unit test mode (#14420)
    
    Right now on airflow startup it would _always_ create a unittest.cfg,
    even if unit test mode was not enabled.
    
    This PR changes it so that this file is only created when unit tests
    mode is enabled (via the environment variable or in airflow.cfg).
    
    I have also refactored the mess of top-level code that was interspersed
    between functions to all be at one place -- at the end of the file.
    
    The bulk of the config loading code now lives in the `initialize_config`
    function.
    
    Adding lazy module attributes to airflow.configuration via the Pep562
    class caused some weird behaviour where AirflowConfigParser became
    unpickleable because of some "leaking" closed-over scope. (Pickling this
    class is used almost exclusively by PythonVirtualEnvOperator). The fix
    here is to add custom get/set state methods to not store more than we
    need.
    
    (cherry picked from commit b16b9ee6894711a8af7143286189c4a3cc31d1c4)
---
 airflow/configuration.py         | 293 ++++++++++++++++++++++++++-------------
 tests/core/test_configuration.py |  13 +-
 tests/test_utils/config.py       |   2 +
 3 files changed, 207 insertions(+), 101 deletions(-)

diff --git a/airflow/configuration.py b/airflow/configuration.py
index 4155d64..f1c48ed 100644
--- a/airflow/configuration.py
+++ b/airflow/configuration.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import functools
 import json
 import logging
 import multiprocessing
@@ -30,16 +31,11 @@ from collections import OrderedDict
 
 # Ignored Mypy on configparser because it thinks the configparser module has no _UNSET attribute
 from configparser import _UNSET, ConfigParser, NoOptionError, NoSectionError  # type: ignore
-from distutils.version import StrictVersion
 from json.decoder import JSONDecodeError
-from typing import Dict, List, Optional, Tuple, Union
-
-import yaml
-from cryptography.fernet import Fernet
+from typing import Dict, List, Optional, Union
 
 from airflow.exceptions import AirflowConfigException
 from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH, BaseSecretsBackend
-from airflow.utils.docs import get_docs_url
 from airflow.utils.module_loading import import_string
 
 log = logging.getLogger(__name__)
@@ -90,15 +86,9 @@ def _get_config_value_from_secret_backend(config_key):
     return secrets_client.get_config(config_key)
 
 
-def _read_default_config_file(file_name: str) -> Tuple[str, str]:
+def _default_config_file_path(file_name: str):
     templates_dir = os.path.join(os.path.dirname(__file__), 'config_templates')
-    file_path = os.path.join(templates_dir, file_name)
-    with open(file_path, encoding='utf-8') as config_file:
-        return config_file.read(), file_path
-
-
-DEFAULT_CONFIG, DEFAULT_CONFIG_FILE_PATH = _read_default_config_file('default_airflow.cfg')
-TEST_CONFIG, TEST_CONFIG_FILE_PATH = _read_default_config_file('default_test.cfg')
+    return os.path.join(templates_dir, file_name)
 
 
 def default_config_yaml() -> dict:
@@ -107,10 +97,9 @@ def default_config_yaml() -> dict:
 
     :return: Python dictionary containing configs & their info
     """
-    templates_dir = os.path.join(os.path.dirname(__file__), 'config_templates')
-    file_path = os.path.join(templates_dir, "config.yml")
+    import yaml
 
-    with open(file_path) as config_file:
+    with open(_default_config_file_path('config.yml')) as config_file:
         return yaml.safe_load(config_file)
 
 
@@ -240,6 +229,9 @@ class AirflowConfigParser(ConfigParser):  # pylint: disable=too-many-ancestors
             raise AirflowConfigException(f"error: cannot use sqlite with the {self.get('core', 'executor')}")
         if is_sqlite:
             import sqlite3
+            from distutils.version import StrictVersion
+
+            from airflow.utils.docs import get_docs_url
 
             # Some of the features in storing rendered fields require sqlite version >= 3.15.0
             min_sqlite_version = '3.15.0'
@@ -683,12 +675,15 @@ class AirflowConfigParser(ConfigParser):  # pylint: disable=too-many-ancestors
 
         Note: this is not reversible.
         """
-        # override any custom settings with defaults
-        log.info("Overriding settings with defaults from %s", DEFAULT_CONFIG_FILE_PATH)
-        self.read_string(parameterized_config(DEFAULT_CONFIG))
+        # remove all sections, falling back to defaults
+        for section in self.sections():
+            self.remove_section(section)
+
         # then read test config
-        log.info("Reading default test configuration from %s", TEST_CONFIG_FILE_PATH)
-        self.read_string(parameterized_config(TEST_CONFIG))
+
+        path = _default_config_file_path('default_test.cfg')
+        log.info("Reading default test configuration from %s", path)
+        self.read_string(_parameterized_config_from_template('default_test.cfg'))
         # then read any "custom" test settings
         log.info("Reading test configuration from %s", TEST_CONFIG_FILE)
         self.read(TEST_CONFIG_FILE)
@@ -719,6 +714,22 @@ class AirflowConfigParser(ConfigParser):  # pylint: disable=too-many-ancestors
                 stacklevel=3,
             )
 
+    def __getstate__(self):
+        return {
+            name: getattr(self, name)
+            for name in [
+                '_sections',
+                'is_validated',
+                'airflow_defaults',
+            ]
+        }
+
+    def __setstate__(self, state):
+        self.__init__()
+        config = state.pop('_sections')
+        self.read_dict(config)
+        self.__dict__.update(state)
+
 
 def get_airflow_home():
     """Get path to Airflow Home"""
@@ -732,32 +743,16 @@ def get_airflow_config(airflow_home):
     return expand_env_var(os.environ['AIRFLOW_CONFIG'])
 
 
-# Setting AIRFLOW_HOME and AIRFLOW_CONFIG from environment variables, using
-# "~/airflow" and "$AIRFLOW_HOME/airflow.cfg" respectively as defaults.
-
-AIRFLOW_HOME = get_airflow_home()
-AIRFLOW_CONFIG = get_airflow_config(AIRFLOW_HOME)
-pathlib.Path(AIRFLOW_HOME).mkdir(parents=True, exist_ok=True)
-
+def _parameterized_config_from_template(filename) -> str:
+    TEMPLATE_START = '# ----------------------- TEMPLATE BEGINS HERE -----------------------\n'
 
-# Set up dags folder for unit tests
-# this directory won't exist if users install via pip
-_TEST_DAGS_FOLDER = os.path.join(
-    os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 'tests', 'dags'
-)
-if os.path.exists(_TEST_DAGS_FOLDER):
-    TEST_DAGS_FOLDER = _TEST_DAGS_FOLDER
-else:
-    TEST_DAGS_FOLDER = os.path.join(AIRFLOW_HOME, 'dags')
-
-# Set up plugins folder for unit tests
-_TEST_PLUGINS_FOLDER = os.path.join(
-    os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 'tests', 'plugins'
-)
-if os.path.exists(_TEST_PLUGINS_FOLDER):
-    TEST_PLUGINS_FOLDER = _TEST_PLUGINS_FOLDER
-else:
-    TEST_PLUGINS_FOLDER = os.path.join(AIRFLOW_HOME, 'plugins')
+    path = _default_config_file_path(filename)
+    with open(path) as fh:
+        for line in fh:
+            if line != TEMPLATE_START:
+                continue
+            return parameterized_config(fh.read().strip())
+    raise RuntimeError(f"Template marker not found in {path!r}")
 
 
 def parameterized_config(template):
@@ -778,65 +773,93 @@ def get_airflow_test_config(airflow_home):
     return expand_env_var(os.environ['AIRFLOW_TEST_CONFIG'])
 
 
-TEST_CONFIG_FILE = get_airflow_test_config(AIRFLOW_HOME)
+def _generate_fernet_key():
+    from cryptography.fernet import Fernet
 
-# only generate a Fernet key if we need to create a new config file
-if not os.path.isfile(TEST_CONFIG_FILE) or not os.path.isfile(AIRFLOW_CONFIG):
-    FERNET_KEY = Fernet.generate_key().decode()
-else:
-    FERNET_KEY = ''
+    return Fernet.generate_key().decode()
 
-SECRET_KEY = b64encode(os.urandom(16)).decode('utf-8')
 
-TEMPLATE_START = '# ----------------------- TEMPLATE BEGINS HERE -----------------------'
-if not os.path.isfile(TEST_CONFIG_FILE):
-    log.info('Creating new Airflow config file for unit tests in: %s', TEST_CONFIG_FILE)
-    with open(TEST_CONFIG_FILE, 'w') as file:
-        cfg = parameterized_config(TEST_CONFIG)
-        file.write(cfg.split(TEMPLATE_START)[-1].strip())
-if not os.path.isfile(AIRFLOW_CONFIG):
-    log.info('Creating new Airflow config file in: %s', AIRFLOW_CONFIG)
-    with open(AIRFLOW_CONFIG, 'w') as file:
-        cfg = parameterized_config(DEFAULT_CONFIG)
-        cfg = cfg.split(TEMPLATE_START)[-1].strip()
-        file.write(cfg)
-
-log.info("Reading the config from %s", AIRFLOW_CONFIG)
-
-conf = AirflowConfigParser(default_config=parameterized_config(DEFAULT_CONFIG))
-
-conf.read(AIRFLOW_CONFIG)
-
-if conf.has_option('core', 'AIRFLOW_HOME'):
-    msg = (
-        'Specifying both AIRFLOW_HOME environment variable and airflow_home '
-        'in the config file is deprecated. Please use only the AIRFLOW_HOME '
-        'environment variable and remove the config file entry.'
-    )
-    if 'AIRFLOW_HOME' in os.environ:
-        warnings.warn(msg, category=DeprecationWarning)
-    elif conf.get('core', 'airflow_home') == AIRFLOW_HOME:
-        warnings.warn(
-            'Specifying airflow_home in the config file is deprecated. As you '
-            'have left it at the default value you should remove the setting '
-            'from your airflow.cfg and suffer no change in behaviour.',
-            category=DeprecationWarning,
-        )
+def initialize_config():
+    """
+    Load the Airflow config files.
+
+    Called for you automatically as part of the Airflow boot process.
+    """
+    global FERNET_KEY, AIRFLOW_HOME
+
+    default_config = _parameterized_config_from_template('default_airflow.cfg')
+
+    conf = AirflowConfigParser(default_config=default_config)
+
+    if conf.getboolean('core', 'unit_test_mode'):
+        # Load test config only
+        if not os.path.isfile(TEST_CONFIG_FILE):
+            from cryptography.fernet import Fernet
+
+            log.info('Creating new Airflow config file for unit tests in: %s', TEST_CONFIG_FILE)
+            pathlib.Path(AIRFLOW_HOME).mkdir(parents=True, exist_ok=True)
+
+            FERNET_KEY = Fernet.generate_key().decode()
+
+            with open(TEST_CONFIG_FILE, 'w') as file:
+                cfg = _parameterized_config_from_template('default_test.cfg')
+                file.write(cfg)
+
+        conf.load_test_config()
     else:
-        AIRFLOW_HOME = conf.get('core', 'airflow_home')
-        warnings.warn(msg, category=DeprecationWarning)
+        # Load normal config
+        if not os.path.isfile(AIRFLOW_CONFIG):
+            from cryptography.fernet import Fernet
 
+            log.info('Creating new Airflow config file in: %s', AIRFLOW_CONFIG)
+            pathlib.Path(AIRFLOW_HOME).mkdir(parents=True, exist_ok=True)
 
-WEBSERVER_CONFIG = AIRFLOW_HOME + '/webserver_config.py'
+            FERNET_KEY = Fernet.generate_key().decode()
 
-if not os.path.isfile(WEBSERVER_CONFIG):
-    log.info('Creating new FAB webserver config file in: %s', WEBSERVER_CONFIG)
-    DEFAULT_WEBSERVER_CONFIG, _ = _read_default_config_file('default_webserver_config.py')
-    with open(WEBSERVER_CONFIG, 'w') as file:
-        file.write(DEFAULT_WEBSERVER_CONFIG)
+            with open(AIRFLOW_CONFIG, 'w') as file:
+                file.write(default_config)
 
-if conf.getboolean('core', 'unit_test_mode'):
-    conf.load_test_config()
+        log.info("Reading the config from %s", AIRFLOW_CONFIG)
+
+        conf.read(AIRFLOW_CONFIG)
+
+        if conf.has_option('core', 'AIRFLOW_HOME'):
+            msg = (
+                'Specifying both AIRFLOW_HOME environment variable and airflow_home '
+                'in the config file is deprecated. Please use only the AIRFLOW_HOME '
+                'environment variable and remove the config file entry.'
+            )
+            if 'AIRFLOW_HOME' in os.environ:
+                warnings.warn(msg, category=DeprecationWarning)
+            elif conf.get('core', 'airflow_home') == AIRFLOW_HOME:
+                warnings.warn(
+                    'Specifying airflow_home in the config file is deprecated. As you '
+                    'have left it at the default value you should remove the setting '
+                    'from your airflow.cfg and suffer no change in behaviour.',
+                    category=DeprecationWarning,
+                )
+            else:
+                AIRFLOW_HOME = conf.get('core', 'airflow_home')
+                warnings.warn(msg, category=DeprecationWarning)
+
+        # They _might_ have set unit_test_mode in the airflow.cfg, we still
+        # want to respect that and then load the unittests.cfg
+        if conf.getboolean('core', 'unit_test_mode'):
+            conf.load_test_config()
+
+    # Make it no longer a proxy variable, just set it to an actual string
+    global WEBSERVER_CONFIG
+    WEBSERVER_CONFIG = AIRFLOW_HOME + '/webserver_config.py'
+
+    if not os.path.isfile(WEBSERVER_CONFIG):
+        import shutil
+
+        log.info('Creating new FAB webserver config file in: %s', WEBSERVER_CONFIG)
+        shutil.copy(_default_config_file_path('default_webserver_config.py'), WEBSERVER_CONFIG)
+
+    conf.validate()
+
+    return conf
 
 
 # Historical convenience functions to access config entries
@@ -1006,6 +1029,78 @@ def initialize_secrets_backends() -> List[BaseSecretsBackend]:
     return backend_list
 
 
+@functools.lru_cache(maxsize=None)
+def _DEFAULT_CONFIG():
+    path = _default_config_file_path('default_airflow.cfg')
+    with open(path) as fh:
+        return fh.read()
+
+
+@functools.lru_cache(maxsize=None)
+def _TEST_CONFIG():
+    path = _default_config_file_path('default_test.cfg')
+    with open(path) as fh:
+        return fh.read()
+
+
+_deprecated = {
+    'DEFAULT_CONFIG': _DEFAULT_CONFIG,
+    'TEST_CONFIG': _TEST_CONFIG,
+    'TEST_CONFIG_FILE_PATH': functools.partial(_default_config_file_path, ('default_test.cfg')),
+    'DEFAULT_CONFIG_FILE_PATH': functools.partial(_default_config_file_path, ('default_airflow.cfg')),
+}
+
+
+def __getattr__(name):
+    if name in _deprecated:
+        warnings.warn(
+            f"{__name__}.{name} is deprecated and will be removed in future",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        return _deprecated[name]()
+    raise AttributeError(f"module {__name__} has no attribute {name}")
+
+
+# Setting AIRFLOW_HOME and AIRFLOW_CONFIG from environment variables, using
+# "~/airflow" and "$AIRFLOW_HOME/airflow.cfg" respectively as defaults.
+
+AIRFLOW_HOME = get_airflow_home()
+AIRFLOW_CONFIG = get_airflow_config(AIRFLOW_HOME)
+
+
+# Set up dags folder for unit tests
+# this directory won't exist if users install via pip
+_TEST_DAGS_FOLDER = os.path.join(
+    os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 'tests', 'dags'
+)
+if os.path.exists(_TEST_DAGS_FOLDER):
+    TEST_DAGS_FOLDER = _TEST_DAGS_FOLDER
+else:
+    TEST_DAGS_FOLDER = os.path.join(AIRFLOW_HOME, 'dags')
+
+# Set up plugins folder for unit tests
+_TEST_PLUGINS_FOLDER = os.path.join(
+    os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 'tests', 'plugins'
+)
+if os.path.exists(_TEST_PLUGINS_FOLDER):
+    TEST_PLUGINS_FOLDER = _TEST_PLUGINS_FOLDER
+else:
+    TEST_PLUGINS_FOLDER = os.path.join(AIRFLOW_HOME, 'plugins')
+
+
+TEST_CONFIG_FILE = get_airflow_test_config(AIRFLOW_HOME)
+
+SECRET_KEY = b64encode(os.urandom(16)).decode('utf-8')
+FERNET_KEY = ''  # Set only if needed when generating a new file
+WEBSERVER_CONFIG = ''  # Set by initialize_config
+
+conf = initialize_config()
 secrets_backend_list = initialize_secrets_backends()
 
-conf.validate()
+
+PY37 = sys.version_info >= (3, 7)
+if not PY37:
+    from pep562 import Pep562
+
+    Pep562(__name__)
diff --git a/tests/core/test_configuration.py b/tests/core/test_configuration.py
index 4e32c5e..a64407e 100644
--- a/tests/core/test_configuration.py
+++ b/tests/core/test_configuration.py
@@ -19,6 +19,7 @@ import io
 import os
 import re
 import tempfile
+import textwrap
 import unittest
 import warnings
 from collections import OrderedDict
@@ -28,7 +29,6 @@ import pytest
 
 from airflow import configuration
 from airflow.configuration import (
-    DEFAULT_CONFIG,
     AirflowConfigException,
     AirflowConfigParser,
     conf,
@@ -561,8 +561,17 @@ notacommand = OK
             assert test_cmdenv_conf.get('testcmdenv', 'notacommand') == 'OK'
 
     def test_parameterized_config_gen(self):
+        config = textwrap.dedent(
+            """
+            [core]
+            dags_folder = {AIRFLOW_HOME}/dags
+            sql_alchemy_conn = sqlite:///{AIRFLOW_HOME}/airflow.db
+            parallelism = 32
+            fernet_key = {FERNET_KEY}
+        """
+        )
 
-        cfg = parameterized_config(DEFAULT_CONFIG)
+        cfg = parameterized_config(config)
 
         # making sure some basic building blocks are present:
         assert "[core]" in cfg
diff --git a/tests/test_utils/config.py b/tests/test_utils/config.py
index c55a2b5..7bffb47 100644
--- a/tests/test_utils/config.py
+++ b/tests/test_utils/config.py
@@ -38,6 +38,8 @@ def conf_vars(overrides):
         else:
             original[(section, key)] = None
         if value is not None:
+            if not conf.has_section(section):
+                conf.add_section(section)
             conf.set(section, key, value)
         else:
             conf.remove_option(section, key)

[airflow] 01/03: Fix error when running tasks with Sentry integration enabled. (#13929)

Posted by as...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 14f978e9f27c8e223dd2f8e0d121ef699d8da663
Author: Jun <Ju...@users.noreply.github.com>
AuthorDate: Sat Mar 20 05:40:22 2021 +0800

    Fix error when running tasks with Sentry integration enabled. (#13929)
    
    Co-authored-by: Ash Berlin-Taylor <as...@apache.org>
    (cherry picked from commit 0e8698d3edb3712eba0514a39d1d30fbfeeaec09)
---
 airflow/sentry.py           | 13 +++++++++---
 airflow/utils/session.py    | 21 +++++++++++-------
 tests/utils/test_session.py | 52 +++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 75 insertions(+), 11 deletions(-)

diff --git a/airflow/sentry.py b/airflow/sentry.py
index 8dc9091..62eac9a 100644
--- a/airflow/sentry.py
+++ b/airflow/sentry.py
@@ -21,7 +21,7 @@ import logging
 from functools import wraps
 
 from airflow.configuration import conf
-from airflow.utils.session import provide_session
+from airflow.utils.session import find_session_idx, provide_session
 from airflow.utils.state import State
 
 log = logging.getLogger(__name__)
@@ -149,14 +149,21 @@ if conf.getboolean("sentry", 'sentry_on', fallback=False):
 
         def enrich_errors(self, func):
             """Wrap TaskInstance._run_raw_task to support task specific tags and breadcrumbs."""
+            session_args_idx = find_session_idx(func)
 
             @wraps(func)
-            def wrapper(task_instance, *args, session=None, **kwargs):
+            def wrapper(task_instance, *args, **kwargs):
                 # Wrapping the _run_raw_task function with push_scope to contain
                 # tags and breadcrumbs to a specific Task Instance
+
+                try:
+                    session = kwargs.get('session', args[session_args_idx])
+                except IndexError:
+                    session = None
+
                 with sentry_sdk.push_scope():
                     try:
-                        return func(task_instance, *args, session=session, **kwargs)
+                        return func(task_instance, *args, **kwargs)
                     except Exception as e:
                         self.add_tagging(task_instance)
                         self.add_breadcrumbs(task_instance, session=session)
diff --git a/airflow/utils/session.py b/airflow/utils/session.py
index 4001a0f..f8b9bcd 100644
--- a/airflow/utils/session.py
+++ b/airflow/utils/session.py
@@ -40,6 +40,18 @@ def create_session():
 RT = TypeVar("RT")  # pylint: disable=invalid-name
 
 
+def find_session_idx(func: Callable[..., RT]) -> int:
+    """Find session index in function call parameter."""
+    func_params = signature(func).parameters
+    try:
+        # func_params is an ordered dict -- this is the "recommended" way of getting the position
+        session_args_idx = tuple(func_params).index("session")
+    except ValueError:
+        raise ValueError(f"Function {func.__qualname__} has no `session` argument") from None
+
+    return session_args_idx
+
+
 def provide_session(func: Callable[..., RT]) -> Callable[..., RT]:
     """
     Function decorator that provides a session if it isn't provided.
@@ -47,14 +59,7 @@ def provide_session(func: Callable[..., RT]) -> Callable[..., RT]:
     database transaction, you pass it to the function, if not this wrapper
     will create one and close it for you.
     """
-    func_params = signature(func).parameters
-    try:
-        # func_params is an ordered dict -- this is the "recommended" way of getting the position
-        session_args_idx = tuple(func_params).index("session")
-    except ValueError:
-        raise ValueError(f"Function {func.__qualname__} has no `session` argument") from None
-    # We don't need this anymore -- ensure we don't keep a reference to it by mistake
-    del func_params
+    session_args_idx = find_session_idx(func)
 
     @wraps(func)
     def wrapper(*args, **kwargs) -> RT:
diff --git a/tests/utils/test_session.py b/tests/utils/test_session.py
new file mode 100644
index 0000000..08f317f
--- /dev/null
+++ b/tests/utils/test_session.py
@@ -0,0 +1,52 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import pytest
+
+from airflow.utils.session import provide_session
+
+
+class TestSession:
+    def dummy_session(self, session=None):
+        return session
+
+    def test_raised_provide_session(self):
+        with pytest.raises(ValueError, match="Function .*dummy has no `session` argument"):
+
+            @provide_session
+            def dummy():
+                pass
+
+    def test_provide_session_without_args_and_kwargs(self):
+        assert self.dummy_session() is None
+
+        wrapper = provide_session(self.dummy_session)
+
+        assert wrapper() is not None
+
+    def test_provide_session_with_args(self):
+        wrapper = provide_session(self.dummy_session)
+
+        session = object()
+        assert wrapper(session) is session
+
+    def test_provide_session_with_kwargs(self):
+        wrapper = provide_session(self.dummy_session)
+
+        session = object()
+        assert wrapper(session=session) is session

[airflow] 03/03: Use libyaml C library when available. (#14577)

Posted by as...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit a857c905f8a8d944c59c55ef530e6e301f88ef99
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Fri Mar 5 10:16:58 2021 +0000

    Use libyaml C library when available. (#14577)
    
    This makes loading local providers 1/3 quicker -- from 2s down from 3s
    on my local SSD.
    
    The `airflow.utils.yaml` module can be used in place of the normal yaml
    module, with the bonus that `safe_load` will use libyaml where available
    instead of always using the pure python version.
    
    This shaves 3 minutes off the "WWW" tests - down to 8 minutes from
    11 minutes.
    
    I have not used this module in tests/docs code etc, as I don't want to
    force importing `airflow` (and everything in currently brings in) in to
    those contexts.
    
    (cherry picked from commit 7daebefd15355b3f1331c6c58f66f3f88d38a10a)
---
 airflow/cli/commands/connection_command.py         |  2 +-
 airflow/cli/commands/kubernetes_command.py         |  2 +-
 airflow/cli/simple_table.py                        |  2 +-
 airflow/configuration.py                           |  2 +-
 airflow/kubernetes/pod_generator.py                |  2 +-
 airflow/kubernetes/refresh_config.py               |  3 +-
 .../providers/cncf/kubernetes/hooks/kubernetes.py  |  6 +-
 .../cncf/kubernetes/operators/kubernetes_pod.py    |  6 +-
 .../google/cloud/operators/cloud_build.py          |  5 +-
 airflow/providers_manager.py                       |  2 +-
 airflow/secrets/local_filesystem.py                |  3 +-
 airflow/utils/yaml.py                              | 76 ++++++++++++++++++++++
 airflow/www/views.py                               |  2 +-
 .../copy_provider_package_sources.py               |  1 +
 dev/provider_packages/prepare_provider_packages.py |  7 +-
 docs/conf.py                                       |  7 +-
 docs/exts/docs_build/lint_checks.py                |  8 ++-
 docs/exts/provider_yaml_utils.py                   |  8 ++-
 .../pre_commit_check_pre_commit_hook_names.py      |  7 +-
 .../pre_commit_check_provider_yaml_files.py        |  7 +-
 20 files changed, 138 insertions(+), 20 deletions(-)

diff --git a/airflow/cli/commands/connection_command.py b/airflow/cli/commands/connection_command.py
index 202a252..435395b 100644
--- a/airflow/cli/commands/connection_command.py
+++ b/airflow/cli/commands/connection_command.py
@@ -22,9 +22,9 @@ import sys
 from typing import Any, Dict, List
 from urllib.parse import urlparse, urlunparse
 
-import yaml
 from sqlalchemy.orm import exc
 
+import airflow.utils.yaml as yaml
 from airflow.cli.simple_table import AirflowConsole
 from airflow.exceptions import AirflowNotFoundException
 from airflow.hooks.base import BaseHook
diff --git a/airflow/cli/commands/kubernetes_command.py b/airflow/cli/commands/kubernetes_command.py
index 23864ae..f98c45e 100644
--- a/airflow/cli/commands/kubernetes_command.py
+++ b/airflow/cli/commands/kubernetes_command.py
@@ -18,11 +18,11 @@
 import os
 import sys
 
-import yaml
 from kubernetes import client
 from kubernetes.client.api_client import ApiClient
 from kubernetes.client.rest import ApiException
 
+import airflow.utils.yaml as yaml
 from airflow.executors.kubernetes_executor import KubeConfig, create_pod_id
 from airflow.kubernetes import pod_generator
 from airflow.kubernetes.kube_client import get_kube_client
diff --git a/airflow/cli/simple_table.py b/airflow/cli/simple_table.py
index 3851272..2aa4707 100644
--- a/airflow/cli/simple_table.py
+++ b/airflow/cli/simple_table.py
@@ -18,13 +18,13 @@ import inspect
 import json
 from typing import Any, Callable, Dict, List, Optional, Union
 
-import yaml
 from rich.box import ASCII_DOUBLE_HEAD
 from rich.console import Console
 from rich.syntax import Syntax
 from rich.table import Table
 from tabulate import tabulate
 
+import airflow.utils.yaml as yaml
 from airflow.plugins_manager import PluginsDirectorySource
 
 
diff --git a/airflow/configuration.py b/airflow/configuration.py
index f1c48ed..dc01d4a 100644
--- a/airflow/configuration.py
+++ b/airflow/configuration.py
@@ -97,7 +97,7 @@ def default_config_yaml() -> dict:
 
     :return: Python dictionary containing configs & their info
     """
-    import yaml
+    import airflow.utils.yaml as yaml
 
     with open(_default_config_file_path('config.yml')) as config_file:
         return yaml.safe_load(config_file)
diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py
index 0782f1a..da2001b 100644
--- a/airflow/kubernetes/pod_generator.py
+++ b/airflow/kubernetes/pod_generator.py
@@ -30,11 +30,11 @@ import warnings
 from functools import reduce
 from typing import List, Optional, Union
 
-import yaml
 from dateutil import parser
 from kubernetes.client import models as k8s
 from kubernetes.client.api_client import ApiClient
 
+import airflow.utils.yaml as yaml
 from airflow.exceptions import AirflowConfigException
 from airflow.kubernetes.pod_generator_deprecated import PodGenerator as PodGeneratorDeprecated
 from airflow.version import version as airflow_version
diff --git a/airflow/kubernetes/refresh_config.py b/airflow/kubernetes/refresh_config.py
index 0004cac..37b3bec 100644
--- a/airflow/kubernetes/refresh_config.py
+++ b/airflow/kubernetes/refresh_config.py
@@ -27,11 +27,12 @@ import time
 from typing import Optional, cast
 
 import pendulum
-import yaml
 from kubernetes.client import Configuration
 from kubernetes.config.exec_provider import ExecProvider
 from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION, KubeConfigLoader
 
+import airflow.utils.yaml as yaml
+
 
 def _parse_timestamp(ts_str: str) -> int:
     parsed_dt = cast(pendulum.DateTime, pendulum.parse(ts_str))
diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
index cf27713..10c7510 100644
--- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
@@ -17,10 +17,14 @@
 import tempfile
 from typing import Any, Dict, Generator, Optional, Tuple, Union
 
-import yaml
 from cached_property import cached_property
 from kubernetes import client, config, watch
 
+try:
+    import airflow.utils.yaml as yaml
+except ImportError:
+    import yaml
+
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 
diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
index 3f42ab1..e6d1bae 100644
--- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
@@ -19,9 +19,13 @@ import re
 import warnings
 from typing import Any, Dict, Iterable, List, Optional, Tuple
 
-import yaml
 from kubernetes.client import CoreV1Api, models as k8s
 
+try:
+    import airflow.utils.yaml as yaml
+except ImportError:
+    import yaml
+
 from airflow.exceptions import AirflowException
 from airflow.kubernetes import kube_client, pod_generator, pod_launcher
 from airflow.kubernetes.pod_generator import PodGenerator
diff --git a/airflow/providers/google/cloud/operators/cloud_build.py b/airflow/providers/google/cloud/operators/cloud_build.py
index 101c04d..b4c0cf7 100644
--- a/airflow/providers/google/cloud/operators/cloud_build.py
+++ b/airflow/providers/google/cloud/operators/cloud_build.py
@@ -22,7 +22,10 @@ from copy import deepcopy
 from typing import Any, Dict, Optional, Sequence, Union
 from urllib.parse import unquote, urlparse
 
-import yaml
+try:
+    import airflow.utils.yaml as yaml
+except ImportError:
+    import yaml
 
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index d3b7ffb..d29ec70 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -25,9 +25,9 @@ from collections import OrderedDict
 from typing import Any, Dict, NamedTuple, Set
 
 import jsonschema
-import yaml
 from wtforms import Field
 
+import airflow.utils.yaml as yaml
 from airflow.utils.entry_points import entry_points_with_dist
 
 try:
diff --git a/airflow/secrets/local_filesystem.py b/airflow/secrets/local_filesystem.py
index c63bb91..3ec20e1 100644
--- a/airflow/secrets/local_filesystem.py
+++ b/airflow/secrets/local_filesystem.py
@@ -25,8 +25,7 @@ from inspect import signature
 from json import JSONDecodeError
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
 
-import yaml
-
+import airflow.utils.yaml as yaml
 from airflow.exceptions import (
     AirflowException,
     AirflowFileParseException,
diff --git a/airflow/utils/yaml.py b/airflow/utils/yaml.py
new file mode 100644
index 0000000..e3be61c
--- /dev/null
+++ b/airflow/utils/yaml.py
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Use libyaml for YAML dump/load operations where possible.
+
+If libyaml is available we will use it -- it is significantly faster.
+
+This module delegates all other properties to the yaml module, so it can be used as:
+
+.. code-block:: python
+    import airflow.utils.yaml as yaml
+
+And then be used directly in place of the normal python module.
+"""
+import sys
+from typing import TYPE_CHECKING, Any, BinaryIO, TextIO, Union, cast
+
+if TYPE_CHECKING:
+    from yaml.error import MarkedYAMLError  # noqa
+
+
+def safe_load(stream: Union[bytes, str, BinaryIO, TextIO]) -> Any:
+    """Like yaml.safe_load, but use the C libyaml for speed where we can"""
+    # delay import until use.
+    from yaml import load as orig
+
+    try:
+        from yaml import CSafeLoader as SafeLoader
+    except ImportError:
+        from yaml import SafeLoader  # type: ignore[no-redef]
+
+    return orig(stream, SafeLoader)
+
+
+def dump(data: Any, **kwargs) -> str:
+    """Like yaml.safe_dump, but use the C libyaml for speed where we can"""
+    # delay import until use.
+    from yaml import dump as orig
+
+    try:
+        from yaml import CSafeDumper as SafeDumper
+    except ImportError:
+        from yaml import SafeDumper  # type: ignore[no-redef]
+
+    return cast(str, orig(data, Dumper=SafeDumper, **kwargs))
+
+
+def __getattr__(name):
+    # Delegate anything else to the yaml module
+    import yaml
+
+    if name == "FullLoader":
+        # Try to use CFullLoader by default
+        getattr(yaml, "CFullLoader", yaml.FullLoader)
+
+    return getattr(yaml, name)
+
+
+if sys.version_info < (3, 7):
+    from pep562 import Pep562
+
+    Pep562(__name__)
diff --git a/airflow/www/views.py b/airflow/www/views.py
index f9f7f5c..6d19208 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -34,7 +34,6 @@ from urllib.parse import parse_qsl, unquote, urlencode, urlparse
 import lazy_object_proxy
 import nvd3
 import sqlalchemy as sqla
-import yaml
 from flask import (
     Markup,
     Response,
@@ -66,6 +65,7 @@ from wtforms import SelectField, validators
 from wtforms.validators import InputRequired
 
 import airflow
+import airflow.utils.yaml as yaml
 from airflow import models, plugins_manager, settings
 from airflow.api.common.experimental.mark_tasks import (
     set_dag_run_state_to_failed,
diff --git a/dev/provider_packages/copy_provider_package_sources.py b/dev/provider_packages/copy_provider_package_sources.py
index c7f75f5..4c504b6 100755
--- a/dev/provider_packages/copy_provider_package_sources.py
+++ b/dev/provider_packages/copy_provider_package_sources.py
@@ -167,6 +167,7 @@ class RefactorBackportPackages:
             ("airflow.sensors.time_delta", "airflow.sensors.time_delta_sensor"),
             ("airflow.sensors.weekday", "airflow.contrib.sensors.weekday_sensor"),
             ("airflow.utils.session", "airflow.utils.db"),
+            ("airflow.utils.yaml", "yaml"),
         ]
         for new, old in changes:
             self.qry.select_module(new).rename(old)
diff --git a/dev/provider_packages/prepare_provider_packages.py b/dev/provider_packages/prepare_provider_packages.py
index 3cfc39f..49408d8 100755
--- a/dev/provider_packages/prepare_provider_packages.py
+++ b/dev/provider_packages/prepare_provider_packages.py
@@ -48,6 +48,11 @@ from rich import print
 from rich.console import Console
 from rich.syntax import Syntax
 
+try:
+    from yaml import CSafeLoader as SafeLoader
+except ImportError:
+    from yaml import SafeLoader  # type: ignore[no-redef]
+
 INITIAL_CHANGELOG_CONTENT = """
 
 
@@ -1264,7 +1269,7 @@ def get_provider_info_from_provider_yaml(provider_package_id: str) -> Dict[str,
     if not os.path.exists(provider_yaml_file_name):
         raise Exception(f"The provider.yaml file is missing: {provider_yaml_file_name}")
     with open(provider_yaml_file_name) as provider_file:
-        provider_yaml_dict = yaml.safe_load(provider_file.read())
+        provider_yaml_dict = yaml.load(provider_file, SafeLoader)
     provider_info = convert_to_provider_info(provider_yaml_dict)
     validate_provider_info_with_2_0_0_schema(provider_info)
     validate_provider_info_with_runtime_schema(provider_info)
diff --git a/docs/conf.py b/docs/conf.py
index cc6dad8..c68f6ea 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -39,6 +39,11 @@ from typing import Any, Dict, List, Optional, Tuple
 
 import yaml
 
+try:
+    from yaml import CSafeLoader as SafeLoader
+except ImportError:
+    from yaml import SafeLoader  # type: ignore[misc]
+
 import airflow
 from airflow.configuration import AirflowConfigParser, default_config_yaml
 from docs.exts.docs_build.third_party_inventories import (  # pylint: disable=no-name-in-module,wrong-import-order
@@ -334,7 +339,7 @@ elif PACKAGE_NAME.startswith('apache-airflow-providers-'):
             return {}
 
         with open(file_path) as config_file:
-            return yaml.safe_load(config_file)
+            return yaml.load(config_file, SafeLoader)
 
     config = _load_config()
     if config:
diff --git a/docs/exts/docs_build/lint_checks.py b/docs/exts/docs_build/lint_checks.py
index 54d9705..88aa1bb 100644
--- a/docs/exts/docs_build/lint_checks.py
+++ b/docs/exts/docs_build/lint_checks.py
@@ -24,7 +24,11 @@ from typing import Iterable, List, Optional, Set
 
 import yaml
 
-# pylint: disable=wrong-import-order
+try:
+    from yaml import CSafeLoader as SafeLoader
+except ImportError:
+    from yaml import SafeLoader  # type: ignore[misc]
+
 import airflow
 from docs.exts.docs_build.docs_builder import ALL_PROVIDER_YAMLS  # pylint: disable=no-name-in-module
 from docs.exts.docs_build.errors import DocBuildError  # pylint: disable=no-name-in-module
@@ -330,7 +334,7 @@ def check_docker_image_tag_in_quick_start_guide() -> List[DocBuildError]:
     # master tag is little outdated.
     expected_image = f'apache/airflow:{expected_tag}'
     with open(compose_file_path) as yaml_file:
-        content = yaml.safe_load(yaml_file)
+        content = yaml.load(yaml_file, SafeLoader)
         current_image_expression = content['x-airflow-common']['image']
         if expected_image not in current_image_expression:
             build_errors.append(
diff --git a/docs/exts/provider_yaml_utils.py b/docs/exts/provider_yaml_utils.py
index 130084c..a6d1ee2 100644
--- a/docs/exts/provider_yaml_utils.py
+++ b/docs/exts/provider_yaml_utils.py
@@ -23,6 +23,12 @@ from typing import Any, Dict, List
 import jsonschema
 import yaml
 
+try:
+    from yaml import CSafeLoader as SafeLoader
+except ImportError:
+    from yaml import SafeLoader  # type: ignore[misc]
+
+
 ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
 PROVIDER_DATA_SCHEMA_PATH = os.path.join(ROOT_DIR, "airflow", "provider.yaml.schema.json")
 
@@ -53,7 +59,7 @@ def load_package_data() -> List[Dict[str, Any]]:
     result = []
     for provider_yaml_path in get_provider_yaml_paths():
         with open(provider_yaml_path) as yaml_file:
-            provider = yaml.safe_load(yaml_file)
+            provider = yaml.load(yaml_file, SafeLoader)
         try:
             jsonschema.validate(provider, schema=schema)
         except jsonschema.ValidationError:
diff --git a/scripts/ci/pre_commit/pre_commit_check_pre_commit_hook_names.py b/scripts/ci/pre_commit/pre_commit_check_pre_commit_hook_names.py
index fae42c8..5e16f42 100755
--- a/scripts/ci/pre_commit/pre_commit_check_pre_commit_hook_names.py
+++ b/scripts/ci/pre_commit/pre_commit_check_pre_commit_hook_names.py
@@ -24,6 +24,11 @@ import sys
 
 import yaml
 
+try:
+    from yaml import CSafeLoader as SafeLoader
+except ImportError:
+    from yaml import SafeLoader  # type: ignore[no-redef]
+
 
 def main() -> int:
     parser = argparse.ArgumentParser()
@@ -34,7 +39,7 @@ def main() -> int:
     retval = 0
 
     with open('.pre-commit-config.yaml', 'rb') as f:
-        content = yaml.safe_load(f)
+        content = yaml.load(f, SafeLoader)
         errors = get_errors(content, max_length)
     if errors:
         retval = 1
diff --git a/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py b/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py
index 038e027..35c80e9 100755
--- a/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py
+++ b/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py
@@ -29,6 +29,11 @@ import jsonschema
 import yaml
 from tabulate import tabulate
 
+try:
+    from yaml import CSafeLoader as SafeLoader
+except ImportError:
+    from yaml import SafeLoader  # type: ignore[no-redef]
+
 if __name__ != "__main__":
     raise Exception(
         "This file is intended to be executed as an executable program. You cannot use it as a module."
@@ -60,7 +65,7 @@ def _load_package_data(package_paths: Iterable[str]):
     result = {}
     for provider_yaml_path in package_paths:
         with open(provider_yaml_path) as yaml_file:
-            provider = yaml.safe_load(yaml_file)
+            provider = yaml.load(yaml_file, SafeLoader)
         rel_path = os.path.relpath(provider_yaml_path, ROOT_DIR)
         try:
             jsonschema.validate(provider, schema=schema)