You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/04/15 11:28:49 UTC

[airflow] 08/08: Import Connection lazily in hooks to avoid cycles (#15361)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit ef876cfe33e723110b5fb9fb527d874eb757ab32
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Wed Apr 14 21:33:00 2021 +0800

    Import Connection lazily in hooks to avoid cycles (#15361)
    
    The current implementation imports Connection on import time, which
    causes a circular import when a model class needs to reference a hook
    class.
    
    By applying this fix, the airflow.hooks package is completely decoupled
    with airflow.models on import time, so model code can reference hooks.
    Hooks, on the other hand, generally don't reference model classes.
    
    Fix #15325.
    
    (cherry picked from commit 75603160848e4199ed368809dfd441dcc5ddbd82)
---
 airflow/hooks/base.py | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/airflow/hooks/base.py b/airflow/hooks/base.py
index b3c0c11..dee76dc 100644
--- a/airflow/hooks/base.py
+++ b/airflow/hooks/base.py
@@ -18,12 +18,14 @@
 """Base class for all hooks"""
 import logging
 import warnings
-from typing import Any, Dict, List
+from typing import TYPE_CHECKING, Any, Dict, List
 
-from airflow.models.connection import Connection
 from airflow.typing_compat import Protocol
 from airflow.utils.log.logging_mixin import LoggingMixin
 
+if TYPE_CHECKING:
+    from airflow.models.connection import Connection  # Avoid circular imports.
+
 log = logging.getLogger(__name__)
 
 
@@ -37,7 +39,7 @@ class BaseHook(LoggingMixin):
     """
 
     @classmethod
-    def get_connections(cls, conn_id: str) -> List[Connection]:
+    def get_connections(cls, conn_id: str) -> List["Connection"]:
         """
         Get all connections as an iterable, given the connection id.
 
@@ -53,13 +55,15 @@ class BaseHook(LoggingMixin):
         return [cls.get_connection(conn_id)]
 
     @classmethod
-    def get_connection(cls, conn_id: str) -> Connection:
+    def get_connection(cls, conn_id: str) -> "Connection":
         """
         Get connection, given connection id.
 
         :param conn_id: connection id
         :return: connection
         """
+        from airflow.models.connection import Connection
+
         conn = Connection.get_connection_from_secrets(conn_id)
         if conn.host:
             log.info(