You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/08/10 21:47:11 UTC

[airflow] branch main updated: Allow PythonVenvOperator using other index url (#33017)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7d87d71e05 Allow PythonVenvOperator using other index url (#33017)
7d87d71e05 is described below

commit 7d87d71e053712a96c5a14da38b310a69667de7d
Author: Jens Scheffler <95...@users.noreply.github.com>
AuthorDate: Thu Aug 10 23:47:04 2023 +0200

    Allow PythonVenvOperator using other index url (#33017)
    
    * Extended PythonVirtualEnvOperator for extra index URL
    
    * Separate-out BranchPythonVirtualenvOperator from this PR
---
 airflow/decorators/__init__.pyi               |  12 ++-
 airflow/hooks/package_index.py                |  94 ++++++++++++++++++++
 airflow/operators/python.py                   |  54 +++++++-----
 airflow/providers_manager.py                  |   3 +-
 airflow/utils/python_virtualenv.py            |  19 ++++
 docs/apache-airflow/howto/operator/python.rst |  14 +++
 docs/spelling_wordlist.txt                    |   3 +
 tests/hooks/test_package_index.py             | 122 ++++++++++++++++++++++++++
 tests/operators/test_python.py                |  13 +++
 tests/utils/test_python_virtualenv.py         |  37 +++++++-
 10 files changed, 347 insertions(+), 24 deletions(-)

diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi
index 7b2ac1c9ce..4ee655f954 100644
--- a/airflow/decorators/__init__.pyi
+++ b/airflow/decorators/__init__.pyi
@@ -21,7 +21,7 @@
 from __future__ import annotations
 
 from datetime import timedelta
-from typing import Any, Callable, Iterable, Mapping, overload
+from typing import Any, Callable, Collection, Container, Iterable, Mapping, overload
 
 from kubernetes.client import models as k8s
 
@@ -107,6 +107,9 @@ class TaskDecoratorCollection:
         use_dill: bool = False,
         system_site_packages: bool = True,
         templates_dict: Mapping[str, Any] | None = None,
+        pip_install_options: list[str] | None = None,
+        skip_on_exit_code: int | Container[int] | None = None,
+        index_urls: None | Collection[str] | str = None,
         show_return_value_in_logs: bool = True,
         **kwargs,
     ) -> TaskDecorator:
@@ -124,6 +127,13 @@ class TaskDecoratorCollection:
         :param system_site_packages: Whether to include
             system_site_packages in your virtualenv.
             See virtualenv documentation for more information.
+        :param pip_install_options: a list of pip install options when installing requirements
+            See 'pip install -h' for available options
+        :param skip_on_exit_code: If python_callable exits with this exit code, leave the task
+            in ``skipped`` state (default: None). If set to ``None``, any non-zero
+            exit code will be treated as a failure.
+        :param index_urls: an optional list of index urls to load Python packages from.
+            If not provided the system pip conf will be used to source packages from.
         :param templates_dict: a dictionary where the values are templates that
             will get templated by the Airflow engine sometime between
             ``__init__`` and ``execute`` takes place and are made available
diff --git a/airflow/hooks/package_index.py b/airflow/hooks/package_index.py
new file mode 100644
index 0000000000..5c940506a1
--- /dev/null
+++ b/airflow/hooks/package_index.py
@@ -0,0 +1,94 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Hook for additional Package Indexes (Python)."""
+from __future__ import annotations
+
+import subprocess
+from typing import Any
+from urllib.parse import quote, urlparse
+
+from airflow.hooks.base import BaseHook
+
+
+class PackageIndexHook(BaseHook):
+    """Specify package indexes/Python package sources using Airflow connections."""
+
+    conn_name_attr = "pi_conn_id"
+    default_conn_name = "package_index_default"
+    conn_type = "package_index"
+    hook_name = "Package Index (Python)"
+
+    def __init__(self, pi_conn_id: str = default_conn_name) -> None:
+        super().__init__()
+        self.pi_conn_id = pi_conn_id
+        self.conn = None
+
+    @staticmethod
+    def get_ui_field_behaviour() -> dict[str, Any]:
+        """Returns custom field behaviour."""
+        return {
+            "hidden_fields": ["schema", "port", "extra"],
+            "relabeling": {"host": "Package Index URL"},
+            "placeholders": {
+                "host": "Example: https://my-package-mirror.net/pypi/repo-name/simple",
+                "login": "Username for package index",
+                "password": "Password for package index (will be masked)",
+            },
+        }
+
+    @staticmethod
+    def _get_basic_auth_conn_url(index_url: str, user: str | None, password: str | None) -> str:
+        """Returns a connection URL with basic auth credentials based on connection config."""
+        url = urlparse(index_url)
+        host = url.netloc.split("@")[-1]
+        if user:
+            if password:
+                host = f"{quote(user)}:{quote(password)}@{host}"
+            else:
+                host = f"{quote(user)}@{host}"
+        return url._replace(netloc=host).geturl()
+
+    def get_conn(self) -> Any:
+        """Returns connection for the hook."""
+        return self.get_connection_url()
+
+    def get_connection_url(self) -> Any:
+        """Returns a connection URL with embedded credentials."""
+        conn = self.get_connection(self.pi_conn_id)
+        index_url = conn.host
+        if not index_url:
+            raise Exception("Please provide an index URL.")
+        return self._get_basic_auth_conn_url(index_url, conn.login, conn.password)
+
+    def test_connection(self) -> tuple[bool, str]:
+        """Test connection to package index url."""
+        conn_url = self.get_connection_url()
+        proc = subprocess.run(
+            ["pip", "search", "not-existing-test-package", "--no-input", "--index", conn_url],
+            check=False,
+            capture_output=True,
+        )
+        conn = self.get_connection(self.pi_conn_id)
+        if proc.returncode not in [
+            0,  # executed successfully, found package
+            23,  # executed successfully, didn't find any packages
+            #      (but we do not expect it to find 'not-existing-test-package')
+        ]:
+            return False, f"Connection test to {conn.host} failed. Error: {str(proc.stderr)}"
+
+        return True, f"Connection to {conn.host} tested successfully!"
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 479fb600c4..b35e16f8b8 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -519,6 +519,8 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
     :param skip_on_exit_code: If python_callable exits with this exit code, leave the task
         in ``skipped`` state (default: None). If set to ``None``, any non-zero
         exit code will be treated as a failure.
+    :param index_urls: an optional list of index urls to load Python packages from.
+        If not provided the system pip conf will be used to source packages from.
     """
 
     template_fields: Sequence[str] = tuple({"requirements"} | set(PythonOperator.template_fields))
@@ -540,6 +542,7 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
         templates_exts: list[str] | None = None,
         expect_airflow: bool = True,
         skip_on_exit_code: int | Container[int] | None = None,
+        index_urls: None | Collection[str] | str = None,
         **kwargs,
     ):
         if (
@@ -555,14 +558,20 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
         if not is_venv_installed():
             raise AirflowException("PythonVirtualenvOperator requires virtualenv, please install it.")
         if not requirements:
-            self.requirements: list[str] | str = []
+            self.requirements: list[str] = []
         elif isinstance(requirements, str):
-            self.requirements = requirements
+            self.requirements = [requirements]
         else:
             self.requirements = list(requirements)
         self.python_version = python_version
         self.system_site_packages = system_site_packages
         self.pip_install_options = pip_install_options
+        if isinstance(index_urls, str):
+            self.index_urls: list[str] | None = [index_urls]
+        elif isinstance(index_urls, Collection):
+            self.index_urls = list(index_urls)
+        else:
+            self.index_urls = None
         super().__init__(
             python_callable=python_callable,
             use_dill=use_dill,
@@ -576,28 +585,31 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
             **kwargs,
         )
 
+    def _requirements_list(self) -> list[str]:
+        """Prepares a list of requirements that need to be installed for the venv."""
+        requirements = [str(dependency) for dependency in self.requirements]
+        if not self.system_site_packages and self.use_dill and "dill" not in requirements:
+            requirements.append("dill")
+        requirements.sort()  # Ensure a hash is stable
+        return requirements
+
+    def _prepare_venv(self, venv_path: Path) -> None:
+        """Prepares the requirements and installs the venv."""
+        requirements_file = venv_path / "requirements.txt"
+        requirements_file.write_text("\n".join(self._requirements_list()))
+        prepare_virtualenv(
+            venv_directory=str(venv_path),
+            python_bin=f"python{self.python_version}" if self.python_version else "python",
+            system_site_packages=self.system_site_packages,
+            requirements_file_path=str(requirements_file),
+            pip_install_options=self.pip_install_options,
+            index_urls=self.index_urls,
+        )
+
     def execute_callable(self):
         with TemporaryDirectory(prefix="venv") as tmp_dir:
             tmp_path = Path(tmp_dir)
-            requirements_file_name = f"{tmp_dir}/requirements.txt"
-
-            if not isinstance(self.requirements, str):
-                requirements_file_contents = "\n".join(str(dependency) for dependency in self.requirements)
-            else:
-                requirements_file_contents = self.requirements
-
-            if not self.system_site_packages and self.use_dill:
-                requirements_file_contents += "\ndill"
-
-            with open(requirements_file_name, "w") as file:
-                file.write(requirements_file_contents)
-            prepare_virtualenv(
-                venv_directory=tmp_dir,
-                python_bin=f"python{self.python_version}" if self.python_version else None,
-                system_site_packages=self.system_site_packages,
-                requirements_file_path=requirements_file_name,
-                pip_install_options=self.pip_install_options,
-            )
+            self._prepare_venv(tmp_path)
             python_path = tmp_path / "bin" / "python"
             result = self._execute_python_callable_in_subprocess(python_path, tmp_path)
             return result
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index e25e82c265..19867f052c 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -37,6 +37,7 @@ from packaging.utils import canonicalize_name
 
 from airflow.exceptions import AirflowOptionalProviderFeatureException
 from airflow.hooks.filesystem import FSHook
+from airflow.hooks.package_index import PackageIndexHook
 from airflow.typing_compat import Literal
 from airflow.utils import yaml
 from airflow.utils.entry_points import entry_points_with_dist
@@ -458,7 +459,7 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
                 connection_type=None,
                 connection_testable=False,
             )
-        for cls in [FSHook]:
+        for cls in [FSHook, PackageIndexHook]:
             package_name = cls.__module__
             hook_class_name = f"{cls.__module__}.{cls.__name__}"
             hook_info = self._import_hook(
diff --git a/airflow/utils/python_virtualenv.py b/airflow/utils/python_virtualenv.py
index 1adabacaff..be8bbe0d22 100644
--- a/airflow/utils/python_virtualenv.py
+++ b/airflow/utils/python_virtualenv.py
@@ -21,6 +21,7 @@ from __future__ import annotations
 import os
 import sys
 import warnings
+from pathlib import Path
 
 import jinja2
 
@@ -51,6 +52,16 @@ def _generate_pip_install_cmd_from_list(
     return cmd + requirements
 
 
+def _generate_pip_conf(conf_file: Path, index_urls: list[str]) -> None:
+    if len(index_urls) == 0:
+        pip_conf_options = "no-index = true"
+    else:
+        pip_conf_options = f"index-url = {index_urls[0]}"
+        if len(index_urls) > 1:
+            pip_conf_options += f"\nextra-index-url = {' '.join(x for x in index_urls[1:])}"
+    conf_file.write_text(f"[global]\n{pip_conf_options}")
+
+
 def remove_task_decorator(python_source: str, task_decorator_name: str) -> str:
     warnings.warn(
         "Import remove_task_decorator from airflow.utils.decorators instead",
@@ -67,6 +78,7 @@ def prepare_virtualenv(
     requirements: list[str] | None = None,
     requirements_file_path: str | None = None,
     pip_install_options: list[str] | None = None,
+    index_urls: list[str] | None = None,
 ) -> str:
     """Creates a virtual environment and installs the additional python packages.
 
@@ -76,11 +88,18 @@ def prepare_virtualenv(
         See virtualenv documentation for more information.
     :param requirements: List of additional python packages.
     :param requirements_file_path: Path to the ``requirements.txt`` file.
+    :param pip_install_options: a list of pip install options when installing requirements
+        See 'pip install -h' for available options
+    :param index_urls: an optional list of index urls to load Python packages from.
+        If not provided the system pip conf will be used to source packages from.
     :return: Path to a binary file with Python in a virtual environment.
     """
     if pip_install_options is None:
         pip_install_options = []
 
+    if index_urls is not None:
+        _generate_pip_conf(Path(venv_directory) / "pip.conf", index_urls)
+
     virtualenv_cmd = _generate_virtualenv_cmd(venv_directory, python_bin, system_site_packages)
     execute_in_subprocess(virtualenv_cmd)
 
diff --git a/docs/apache-airflow/howto/operator/python.rst b/docs/apache-airflow/howto/operator/python.rst
index e2bf4b11e5..2e01a93743 100644
--- a/docs/apache-airflow/howto/operator/python.rst
+++ b/docs/apache-airflow/howto/operator/python.rst
@@ -112,6 +112,20 @@ If additional parameters for package installation are needed pass them in ``requ
 
 All supported options are listed in the `requirements file format <https://pip.pypa.io/en/stable/reference/requirements-file-format/#supported-options>`_.
 
+Virtualenv setup options
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+The virtualenv is created based on the global python pip configuration on your worker. Using additional ENVs in your environment or adjustments in the general
+pip configuration as described in `pip config <https://pip.pypa.io/en/stable/topics/configuration/>`_.
+
+If you want to use additional task specific private python repositories to setup the virtualenv, you can pass the ``index_urls`` parameter which will adjust the
+pip install configurations. Passed index urls replace the standard system configured index url settings.
+To prevent adding secrets to the private repository in your DAG code you can use the Airflow
+:doc:`../../authoring-and-scheduling/connections`. For this purpose the connection type ``Package Index (Python)`` can be used.
+
+In the special case you want to prevent remote calls for setup of a virtualenv, pass the ``index_urls`` as empty list as ``index_urls=[]`` which
+forced pip installer to use the ``--no-index`` option.
+
 
 .. _howto/operator:ExternalPythonOperator:
 
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index fbca728f8e..cd62151566 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1617,6 +1617,7 @@ Url
 url
 urlencoded
 urlparse
+urls
 useHCatalog
 useLegacySQL
 useQueryCache
@@ -1637,6 +1638,7 @@ vCPU
 ve
 vendored
 Vendorize
+venv
 venvs
 versionable
 Vertica
@@ -1646,6 +1648,7 @@ Vevo
 videointelligence
 VideoIntelligenceServiceClient
 virtualenv
+virtualenvs
 vm
 VolumeMount
 volumeMounts
diff --git a/tests/hooks/test_package_index.py b/tests/hooks/test_package_index.py
new file mode 100644
index 0000000000..375898eb06
--- /dev/null
+++ b/tests/hooks/test_package_index.py
@@ -0,0 +1,122 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test for Package Index Hook."""
+from __future__ import annotations
+
+from pytest import FixtureRequest, MonkeyPatch, fixture, mark, raises
+
+from airflow.hooks.package_index import PackageIndexHook
+from airflow.models.connection import Connection
+
+
+class MockConnection(Connection):
+    """Mock for the Connection class."""
+
+    def __init__(self, host: str | None, login: str | None, password: str | None):
+        super().__init__()
+        self.host = host
+        self.login = login
+        self.password = password
+
+
+PI_MOCK_TESTDATA = {
+    "missing-url": {},
+    "anonymous-https": {
+        "host": "https://site/path",
+        "expected_result": "https://site/path",
+    },
+    "no_password-http": {
+        "host": "http://site/path",
+        "login": "any_user",
+        "expected_result": "http://any_user@site/path",
+    },
+    "with_password-http": {
+        "host": "http://site/path",
+        "login": "any_user",
+        "password": "secret@_%1234!",
+        "expected_result": "http://any_user:secret%40_%251234%21@site/path",
+    },
+    "with_password-https": {
+        "host": "https://old_user:pass@site/path",
+        "login": "any_user",
+        "password": "secret@_%1234!",
+        "expected_result": "https://any_user:secret%40_%251234%21@site/path",
+    },
+}
+
+
+@fixture(
+    params=list(PI_MOCK_TESTDATA.values()),
+    ids=list(PI_MOCK_TESTDATA.keys()),
+)
+def mock_get_connection(monkeypatch: MonkeyPatch, request: FixtureRequest) -> str | None:
+    """Pytest Fixture."""
+    testdata: dict[str, str | None] = request.param
+    host: str | None = testdata.get("host", None)
+    login: str | None = testdata.get("login", None)
+    password: str | None = testdata.get("password", None)
+    expected_result: str | None = testdata.get("expected_result", None)
+    monkeypatch.setattr(
+        "airflow.hooks.package_index.PackageIndexHook.get_connection",
+        lambda *_: MockConnection(host, login, password),
+    )
+    return expected_result
+
+
+def test_get_connection_url(mock_get_connection: str | None):
+    """Test if connection url is assembled correctly from credentials and index_url."""
+    expected_result = mock_get_connection
+    hook_instance = PackageIndexHook()
+    if expected_result:
+        connection_url = hook_instance.get_connection_url()
+        assert connection_url == expected_result
+    else:
+        with raises(Exception):
+            hook_instance.get_connection_url()
+
+
+@mark.parametrize("success", [0, 1])
+def test_test_connection(monkeypatch: MonkeyPatch, mock_get_connection: str | None, success: int):
+    """Test if connection test responds correctly to return code."""
+
+    def mock_run(*_, **__):
+        class MockProc:
+            """Mock class."""
+
+            returncode = success
+            stderr = "some error text"
+
+        return MockProc()
+
+    monkeypatch.setattr("airflow.hooks.package_index.subprocess.run", mock_run)
+
+    hook_instance = PackageIndexHook()
+    if mock_get_connection:
+        result = hook_instance.test_connection()
+        assert result[0] == (success == 0)
+    else:
+        with raises(Exception):
+            hook_instance.test_connection()
+
+
+def test_get_ui_field_behaviour():
+    """Tests UI field result structure"""
+    ui_field_behavior = PackageIndexHook.get_ui_field_behaviour()
+    assert "hidden_fields" in ui_field_behavior
+    assert "relabeling" in ui_field_behavior
+    assert "placeholders" in ui_field_behavior
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index 986f3ae510..0dcdac81b5 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -931,6 +931,7 @@ class TestPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
             pip_install_options=["--no-deps"],
         )
         mocked_prepare_virtualenv.assert_called_with(
+            index_urls=None,
             venv_directory=mock.ANY,
             python_bin=mock.ANY,
             system_site_packages=False,
@@ -971,6 +972,18 @@ class TestPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
 
         self.run_as_task(f, system_site_packages=False, use_dill=False, op_args=[4])
 
+    def test_with_index_urls(self):
+        def f(a):
+            import sys
+            from pathlib import Path
+
+            pip_conf = (Path(sys.executable).parent.parent / "pip.conf").read_text()
+            assert "abc.def.de" in pip_conf
+            assert "xyz.abc.de" in pip_conf
+            return a
+
+        self.run_as_task(f, index_urls=["https://abc.def.de", "http://xyz.abc.de"], op_args=[4])
+
     # This tests might take longer than default 60 seconds as it is serializing a lot of
     # context using dill (which is slow apparently).
     @pytest.mark.execution_timeout(120)
diff --git a/tests/utils/test_python_virtualenv.py b/tests/utils/test_python_virtualenv.py
index 6b492d8a1c..ea11f979d7 100644
--- a/tests/utils/test_python_virtualenv.py
+++ b/tests/utils/test_python_virtualenv.py
@@ -18,13 +18,48 @@
 from __future__ import annotations
 
 import sys
+from pathlib import Path
 from unittest import mock
 
+import pytest
+
 from airflow.utils.decorators import remove_task_decorator
-from airflow.utils.python_virtualenv import prepare_virtualenv
+from airflow.utils.python_virtualenv import _generate_pip_conf, prepare_virtualenv
 
 
 class TestPrepareVirtualenv:
+    @pytest.mark.parametrize(
+        ("index_urls", "expected_pip_conf_content", "unexpected_pip_conf_content"),
+        [
+            [[], ["[global]", "no-index ="], ["index-url", "extra", "http", "pypi"]],
+            [["http://mysite"], ["[global]", "index-url", "http://mysite"], ["no-index", "extra", "pypi"]],
+            [
+                ["http://mysite", "https://othersite"],
+                ["[global]", "index-url", "http://mysite", "extra", "https://othersite"],
+                ["no-index", "pypi"],
+            ],
+            [
+                ["http://mysite", "https://othersite", "http://site"],
+                ["[global]", "index-url", "http://mysite", "extra", "https://othersite http://site"],
+                ["no-index", "pypi"],
+            ],
+        ],
+    )
+    def test_generate_pip_conf(
+        self,
+        index_urls: list[str],
+        expected_pip_conf_content: list[str],
+        unexpected_pip_conf_content: list[str],
+        tmp_path: Path,
+    ):
+        tmp_file = tmp_path / "pip.conf"
+        _generate_pip_conf(tmp_file, index_urls)
+        generated_conf = tmp_file.read_text()
+        for term in expected_pip_conf_content:
+            assert term in generated_conf
+        for term in unexpected_pip_conf_content:
+            assert term not in generated_conf
+
     @mock.patch("airflow.utils.python_virtualenv.execute_in_subprocess")
     def test_should_create_virtualenv(self, mock_execute_in_subprocess):
         python_bin = prepare_virtualenv(