You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/07 20:11:25 UTC

[airflow] branch main updated: Add test connection functionality to `GithubHook` (#24903)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7f749b653c Add test connection functionality to `GithubHook` (#24903)
7f749b653c is described below

commit 7f749b653ce363b1450346b61c7f6c406f72cd66
Author: Josh Fell <48...@users.noreply.github.com>
AuthorDate: Thu Jul 7 16:11:16 2022 -0400

    Add test connection functionality to `GithubHook` (#24903)
---
 airflow/providers/github/hooks/github.py        | 39 +++++++++++++++----------
 tests/providers/github/hooks/test_github.py     | 38 +++++++++++++++++++-----
 tests/providers/github/operators/test_github.py | 10 +++----
 tests/providers/github/sensors/test_github.py   | 10 +++----
 4 files changed, 62 insertions(+), 35 deletions(-)

diff --git a/airflow/providers/github/hooks/github.py b/airflow/providers/github/hooks/github.py
index 07a8566a7f..9a71ef5b38 100644
--- a/airflow/providers/github/hooks/github.py
+++ b/airflow/providers/github/hooks/github.py
@@ -16,17 +16,18 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""This module allows to connect to a Github."""
-from typing import Dict, Optional
+"""This module allows you to connect to GitHub."""
+from typing import Dict, Optional, Tuple
 
 from github import Github as GithubClient
 
+from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 
 
 class GithubHook(BaseHook):
     """
-    Interact with Github.
+    Interact with GitHub.
 
     Performs a connection to GitHub and retrieves client.
 
@@ -36,7 +37,7 @@ class GithubHook(BaseHook):
     conn_name_attr = 'github_conn_id'
     default_conn_name = 'github_default'
     conn_type = 'github'
-    hook_name = 'Github'
+    hook_name = 'GitHub'
 
     def __init__(self, github_conn_id: str = default_conn_name, *args, **kwargs) -> None:
         super().__init__(*args, **kwargs)
@@ -45,10 +46,7 @@ class GithubHook(BaseHook):
         self.get_conn()
 
     def get_conn(self) -> GithubClient:
-        """
-        Function that initiates a new GitHub connection
-        with token and hostname ( for GitHub Enterprise )
-        """
+        """Function that initiates a new GitHub connection with token and hostname (for GitHub Enterprise)."""
         if self.client is not None:
             return self.client
 
@@ -56,6 +54,12 @@ class GithubHook(BaseHook):
         access_token = conn.password
         host = conn.host
 
+        # Currently the only method of authenticating to GitHub in Airflow is via a token. This is not the
+        # only means available, but raising an exception to enforce this method for now.
+        # TODO: When/If other auth methods are implemented this exception should be removed/modified.
+        if not access_token:
+            raise AirflowException("An access token is required to authenticate to GitHub.")
+
         if not host:
             self.client = GithubClient(login_or_token=access_token)
         else:
@@ -68,12 +72,15 @@ class GithubHook(BaseHook):
         """Returns custom field behaviour"""
         return {
             "hidden_fields": ['schema', 'port', 'login', 'extra'],
-            "relabeling": {
-                'host': 'GitHub Enterprise Url (Optional)',
-                'password': 'GitHub Access Token',
-            },
-            "placeholders": {
-                'host': 'https://{hostname}/api/v3 (for GitHub Enterprise Connection)',
-                'password': 'token credentials auth',
-            },
+            "relabeling": {'host': 'GitHub Enterprise URL (Optional)', 'password': 'GitHub Access Token'},
+            "placeholders": {'host': 'https://{hostname}/api/v3 (for GitHub Enterprise)'},
         }
+
+    def test_connection(self) -> Tuple[bool, str]:
+        """Test GitHub connection."""
+        try:
+            assert self.client  # For mypy union-attr check of Optional[GithubClient].
+            self.client.get_user().id
+            return True, "Successfully connected to GitHub."
+        except Exception as e:
+            return False, str(e)
diff --git a/tests/providers/github/hooks/test_github.py b/tests/providers/github/hooks/test_github.py
index b4feab3183..4bad1d8e8c 100644
--- a/tests/providers/github/hooks/test_github.py
+++ b/tests/providers/github/hooks/test_github.py
@@ -17,9 +17,10 @@
 # under the License.
 #
 
-import unittest
 from unittest.mock import Mock, patch
 
+from github import BadCredentialsException, Github, NamedUser
+
 from airflow.models import Connection
 from airflow.providers.github.hooks.github import GithubHook
 from airflow.utils import db
@@ -27,15 +28,14 @@ from airflow.utils import db
 github_client_mock = Mock(name="github_client_for_test")
 
 
-class TestGithubHook(unittest.TestCase):
-    def setUp(self):
+class TestGithubHook:
+    def setup_class(self):
         db.merge_conn(
             Connection(
-                conn_id='github_default',
+                conn_id="github_default",
                 conn_type='github',
-                host='https://localhost/github/',
-                port=443,
-                extra='{"verify": "False", "project": "AIRFLOW"}',
+                password='my-access-token',
+                host='https://mygithub.com/api/v3',
             )
         )
 
@@ -48,3 +48,27 @@ class TestGithubHook(unittest.TestCase):
         assert github_mock.called
         assert isinstance(github_hook.client, Mock)
         assert github_hook.client.name == github_mock.return_value.name
+
+    def test_connection_success(self):
+        hook = GithubHook()
+        hook.client = Mock(spec=Github)
+        hook.client.get_user.return_value = NamedUser.NamedUser
+
+        status, msg = hook.test_connection()
+
+        assert status is True
+        assert msg == "Successfully connected to GitHub."
+
+    def test_connection_failure(self):
+        hook = GithubHook()
+        hook.client.get_user = Mock(
+            side_effect=BadCredentialsException(
+                status=401,
+                data={"message": "Bad credentials"},
+                headers={},
+            )
+        )
+        status, msg = hook.test_connection()
+
+        assert status is False
+        assert msg == '401 {"message": "Bad credentials"}'
diff --git a/tests/providers/github/operators/test_github.py b/tests/providers/github/operators/test_github.py
index 23461cbbdf..8b3b626321 100644
--- a/tests/providers/github/operators/test_github.py
+++ b/tests/providers/github/operators/test_github.py
@@ -17,7 +17,6 @@
 # under the License.
 #
 
-import unittest
 from unittest.mock import Mock, patch
 
 from airflow.models import Connection
@@ -29,8 +28,8 @@ DEFAULT_DATE = timezone.datetime(2017, 1, 1)
 github_client_mock = Mock(name="github_client_for_test")
 
 
-class TestGithubOperator(unittest.TestCase):
-    def setUp(self):
+class TestGithubOperator:
+    def setup_class(self):
         args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
         dag = DAG('test_dag_id', default_args=args)
         self.dag = dag
@@ -38,9 +37,8 @@ class TestGithubOperator(unittest.TestCase):
             Connection(
                 conn_id='github_default',
                 conn_type='github',
-                host='https://localhost/github/',
-                port=443,
-                extra='{"verify": "False", "project": "AIRFLOW"}',
+                password='my-access-token',
+                host='https://mygithub.com/api/v3',
             )
         )
 
diff --git a/tests/providers/github/sensors/test_github.py b/tests/providers/github/sensors/test_github.py
index 71cb0a75ca..14d168fa46 100644
--- a/tests/providers/github/sensors/test_github.py
+++ b/tests/providers/github/sensors/test_github.py
@@ -17,7 +17,6 @@
 # under the License.
 #
 
-import unittest
 from unittest.mock import Mock, patch
 
 from airflow.models import Connection
@@ -29,8 +28,8 @@ DEFAULT_DATE = timezone.datetime(2017, 1, 1)
 github_client_mock = Mock(name="github_client_for_test")
 
 
-class TestGithubSensor(unittest.TestCase):
-    def setUp(self):
+class TestGithubSensor:
+    def setup_class(self):
         args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
         dag = DAG('test_dag_id', default_args=args)
         self.dag = dag
@@ -38,9 +37,8 @@ class TestGithubSensor(unittest.TestCase):
             Connection(
                 conn_id='github_default',
                 conn_type='github',
-                host='https://localhost/github/',
-                port=443,
-                extra='{"verify": "False", "project": "AIRFLOW"}',
+                password='my-access-token',
+                host='https://mygithub.com/api/v3',
             )
         )