You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by hu...@apache.org on 2023/06/19 09:32:48 UTC

[airflow] branch main updated: fix connection uri parsing when the host includes a scheme (#31465)

This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0560881f0e fix connection uri parsing when the host includes a scheme (#31465)
0560881f0e is described below

commit 0560881f0eaef9c583b11e937bf1f79d13e5ac7c
Author: Hussein Awala <hu...@awala.fr>
AuthorDate: Mon Jun 19 11:32:41 2023 +0200

    fix connection uri parsing when the host includes a scheme (#31465)
    
    * update _parse_from_uri and get_uri  methods, and add tests for connection model
    
    * some fixes from review
---
 airflow/models/connection.py    |  35 +++++++-
 tests/models/test_connection.py | 188 ++++++++++++++++++++++++++++++++++++++++
 2 files changed, 220 insertions(+), 3 deletions(-)

diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index a565341209..0bc3ca38d4 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -187,10 +187,22 @@ class Connection(Base, LoggingMixin):
         return conn_type
 
     def _parse_from_uri(self, uri: str):
+        schemes_count_in_uri = uri.count("://")
+        if schemes_count_in_uri > 2:
+            raise AirflowException(f"Invalid connection string: {uri}.")
+        host_with_protocol = schemes_count_in_uri == 2
         uri_parts = urlsplit(uri)
         conn_type = uri_parts.scheme
         self.conn_type = self._normalize_conn_type(conn_type)
-        self.host = _parse_netloc_to_hostname(uri_parts)
+        rest_of_the_url = uri.replace(f"{conn_type}://", ("" if host_with_protocol else "//"))
+        if host_with_protocol:
+            uri_splits = rest_of_the_url.split("://", 1)
+            if "@" in uri_splits[0] or ":" in uri_splits[0]:
+                raise AirflowException(f"Invalid connection string: {uri}.")
+        uri_parts = urlsplit(rest_of_the_url)
+        protocol = uri_parts.scheme if host_with_protocol else None
+        host = _parse_netloc_to_hostname(uri_parts)
+        self.host = self._create_host(protocol, host)
         quoted_schema = uri_parts.path[1:]
         self.schema = unquote(quoted_schema) if quoted_schema else quoted_schema
         self.login = unquote(uri_parts.username) if uri_parts.username else uri_parts.username
@@ -203,6 +215,15 @@ class Connection(Base, LoggingMixin):
             else:
                 self.extra = json.dumps(query)
 
+    @staticmethod
+    def _create_host(protocol, host) -> str | None:
+        """Returns the connection host with the protocol."""
+        if not host:
+            return host
+        if protocol:
+            return f"{protocol}://{host}"
+        return host
+
     def get_uri(self) -> str:
         """Return connection in URI format."""
         if self.conn_type and "_" in self.conn_type:
@@ -216,6 +237,14 @@ class Connection(Base, LoggingMixin):
         else:
             uri = "//"
 
+        if self.host and "://" in self.host:
+            protocol, host = self.host.split("://", 1)
+        else:
+            protocol, host = None, self.host
+
+        if protocol:
+            uri += f"{protocol}://"
+
         authority_block = ""
         if self.login is not None:
             authority_block += quote(self.login, safe="")
@@ -229,8 +258,8 @@ class Connection(Base, LoggingMixin):
             uri += authority_block
 
         host_block = ""
-        if self.host:
-            host_block += quote(self.host, safe="")
+        if host:
+            host_block += quote(host, safe="")
 
         if self.port:
             if host_block == "" and authority_block == "":
diff --git a/tests/models/test_connection.py b/tests/models/test_connection.py
new file mode 100644
index 0000000000..0223cffb8d
--- /dev/null
+++ b/tests/models/test_connection.py
@@ -0,0 +1,188 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import re
+
+import pytest
+
+from airflow import AirflowException
+from airflow.models import Connection
+
+
+class TestConnection:
+    @pytest.mark.parametrize(
+        "uri, expected_conn_type, expected_host, expected_login, expected_password,"
+        " expected_port, expected_schema, expected_extra_dict, expected_exception_message",
+        [
+            (
+                "type://user:pass@host:100/schema",
+                "type",
+                "host",
+                "user",
+                "pass",
+                100,
+                "schema",
+                {},
+                None,
+            ),
+            (
+                "type://user:pass@host/schema",
+                "type",
+                "host",
+                "user",
+                "pass",
+                None,
+                "schema",
+                {},
+                None,
+            ),
+            (
+                "type://user:pass@host/schema?param1=val1&param2=val2",
+                "type",
+                "host",
+                "user",
+                "pass",
+                None,
+                "schema",
+                {"param1": "val1", "param2": "val2"},
+                None,
+            ),
+            (
+                "type://host",
+                "type",
+                "host",
+                None,
+                None,
+                None,
+                "",
+                {},
+                None,
+            ),
+            (
+                "spark://mysparkcluster.com:80?deploy-mode=cluster&spark_binary=command&namespace=kube+namespace",
+                "spark",
+                "mysparkcluster.com",
+                None,
+                None,
+                80,
+                "",
+                {"deploy-mode": "cluster", "spark_binary": "command", "namespace": "kube namespace"},
+                None,
+            ),
+            (
+                "spark://k8s://100.68.0.1:443?deploy-mode=cluster",
+                "spark",
+                "k8s://100.68.0.1",
+                None,
+                None,
+                443,
+                "",
+                {"deploy-mode": "cluster"},
+                None,
+            ),
+            (
+                "type://protocol://user:pass@host:123?param=value",
+                "type",
+                "protocol://host",
+                "user",
+                "pass",
+                123,
+                "",
+                {"param": "value"},
+                None,
+            ),
+            (
+                "type://user:pass@protocol://host:port?param=value",
+                None,
+                None,
+                None,
+                None,
+                None,
+                None,
+                None,
+                r"Invalid connection string: type://user:pass@protocol://host:port?param=value.",
+            ),
+        ],
+    )
+    def test_parse_from_uri(
+        self,
+        uri,
+        expected_conn_type,
+        expected_host,
+        expected_login,
+        expected_password,
+        expected_port,
+        expected_schema,
+        expected_extra_dict,
+        expected_exception_message,
+    ):
+        if expected_exception_message is not None:
+            with pytest.raises(AirflowException, match=re.escape(expected_exception_message)):
+                Connection(uri=uri)
+        else:
+            conn = Connection(uri=uri)
+            assert conn.conn_type == expected_conn_type
+            assert conn.login == expected_login
+            assert conn.password == expected_password
+            assert conn.host == expected_host
+            assert conn.port == expected_port
+            assert conn.schema == expected_schema
+            assert conn.extra_dejson == expected_extra_dict
+
+    @pytest.mark.parametrize(
+        "connection, expected_uri",
+        [
+            (
+                Connection(
+                    conn_type="type",
+                    login="user",
+                    password="pass",
+                    host="host",
+                    port=100,
+                    schema="schema",
+                    extra={"param1": "val1", "param2": "val2"},
+                ),
+                "type://user:pass@host:100/schema?param1=val1&param2=val2",
+            ),
+            (
+                Connection(
+                    conn_type="type",
+                    host="protocol://host",
+                    port=100,
+                    schema="schema",
+                    extra={"param1": "val1", "param2": "val2"},
+                ),
+                "type://protocol://host:100/schema?param1=val1&param2=val2",
+            ),
+            (
+                Connection(
+                    conn_type="type",
+                    login="user",
+                    password="pass",
+                    host="protocol://host",
+                    port=100,
+                    schema="schema",
+                    extra={"param1": "val1", "param2": "val2"},
+                ),
+                "type://protocol://user:pass@host:100/schema?param1=val1&param2=val2",
+            ),
+        ],
+    )
+    def test_get_uri(self, connection, expected_uri):
+        assert connection.get_uri() == expected_uri