You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/02/24 09:48:09 UTC

[airflow] branch main updated: Support multiple mount points in Vault backend secret (#29734)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new dff425bc3d Support multiple mount points in Vault backend secret (#29734)
dff425bc3d is described below

commit dff425bc3d92697bb447010aa9f3b56519a59f1e
Author: Hussein Awala <ho...@gmail.com>
AuthorDate: Fri Feb 24 10:48:01 2023 +0100

    Support multiple mount points in Vault backend secret (#29734)
---
 .../hashicorp/_internal_client/vault_client.py     |  40 +++--
 airflow/providers/hashicorp/secrets/vault.py       |  41 +++--
 .../_internal_client/test_vault_client.py          |  79 ++++++++-
 tests/providers/hashicorp/secrets/test_vault.py    | 179 +++++++++++++++++++++
 4 files changed, 310 insertions(+), 29 deletions(-)

diff --git a/airflow/providers/hashicorp/_internal_client/vault_client.py b/airflow/providers/hashicorp/_internal_client/vault_client.py
index 076a869666..ea8aaf0071 100644
--- a/airflow/providers/hashicorp/_internal_client/vault_client.py
+++ b/airflow/providers/hashicorp/_internal_client/vault_client.py
@@ -89,7 +89,7 @@ class _VaultClient(LoggingMixin):
         url: str | None = None,
         auth_type: str = "token",
         auth_mount_point: str | None = None,
-        mount_point: str = "secret",
+        mount_point: str | None = "secret",
         kv_engine_version: int | None = None,
         token: str | None = None,
         token_path: str | None = None,
@@ -324,6 +324,15 @@ class _VaultClient(LoggingMixin):
         else:
             _client.token = self.token
 
+    def _parse_secret_path(self, secret_path: str) -> tuple[str, str]:
+        if not self.mount_point:
+            split_secret_path = secret_path.split("/", 1)
+            if len(split_secret_path) < 2:
+                raise InvalidPath
+            return split_secret_path[0], split_secret_path[1]
+        else:
+            return self.mount_point, secret_path
+
     def get_secret(self, secret_path: str, secret_version: int | None = None) -> dict | None:
         """
         Get secret value from the KV engine.
@@ -337,19 +346,19 @@ class _VaultClient(LoggingMixin):
 
         :return: secret stored in the vault as a dictionary
         """
+        mount_point = None
         try:
+            mount_point, secret_path = self._parse_secret_path(secret_path)
             if self.kv_engine_version == 1:
                 if secret_version:
                     raise VaultError("Secret version can only be used with version 2 of the KV engine")
-                response = self.client.secrets.kv.v1.read_secret(
-                    path=secret_path, mount_point=self.mount_point
-                )
+                response = self.client.secrets.kv.v1.read_secret(path=secret_path, mount_point=mount_point)
             else:
                 response = self.client.secrets.kv.v2.read_secret_version(
-                    path=secret_path, mount_point=self.mount_point, version=secret_version
+                    path=secret_path, mount_point=mount_point, version=secret_version
                 )
         except InvalidPath:
-            self.log.debug("Secret not found %s with mount point %s", secret_path, self.mount_point)
+            self.log.debug("Secret not found %s with mount point %s", secret_path, mount_point)
             return None
 
         return_data = response["data"] if self.kv_engine_version == 1 else response["data"]["data"]
@@ -367,12 +376,12 @@ class _VaultClient(LoggingMixin):
         """
         if self.kv_engine_version == 1:
             raise VaultError("Metadata might only be used with version 2 of the KV engine.")
+        mount_point = None
         try:
-            return self.client.secrets.kv.v2.read_secret_metadata(
-                path=secret_path, mount_point=self.mount_point
-            )
+            mount_point, secret_path = self._parse_secret_path(secret_path)
+            return self.client.secrets.kv.v2.read_secret_metadata(path=secret_path, mount_point=mount_point)
         except InvalidPath:
-            self.log.debug("Secret not found %s with mount point %s", secret_path, self.mount_point)
+            self.log.debug("Secret not found %s with mount point %s", secret_path, mount_point)
             return None
 
     def get_secret_including_metadata(
@@ -391,15 +400,17 @@ class _VaultClient(LoggingMixin):
         """
         if self.kv_engine_version == 1:
             raise VaultError("Metadata might only be used with version 2 of the KV engine.")
+        mount_point = None
         try:
+            mount_point, secret_path = self._parse_secret_path(secret_path)
             return self.client.secrets.kv.v2.read_secret_version(
-                path=secret_path, mount_point=self.mount_point, version=secret_version
+                path=secret_path, mount_point=mount_point, version=secret_version
             )
         except InvalidPath:
             self.log.debug(
                 "Secret not found %s with mount point %s and version %s",
                 secret_path,
-                self.mount_point,
+                mount_point,
                 secret_version,
             )
             return None
@@ -429,12 +440,13 @@ class _VaultClient(LoggingMixin):
             raise VaultError("The method parameter is only valid for version 1")
         if self.kv_engine_version == 1 and cas:
             raise VaultError("The cas parameter is only valid for version 2")
+        mount_point, secret_path = self._parse_secret_path(secret_path)
         if self.kv_engine_version == 1:
             response = self.client.secrets.kv.v1.create_or_update_secret(
-                secret_path=secret_path, secret=secret, mount_point=self.mount_point, method=method
+                secret_path=secret_path, secret=secret, mount_point=mount_point, method=method
             )
         else:
             response = self.client.secrets.kv.v2.create_or_update_secret(
-                secret_path=secret_path, secret=secret, mount_point=self.mount_point, cas=cas
+                secret_path=secret_path, secret=secret, mount_point=mount_point, cas=cas
             )
         return response
diff --git a/airflow/providers/hashicorp/secrets/vault.py b/airflow/providers/hashicorp/secrets/vault.py
index 9c22ff71d6..79943aacd7 100644
--- a/airflow/providers/hashicorp/secrets/vault.py
+++ b/airflow/providers/hashicorp/secrets/vault.py
@@ -59,7 +59,8 @@ class VaultBackend(BaseSecretsBackend, LoggingMixin):
           Default depends on the authentication method used.
     :param mount_point: The "path" the secret engine was mounted on. Default is "secret". Note that
          this mount_point is not used for authentication if authentication is done via a
-         different engine. For authentication mount_points see, auth_mount_point.
+         different engine. If set to None, the mount secret should be provided as a prefix for each
+         variable/connection_id. For authentication mount_points see, auth_mount_point.
     :param kv_engine_version: Select the version of the engine to run (``1`` or ``2``, default: ``2``).
     :param token: Authentication token to include in requests sent to Vault.
         (for ``token`` and ``github`` auth_type)
@@ -94,7 +95,7 @@ class VaultBackend(BaseSecretsBackend, LoggingMixin):
         url: str | None = None,
         auth_type: str = "token",
         auth_mount_point: str | None = None,
-        mount_point: str = "secret",
+        mount_point: str | None = "secret",
         kv_engine_version: int = 2,
         token: str | None = None,
         token_path: str | None = None,
@@ -156,17 +157,29 @@ class VaultBackend(BaseSecretsBackend, LoggingMixin):
             **kwargs,
         )
 
+    def _parse_path(self, secret_path: str) -> tuple[str | None, str | None]:
+        if not self.mount_point:
+            split_secret_path = secret_path.split("/", 1)
+            if len(split_secret_path) < 2:
+                return None, None
+            return split_secret_path[0], split_secret_path[1]
+        else:
+            return "", secret_path
+
     def get_response(self, conn_id: str) -> dict | None:
         """
         Get data from Vault
 
         :return: The data from the Vault path if exists
         """
-        if self.connections_path is None:
+        mount_point, conn_key = self._parse_path(conn_id)
+        if self.connections_path is None or conn_key is None:
             return None
 
-        secret_path = self.build_path(self.connections_path, conn_id)
-        return self.vault_client.get_secret(secret_path=secret_path)
+        secret_path = self.build_path(self.connections_path, conn_key)
+        return self.vault_client.get_secret(
+            secret_path=(mount_point + "/" if mount_point else "") + secret_path
+        )
 
     def get_conn_uri(self, conn_id: str) -> str | None:
         """
@@ -219,11 +232,14 @@ class VaultBackend(BaseSecretsBackend, LoggingMixin):
         :param key: Variable Key
         :return: Variable Value retrieved from the vault
         """
-        if self.variables_path is None:
+        mount_point, variable_key = self._parse_path(key)
+        if self.variables_path is None or variable_key is None:
             return None
         else:
-            secret_path = self.build_path(self.variables_path, key)
-            response = self.vault_client.get_secret(secret_path=secret_path)
+            secret_path = self.build_path(self.variables_path, variable_key)
+            response = self.vault_client.get_secret(
+                secret_path=(mount_point + "/" if mount_point else "") + secret_path
+            )
             return response.get("value") if response else None
 
     def get_config(self, key: str) -> str | None:
@@ -233,9 +249,12 @@ class VaultBackend(BaseSecretsBackend, LoggingMixin):
         :param key: Configuration Option Key
         :return: Configuration Option Value retrieved from the vault
         """
-        if self.config_path is None:
+        mount_point, config_key = self._parse_path(key)
+        if self.config_path is None or config_key is None:
             return None
         else:
-            secret_path = self.build_path(self.config_path, key)
-            response = self.vault_client.get_secret(secret_path=secret_path)
+            secret_path = self.build_path(self.config_path, config_key)
+            response = self.vault_client.get_secret(
+                secret_path=(mount_point + "/" if mount_point else "") + secret_path
+            )
             return response.get("value") if response else None
diff --git a/tests/providers/hashicorp/_internal_client/test_vault_client.py b/tests/providers/hashicorp/_internal_client/test_vault_client.py
index 1bab652dc0..6fd8dcf6a7 100644
--- a/tests/providers/hashicorp/_internal_client/test_vault_client.py
+++ b/tests/providers/hashicorp/_internal_client/test_vault_client.py
@@ -661,10 +661,48 @@ class TestVaultClient:
             radius_secret="pass",
             url="http://localhost:8180",
         )
-        secret = vault_client.get_secret(secret_path="missing")
+        secret = vault_client.get_secret(secret_path="path/to/secret")
         assert {"secret_key": "secret_value"} == secret
         mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with(
-            mount_point="secret", path="missing", version=None
+            mount_point="secret", path="path/to/secret", version=None
+        )
+
+    @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
+    def test_get_existing_key_v2_without_preconfigured_mount_point(self, mock_hvac):
+        mock_client = mock.MagicMock()
+        mock_hvac.Client.return_value = mock_client
+
+        mock_client.secrets.kv.v2.read_secret_version.return_value = {
+            "request_id": "94011e25-f8dc-ec29-221b-1f9c1d9ad2ae",
+            "lease_id": "",
+            "renewable": False,
+            "lease_duration": 0,
+            "data": {
+                "data": {"secret_key": "secret_value"},
+                "metadata": {
+                    "created_time": "2020-03-16T21:01:43.331126Z",
+                    "deletion_time": "",
+                    "destroyed": False,
+                    "version": 1,
+                },
+            },
+            "wrap_info": None,
+            "warnings": None,
+            "auth": None,
+        }
+
+        vault_client = _VaultClient(
+            auth_type="radius",
+            radius_host="radhost",
+            radius_port=8110,
+            radius_secret="pass",
+            url="http://localhost:8180",
+            mount_point=None,
+        )
+        secret = vault_client.get_secret(secret_path="mount_point/path/to/secret")
+        assert {"secret_key": "secret_value"} == secret
+        mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with(
+            mount_point="mount_point", path="path/to/secret", version=None
         )
 
     @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -728,9 +766,42 @@ class TestVaultClient:
             kv_engine_version=1,
             url="http://localhost:8180",
         )
-        secret = vault_client.get_secret(secret_path="missing")
+        secret = vault_client.get_secret(secret_path="/path/to/secret")
         assert {"value": "world"} == secret
-        mock_client.secrets.kv.v1.read_secret.assert_called_once_with(mount_point="secret", path="missing")
+        mock_client.secrets.kv.v1.read_secret.assert_called_once_with(
+            mount_point="secret", path="/path/to/secret"
+        )
+
+    @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
+    def test_get_existing_key_v1_without_preconfigured_mount_point(self, mock_hvac):
+        mock_client = mock.MagicMock()
+        mock_hvac.Client.return_value = mock_client
+
+        mock_client.secrets.kv.v1.read_secret.return_value = {
+            "request_id": "182d0673-618c-9889-4cba-4e1f4cfe4b4b",
+            "lease_id": "",
+            "renewable": False,
+            "lease_duration": 2764800,
+            "data": {"value": "world"},
+            "wrap_info": None,
+            "warnings": None,
+            "auth": None,
+        }
+
+        vault_client = _VaultClient(
+            auth_type="radius",
+            radius_host="radhost",
+            radius_port=8110,
+            radius_secret="pass",
+            kv_engine_version=1,
+            url="http://localhost:8180",
+            mount_point=None,
+        )
+        secret = vault_client.get_secret(secret_path="mount_point/path/to/secret")
+        assert {"value": "world"} == secret
+        mock_client.secrets.kv.v1.read_secret.assert_called_once_with(
+            mount_point="mount_point", path="path/to/secret"
+        )
 
     @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
     def test_get_existing_key_v1_different_auth_mount_point(self, mock_hvac):
diff --git a/tests/providers/hashicorp/secrets/test_vault.py b/tests/providers/hashicorp/secrets/test_vault.py
index a29e6dc21e..309dbd9d6a 100644
--- a/tests/providers/hashicorp/secrets/test_vault.py
+++ b/tests/providers/hashicorp/secrets/test_vault.py
@@ -60,6 +60,41 @@ class TestVaultSecrets:
         returned_uri = test_client.get_conn_uri(conn_id="test_postgres")
         assert "postgresql://airflow:airflow@host:5432/airflow" == returned_uri
 
+    @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
+    def test_get_conn_uri_without_predefined_mount_point(self, mock_hvac):
+        mock_client = mock.MagicMock()
+        mock_hvac.Client.return_value = mock_client
+        mock_client.secrets.kv.v2.read_secret_version.return_value = {
+            "request_id": "94011e25-f8dc-ec29-221b-1f9c1d9ad2ae",
+            "lease_id": "",
+            "renewable": False,
+            "lease_duration": 0,
+            "data": {
+                "data": {"conn_uri": "postgresql://airflow:airflow@host:5432/airflow"},
+                "metadata": {
+                    "created_time": "2020-03-16T21:01:43.331126Z",
+                    "deletion_time": "",
+                    "destroyed": False,
+                    "version": 1,
+                },
+            },
+            "wrap_info": None,
+            "warnings": None,
+            "auth": None,
+        }
+
+        kwargs = {
+            "connections_path": "connections",
+            "mount_point": None,
+            "auth_type": "token",
+            "url": "http://127.0.0.1:8200",
+            "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS",
+        }
+
+        test_client = VaultBackend(**kwargs)
+        returned_uri = test_client.get_conn_uri(conn_id="airflow/test_postgres")
+        assert "postgresql://airflow:airflow@host:5432/airflow" == returned_uri
+
     @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
     def test_get_connection(self, mock_hvac):
         mock_client = mock.MagicMock()
@@ -103,6 +138,49 @@ class TestVaultSecrets:
         connection = test_client.get_connection(conn_id="test_postgres")
         assert "postgresql://airflow:airflow@host:5432/airflow?foo=bar&baz=taz" == connection.get_uri()
 
+    @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
+    def test_get_connection_without_predefined_mount_point(self, mock_hvac):
+        mock_client = mock.MagicMock()
+        mock_hvac.Client.return_value = mock_client
+        mock_client.secrets.kv.v2.read_secret_version.return_value = {
+            "request_id": "94011e25-f8dc-ec29-221b-1f9c1d9ad2ae",
+            "lease_id": "",
+            "renewable": False,
+            "lease_duration": 0,
+            "data": {
+                "data": {
+                    "conn_type": "postgresql",
+                    "login": "airflow",
+                    "password": "airflow",
+                    "host": "host",
+                    "port": "5432",
+                    "schema": "airflow",
+                    "extra": '{"foo":"bar","baz":"taz"}',
+                },
+                "metadata": {
+                    "created_time": "2020-03-16T21:01:43.331126Z",
+                    "deletion_time": "",
+                    "destroyed": False,
+                    "version": 1,
+                },
+            },
+            "wrap_info": None,
+            "warnings": None,
+            "auth": None,
+        }
+
+        kwargs = {
+            "connections_path": "connections",
+            "mount_point": None,
+            "auth_type": "token",
+            "url": "http://127.0.0.1:8200",
+            "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS",
+        }
+
+        test_client = VaultBackend(**kwargs)
+        connection = test_client.get_connection(conn_id="airflow/test_postgres")
+        assert "postgresql://airflow:airflow@host:5432/airflow?foo=bar&baz=taz" == connection.get_uri()
+
     @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
     def test_get_conn_uri_engine_version_1(self, mock_hvac):
         mock_client = mock.MagicMock()
@@ -234,6 +312,41 @@ class TestVaultSecrets:
         returned_uri = test_client.get_variable("hello")
         assert "world" == returned_uri
 
+    @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
+    def test_get_variable_value_without_predefined_mount_point(self, mock_hvac):
+        mock_client = mock.MagicMock()
+        mock_hvac.Client.return_value = mock_client
+        mock_client.secrets.kv.v2.read_secret_version.return_value = {
+            "request_id": "2d48a2ad-6bcb-e5b6-429d-da35fdf31f56",
+            "lease_id": "",
+            "renewable": False,
+            "lease_duration": 0,
+            "data": {
+                "data": {"value": "world"},
+                "metadata": {
+                    "created_time": "2020-03-28T02:10:54.301784Z",
+                    "deletion_time": "",
+                    "destroyed": False,
+                    "version": 1,
+                },
+            },
+            "wrap_info": None,
+            "warnings": None,
+            "auth": None,
+        }
+
+        kwargs = {
+            "variables_path": "variables",
+            "mount_point": None,
+            "auth_type": "token",
+            "url": "http://127.0.0.1:8200",
+            "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS",
+        }
+
+        test_client = VaultBackend(**kwargs)
+        returned_uri = test_client.get_variable("airflow/hello")
+        assert "world" == returned_uri
+
     @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
     def test_get_variable_value_engine_version_1(self, mock_hvac):
         mock_client = mock.MagicMock()
@@ -265,6 +378,37 @@ class TestVaultSecrets:
         )
         assert "world" == returned_uri
 
+    @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
+    def test_get_variable_value_engine_version_1_without_predefined_mount_point(self, mock_hvac):
+        mock_client = mock.MagicMock()
+        mock_hvac.Client.return_value = mock_client
+        mock_client.secrets.kv.v1.read_secret.return_value = {
+            "request_id": "182d0673-618c-9889-4cba-4e1f4cfe4b4b",
+            "lease_id": "",
+            "renewable": False,
+            "lease_duration": 2764800,
+            "data": {"value": "world"},
+            "wrap_info": None,
+            "warnings": None,
+            "auth": None,
+        }
+
+        kwargs = {
+            "variables_path": "variables",
+            "mount_point": None,
+            "auth_type": "token",
+            "url": "http://127.0.0.1:8200",
+            "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS",
+            "kv_engine_version": 1,
+        }
+
+        test_client = VaultBackend(**kwargs)
+        returned_uri = test_client.get_variable("airflow/hello")
+        mock_client.secrets.kv.v1.read_secret.assert_called_once_with(
+            mount_point="airflow", path="variables/hello"
+        )
+        assert "world" == returned_uri
+
     @mock.patch.dict(
         "os.environ",
         {
@@ -361,6 +505,41 @@ class TestVaultSecrets:
         returned_uri = test_client.get_config("sql_alchemy_conn")
         assert "sqlite:////Users/airflow/airflow/airflow.db" == returned_uri
 
+    @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
+    def test_get_config_value_without_predefined_mount_point(self, mock_hvac):
+        mock_client = mock.MagicMock()
+        mock_hvac.Client.return_value = mock_client
+        mock_client.secrets.kv.v2.read_secret_version.return_value = {
+            "request_id": "2d48a2ad-6bcb-e5b6-429d-da35fdf31f56",
+            "lease_id": "",
+            "renewable": False,
+            "lease_duration": 0,
+            "data": {
+                "data": {"value": "sqlite:////Users/airflow/airflow/airflow.db"},
+                "metadata": {
+                    "created_time": "2020-03-28T02:10:54.301784Z",
+                    "deletion_time": "",
+                    "destroyed": False,
+                    "version": 1,
+                },
+            },
+            "wrap_info": None,
+            "warnings": None,
+            "auth": None,
+        }
+
+        kwargs = {
+            "configs_path": "configurations",
+            "mount_point": None,
+            "auth_type": "token",
+            "url": "http://127.0.0.1:8200",
+            "token": "s.FnL7qg0YnHZDpf4zKKuFy0UK",
+        }
+
+        test_client = VaultBackend(**kwargs)
+        returned_uri = test_client.get_config("airflow/sql_alchemy_conn")
+        assert "sqlite:////Users/airflow/airflow/airflow.db" == returned_uri
+
     @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
     def test_connections_path_none_value(self, mock_hvac):
         mock_client = mock.MagicMock()