You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by hu...@apache.org on 2023/06/19 09:32:48 UTC
[airflow] branch main updated: fix connection uri parsing when the host includes a scheme (#31465)
This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0560881f0e fix connection uri parsing when the host includes a scheme (#31465)
0560881f0e is described below
commit 0560881f0eaef9c583b11e937bf1f79d13e5ac7c
Author: Hussein Awala <hu...@awala.fr>
AuthorDate: Mon Jun 19 11:32:41 2023 +0200
fix connection uri parsing when the host includes a scheme (#31465)
* update _parse_from_uri and get_uri methods, and add tests for connection model
* some fixes from review
---
airflow/models/connection.py | 35 +++++++-
tests/models/test_connection.py | 188 ++++++++++++++++++++++++++++++++++++++++
2 files changed, 220 insertions(+), 3 deletions(-)
diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index a565341209..0bc3ca38d4 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -187,10 +187,22 @@ class Connection(Base, LoggingMixin):
return conn_type
def _parse_from_uri(self, uri: str):
+ schemes_count_in_uri = uri.count("://")
+ if schemes_count_in_uri > 2:
+ raise AirflowException(f"Invalid connection string: {uri}.")
+ host_with_protocol = schemes_count_in_uri == 2
uri_parts = urlsplit(uri)
conn_type = uri_parts.scheme
self.conn_type = self._normalize_conn_type(conn_type)
- self.host = _parse_netloc_to_hostname(uri_parts)
+ rest_of_the_url = uri.replace(f"{conn_type}://", ("" if host_with_protocol else "//"))
+ if host_with_protocol:
+ uri_splits = rest_of_the_url.split("://", 1)
+ if "@" in uri_splits[0] or ":" in uri_splits[0]:
+ raise AirflowException(f"Invalid connection string: {uri}.")
+ uri_parts = urlsplit(rest_of_the_url)
+ protocol = uri_parts.scheme if host_with_protocol else None
+ host = _parse_netloc_to_hostname(uri_parts)
+ self.host = self._create_host(protocol, host)
quoted_schema = uri_parts.path[1:]
self.schema = unquote(quoted_schema) if quoted_schema else quoted_schema
self.login = unquote(uri_parts.username) if uri_parts.username else uri_parts.username
@@ -203,6 +215,15 @@ class Connection(Base, LoggingMixin):
else:
self.extra = json.dumps(query)
+ @staticmethod
+ def _create_host(protocol, host) -> str | None:
+ """Returns the connection host with the protocol."""
+ if not host:
+ return host
+ if protocol:
+ return f"{protocol}://{host}"
+ return host
+
def get_uri(self) -> str:
"""Return connection in URI format."""
if self.conn_type and "_" in self.conn_type:
@@ -216,6 +237,14 @@ class Connection(Base, LoggingMixin):
else:
uri = "//"
+ if self.host and "://" in self.host:
+ protocol, host = self.host.split("://", 1)
+ else:
+ protocol, host = None, self.host
+
+ if protocol:
+ uri += f"{protocol}://"
+
authority_block = ""
if self.login is not None:
authority_block += quote(self.login, safe="")
@@ -229,8 +258,8 @@ class Connection(Base, LoggingMixin):
uri += authority_block
host_block = ""
- if self.host:
- host_block += quote(self.host, safe="")
+ if host:
+ host_block += quote(host, safe="")
if self.port:
if host_block == "" and authority_block == "":
diff --git a/tests/models/test_connection.py b/tests/models/test_connection.py
new file mode 100644
index 0000000000..0223cffb8d
--- /dev/null
+++ b/tests/models/test_connection.py
@@ -0,0 +1,188 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import re
+
+import pytest
+
+from airflow import AirflowException
+from airflow.models import Connection
+
+
+class TestConnection:
+ @pytest.mark.parametrize(
+ "uri, expected_conn_type, expected_host, expected_login, expected_password,"
+ " expected_port, expected_schema, expected_extra_dict, expected_exception_message",
+ [
+ (
+ "type://user:pass@host:100/schema",
+ "type",
+ "host",
+ "user",
+ "pass",
+ 100,
+ "schema",
+ {},
+ None,
+ ),
+ (
+ "type://user:pass@host/schema",
+ "type",
+ "host",
+ "user",
+ "pass",
+ None,
+ "schema",
+ {},
+ None,
+ ),
+ (
+ "type://user:pass@host/schema?param1=val1¶m2=val2",
+ "type",
+ "host",
+ "user",
+ "pass",
+ None,
+ "schema",
+ {"param1": "val1", "param2": "val2"},
+ None,
+ ),
+ (
+ "type://host",
+ "type",
+ "host",
+ None,
+ None,
+ None,
+ "",
+ {},
+ None,
+ ),
+ (
+ "spark://mysparkcluster.com:80?deploy-mode=cluster&spark_binary=command&namespace=kube+namespace",
+ "spark",
+ "mysparkcluster.com",
+ None,
+ None,
+ 80,
+ "",
+ {"deploy-mode": "cluster", "spark_binary": "command", "namespace": "kube namespace"},
+ None,
+ ),
+ (
+ "spark://k8s://100.68.0.1:443?deploy-mode=cluster",
+ "spark",
+ "k8s://100.68.0.1",
+ None,
+ None,
+ 443,
+ "",
+ {"deploy-mode": "cluster"},
+ None,
+ ),
+ (
+ "type://protocol://user:pass@host:123?param=value",
+ "type",
+ "protocol://host",
+ "user",
+ "pass",
+ 123,
+ "",
+ {"param": "value"},
+ None,
+ ),
+ (
+ "type://user:pass@protocol://host:port?param=value",
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ r"Invalid connection string: type://user:pass@protocol://host:port?param=value.",
+ ),
+ ],
+ )
+ def test_parse_from_uri(
+ self,
+ uri,
+ expected_conn_type,
+ expected_host,
+ expected_login,
+ expected_password,
+ expected_port,
+ expected_schema,
+ expected_extra_dict,
+ expected_exception_message,
+ ):
+ if expected_exception_message is not None:
+ with pytest.raises(AirflowException, match=re.escape(expected_exception_message)):
+ Connection(uri=uri)
+ else:
+ conn = Connection(uri=uri)
+ assert conn.conn_type == expected_conn_type
+ assert conn.login == expected_login
+ assert conn.password == expected_password
+ assert conn.host == expected_host
+ assert conn.port == expected_port
+ assert conn.schema == expected_schema
+ assert conn.extra_dejson == expected_extra_dict
+
+ @pytest.mark.parametrize(
+ "connection, expected_uri",
+ [
+ (
+ Connection(
+ conn_type="type",
+ login="user",
+ password="pass",
+ host="host",
+ port=100,
+ schema="schema",
+ extra={"param1": "val1", "param2": "val2"},
+ ),
+ "type://user:pass@host:100/schema?param1=val1¶m2=val2",
+ ),
+ (
+ Connection(
+ conn_type="type",
+ host="protocol://host",
+ port=100,
+ schema="schema",
+ extra={"param1": "val1", "param2": "val2"},
+ ),
+ "type://protocol://host:100/schema?param1=val1¶m2=val2",
+ ),
+ (
+ Connection(
+ conn_type="type",
+ login="user",
+ password="pass",
+ host="protocol://host",
+ port=100,
+ schema="schema",
+ extra={"param1": "val1", "param2": "val2"},
+ ),
+ "type://protocol://user:pass@host:100/schema?param1=val1¶m2=val2",
+ ),
+ ],
+ )
+ def test_get_uri(self, connection, expected_uri):
+ assert connection.get_uri() == expected_uri