You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/08/03 17:29:20 UTC

[airflow] branch main updated: Validate database URL passed to create_engine of Drill hook's connection (#33074)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 394a727ac2 Validate database URL passed to create_engine of Drill hook's connection (#33074)
394a727ac2 is described below

commit 394a727ac2c18d58978bf186a7a92923460ec110
Author: Pankaj Koti <pa...@gmail.com>
AuthorDate: Thu Aug 3 22:59:13 2023 +0530

    Validate database URL passed to create_engine of Drill hook's connection (#33074)
    
    The database URL passed as an argument to the create_engine should
    not contain query parameters as it is not intended.
---
 airflow/providers/apache/drill/hooks/drill.py    | 19 +++++++++++++------
 tests/providers/apache/drill/hooks/test_drill.py |  4 +---
 2 files changed, 14 insertions(+), 9 deletions(-)

diff --git a/airflow/providers/apache/drill/hooks/drill.py b/airflow/providers/apache/drill/hooks/drill.py
index ab15ba6b66..2f2bfa3273 100644
--- a/airflow/providers/apache/drill/hooks/drill.py
+++ b/airflow/providers/apache/drill/hooks/drill.py
@@ -49,13 +49,14 @@ class DrillHook(DbApiHook):
         """Establish a connection to Drillbit."""
         conn_md = self.get_connection(getattr(self, self.conn_name_attr))
         creds = f"{conn_md.login}:{conn_md.password}@" if conn_md.login else ""
-        if "/" in conn_md.host or "&" in conn_md.host:
-            raise ValueError("Drill host should not contain '/&' characters")
-        engine = create_engine(
-            f'{conn_md.extra_dejson.get("dialect_driver", "drill+sadrill")}://{creds}'
+        database_url = (
+            f"{conn_md.extra_dejson.get('dialect_driver', 'drill+sadrill')}://{creds}"
             f"{conn_md.host}:{conn_md.port}/"
             f'{conn_md.extra_dejson.get("storage_plugin", "dfs")}'
         )
+        if "?" in database_url:
+            raise ValueError("Drill database_url should not contain a '?'")
+        engine = create_engine(database_url)
 
         self.log.info(
             "Connected to the Drillbit at %s:%s as user %s", conn_md.host, conn_md.port, conn_md.login
@@ -77,10 +78,16 @@ class DrillHook(DbApiHook):
         storage_plugin = conn_md.extra_dejson.get("storage_plugin", "dfs")
         return f"{conn_type}://{host}/{storage_plugin}?dialect_driver={dialect_driver}"
 
-    def set_autocommit(self, conn: Connection, autocommit: bool) -> NotImplementedError:
+    # The superclass DbApiHook's method implementation has a return type `None` and mypy fails saying
+    # return type `NotImplementedError` is incompatible with it. Hence, we ignore the mypy error here.
+    def set_autocommit(  # type: ignore[override]
+        self, conn: Connection, autocommit: bool
+    ) -> NotImplementedError:
         raise NotImplementedError("There are no transactions in Drill.")
 
-    def insert_rows(
+    # The superclass DbApiHook's method implementation has a return type `None` and mypy fails saying
+    # return type `NotImplementedError` is incompatible with it. Hence, we ignore the mypy error here.
+    def insert_rows(  # type: ignore[override]
         self,
         table: str,
         rows: Iterable[tuple[str]],
diff --git a/tests/providers/apache/drill/hooks/test_drill.py b/tests/providers/apache/drill/hooks/test_drill.py
index 241f50fce5..bfedffd3d7 100644
--- a/tests/providers/apache/drill/hooks/test_drill.py
+++ b/tests/providers/apache/drill/hooks/test_drill.py
@@ -24,9 +24,7 @@ import pytest
 from airflow.providers.apache.drill.hooks.drill import DrillHook
 
 
-@pytest.mark.parametrize(
-    "host, expect_error", [("host_with/", True), ("host_with&", True), ("good_host", False)]
-)
+@pytest.mark.parametrize("host, expect_error", [("host_with?", True), ("good_host", False)])
 def test_get_host(host, expect_error):
     with patch(
         "airflow.providers.apache.drill.hooks.drill.DrillHook.get_connection"