You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2019/01/07 18:23:20 UTC

[GitHub] morgendave closed pull request #4101: [AIRFLOW-3272] Add base grpc hook

morgendave closed pull request #4101: [AIRFLOW-3272] Add base grpc hook
URL: https://github.com/apache/airflow/pull/4101
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/grpc_hook.py b/airflow/contrib/hooks/grpc_hook.py
new file mode 100644
index 0000000000..46a9a0ca7e
--- /dev/null
+++ b/airflow/contrib/hooks/grpc_hook.py
@@ -0,0 +1,120 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import grpc
+from google import auth as google_auth
+from google.auth import jwt as google_auth_jwt
+from google.auth.transport import grpc as google_auth_transport_grpc
+from google.auth.transport import requests as google_auth_transport_requests
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.exceptions import AirflowConfigException
+
+
+class GrpcHook(BaseHook):
+    """
+    General interaction with gRPC servers.
+    :param grpc_conn_id: The connection ID to use when fetching connection info.
+    :type grpc_conn_id: str
+    :param interceptors: a list of gRPC interceptor objects which would be applied
+        to the connected gRPC channel. None by default.
+    :type interceptors: a list of gRPC interceptors based on or extends the four
+        official gRPC interceptors, eg, UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
+        StreamUnaryClientInterceptor, StreamStreamClientInterceptor.
+    ::param custom_connection_func: The customized connection function to return gRPC channel.
+    :type custom_connection_func: python callable objects that accept the connection as
+        its only arg. Could be partial or lambda.
+    """
+
+    def __init__(self, grpc_conn_id, interceptors=None, custom_connection_func=None):
+        self.grpc_conn_id = grpc_conn_id
+        self.conn = self.get_connection(self.grpc_conn_id)
+        self.extras = self.conn.extra_dejson
+        self.interceptors = interceptors if interceptors else []
+        self.custom_connection_func = custom_connection_func
+
+    def get_conn(self):
+        base_url = self.conn.host
+
+        if self.conn.port:
+            base_url = base_url + ":" + str(self.conn.port)
+
+        auth_type = self._get_field("auth_type")
+
+        if auth_type == "NO_AUTH":
+            channel = grpc.insecure_channel(base_url)
+        elif auth_type == "SSL" or auth_type == "TLS":
+            credential_file_name = self._get_field("credential_pem_file")
+            creds = grpc.ssl_channel_credentials(open(credential_file_name).read())
+            channel = grpc.secure_channel(base_url, creds)
+        elif auth_type == "JWT_GOOGLE":
+            credentials, _ = google_auth.default()
+            jwt_creds = google_auth_jwt.OnDemandCredentials.from_signing_credentials(
+                credentials)
+            channel = google_auth_transport_grpc.secure_authorized_channel(
+                jwt_creds, None, base_url)
+        elif auth_type == "OATH_GOOGLE":
+            scopes = self._get_field("scopes").split(",")
+            credentials, _ = google_auth.default(scopes=scopes)
+            request = google_auth_transport_requests.Request()
+            channel = google_auth_transport_grpc.secure_authorized_channel(
+                credentials, request, base_url)
+        elif auth_type == "CUSTOM":
+            if not self.custom_connection_func:
+                raise AirflowConfigException(
+                    "Customized connection function not set, not able to establish a channel")
+            channel = self.custom_connection_func(self.conn)
+        else:
+            raise AirflowConfigException(
+                "auth_type not supported or not provided, channel cannot be established,\
+                given value: %s" % str(auth_type))
+
+        if self.interceptors:
+            for interceptor in self.interceptors:
+                channel = grpc.intercept_channel(channel,
+                                                 interceptor)
+
+        return channel
+
+    def run(self, stub_class, call_func, streaming=False, data={}):
+        with self.get_conn() as channel:
+            stub = stub_class(channel)
+            try:
+                rpc_func = getattr(stub, call_func)
+                response = rpc_func(**data)
+                if not streaming:
+                    yield response
+                else:
+                    for single_response in response:
+                        yield single_response
+            except grpc.RpcError as ex:
+                self.log.exception(
+                    "Error occured when calling the grpc service: {0}, method: {1} \
+                    status code: {2}, error details: {3}"
+                    .format(stub.__class__.__name__, call_func, ex.code(), ex.details()))
+                raise ex
+
+    def _get_field(self, field_name, default=None):
+        """
+        Fetches a field from extras, and returns it. This is some Airflow
+        magic. The grpc hook type adds custom UI elements
+        to the hook page, which allow admins to specify scopes, credential pem files, etc.
+        They get formatted as shown below.
+        """
+        full_field_name = 'extra__grpc__{}'.format(field_name)
+        if full_field_name in self.extras:
+            return self.extras[full_field_name]
+        else:
+            return default
diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index ff63560020..cf7b159257 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -89,6 +89,7 @@ class Connection(Base, LoggingMixin):
         ('qubole', 'Qubole'),
         ('mongo', 'MongoDB'),
         ('gcpcloudsql', 'Google Cloud SQL'),
+        ('grpc', 'GRPC Connection'),
     ]
 
     def __init__(
@@ -238,6 +239,9 @@ def get_hook(self):
             elif self.conn_type == 'gcpcloudsql':
                 from airflow.contrib.hooks.gcp_sql_hook import CloudSqlDatabaseHook
                 return CloudSqlDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
+            elif self.conn_type == 'grpc':
+                from airflow.contrib.hooks.grpc_hook import GrpcHook
+                return GrpcHook(grpc_conn_id=self.conn_id)
         except Exception:
             pass
 
diff --git a/airflow/www/views.py b/airflow/www/views.py
index bc8e392b3a..b54fdb8562 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -2968,6 +2968,9 @@ class ConnectionModelView(wwwutils.SuperUserMixin, AirflowModelView):
         'extra__google_cloud_platform__key_path',
         'extra__google_cloud_platform__keyfile_dict',
         'extra__google_cloud_platform__scope',
+        'extra__grpc__auth_type',
+        'extra__grpc__credential_pem_file',
+        'extra__grpc__scopes',
     )
     verbose_name = "Connection"
     verbose_name_plural = "Connections"
@@ -2990,6 +2993,9 @@ class ConnectionModelView(wwwutils.SuperUserMixin, AirflowModelView):
         'extra__google_cloud_platform__key_path': StringField('Keyfile Path'),
         'extra__google_cloud_platform__keyfile_dict': PasswordField('Keyfile JSON'),
         'extra__google_cloud_platform__scope': StringField('Scopes (comma separated)'),
+        'extra__grpc__auth_type': StringField('Authentication Type'),
+        'extra__grpc__credential_pem_file': StringField('Credential Pem File Path'),
+        'extra__grpc__scopes': StringField('Scopes (comma separated)'),
     }
     form_choices = {
         'conn_type': Connection._types
@@ -2997,7 +3003,7 @@ class ConnectionModelView(wwwutils.SuperUserMixin, AirflowModelView):
 
     def on_model_change(self, form, model, is_created):
         formdata = form.data
-        if formdata['conn_type'] in ['jdbc', 'google_cloud_platform']:
+        if formdata['conn_type'] in ['jdbc', 'google_cloud_platform', 'grpc']:
             extra = {
                 key: formdata[key]
                 for key in self.form_extra_fields.keys() if key in formdata}
diff --git a/airflow/www_rbac/views.py b/airflow/www_rbac/views.py
index 3177a5540b..fb5515c43d 100644
--- a/airflow/www_rbac/views.py
+++ b/airflow/www_rbac/views.py
@@ -1918,7 +1918,10 @@ class ConnectionModelView(AirflowModelView):
                     'extra__google_cloud_platform__project',
                     'extra__google_cloud_platform__key_path',
                     'extra__google_cloud_platform__keyfile_dict',
-                    'extra__google_cloud_platform__scope']
+                    'extra__google_cloud_platform__scope',
+                    'extra__grpc__auth_type',
+                    'extra__grpc__credential_pem_file',
+                    'extra__grpc__scopes']
     list_columns = ['conn_id', 'conn_type', 'host', 'port', 'is_encrypted',
                     'is_extra_encrypted']
     add_columns = edit_columns = ['conn_id', 'conn_type', 'host', 'schema',
@@ -1939,7 +1942,7 @@ def action_muldelete(self, items):
 
     def process_form(self, form, is_created):
         formdata = form.data
-        if formdata['conn_type'] in ['jdbc', 'google_cloud_platform']:
+        if formdata['conn_type'] in ['jdbc', 'google_cloud_platform', 'grpc']:
             extra = {
                 key: formdata[key]
                 for key in self.extra_fields if key in formdata}
diff --git a/setup.py b/setup.py
index 6dc452302f..10ae4e4b53 100644
--- a/setup.py
+++ b/setup.py
@@ -313,6 +313,7 @@ def do_setup():
             'funcsigs==1.0.0',
             'future>=0.16.0, <0.17',
             'gitpython>=2.0.2',
+            'grpcio>=1.15.0',
             'gunicorn>=19.4.0, <20.0',
             'iso8601>=0.1.12',
             'json-merge-patch==0.2',
diff --git a/tests/contrib/hooks/test_grpc_hook.py b/tests/contrib/hooks/test_grpc_hook.py
new file mode 100644
index 0000000000..0fe3fc4d28
--- /dev/null
+++ b/tests/contrib/hooks/test_grpc_hook.py
@@ -0,0 +1,312 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+import StringIO
+
+from airflow import configuration
+from airflow.exceptions import AirflowConfigException
+from airflow.contrib.hooks.grpc_hook import GrpcHook
+from airflow.models.connection import Connection
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+
+def get_airflow_connection(auth_type="NO_AUTH", credential_pem_file=None, scopes=None):
+    extra = \
+        '{{"extra__grpc__auth_type": "{auth_type}",' \
+        '"extra__grpc__credential_pem_file": "{credential_pem_file}",' \
+        '"extra__grpc__scopes": "{scopes}"}}' \
+        .format(auth_type=auth_type,
+                credential_pem_file=credential_pem_file,
+                scopes=scopes)
+
+    return Connection(
+        conn_id='grpc_default',
+        conn_type='grpc',
+        host='test:8080',
+        extra=extra
+    )
+
+
+def get_airflow_connection_with_port():
+    return Connection(
+        conn_id='grpc_default',
+        conn_type='grpc',
+        host='test.com',
+        port=1234,
+        extra='{"extra__grpc__auth_type": "NO_AUTH"}'
+    )
+
+
+class StubClass(object):
+    def __init__(self, channel):
+        pass
+
+    def single_call(self, data):
+        return data
+
+    def stream_call(self, data):
+        return ["streaming", "call"]
+
+
+class TestGrpcHook(unittest.TestCase):
+    def setUp(self):
+        configuration.load_test_config()
+        self.channel_mock = mock.patch('grpc.Channel').start()
+
+    def custom_conn_func(self, connection):
+        mocked_channel = self.channel_mock.return_value
+        return mocked_channel
+
+    @mock.patch('grpc.insecure_channel')
+    @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
+    def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel):
+        conn = get_airflow_connection()
+        mock_get_connection.return_value = conn
+        hook = GrpcHook("grpc_default")
+        mocked_channel = self.channel_mock.return_value
+        mock_insecure_channel.return_value = mocked_channel
+
+        channel = hook.get_conn()
+        expected_url = "test:8080"
+
+        mock_insecure_channel.assert_called_once_with(expected_url)
+        self.assertEquals(channel, mocked_channel)
+
+    @mock.patch('grpc.insecure_channel')
+    @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
+    def test_connection_with_port(self, mock_get_connection, mock_insecure_channel):
+        conn = get_airflow_connection_with_port()
+        mock_get_connection.return_value = conn
+        hook = GrpcHook("grpc_default")
+        mocked_channel = self.channel_mock.return_value
+        mock_insecure_channel.return_value = mocked_channel
+
+        channel = hook.get_conn()
+        expected_url = "test.com:1234"
+
+        mock_insecure_channel.assert_called_once_with(expected_url)
+        self.assertEquals(channel, mocked_channel)
+
+    @mock.patch('airflow.contrib.hooks.grpc_hook.open')
+    @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
+    @mock.patch('grpc.ssl_channel_credentials')
+    @mock.patch('grpc.secure_channel')
+    def test_connection_with_ssl(self,
+                                 mock_secure_channel,
+                                 mock_channel_credentials,
+                                 mock_get_connection,
+                                 mock_open):
+        conn = get_airflow_connection(
+            auth_type="SSL",
+            credential_pem_file="pem"
+        )
+        mock_get_connection.return_value = conn
+        mock_open.return_value = StringIO.StringIO('credential')
+        hook = GrpcHook("grpc_default")
+        mocked_channel = self.channel_mock.return_value
+        mock_secure_channel.return_value = mocked_channel
+        mock_credential_object = "test_credential_object"
+        mock_channel_credentials.return_value = mock_credential_object
+
+        channel = hook.get_conn()
+        expected_url = "test:8080"
+
+        mock_open.assert_called_once_with("pem")
+        mock_channel_credentials.assert_called_once_with('credential')
+        mock_secure_channel.assert_called_once_with(
+            expected_url,
+            mock_credential_object
+        )
+        self.assertEquals(channel, mocked_channel)
+
+    @mock.patch('airflow.contrib.hooks.grpc_hook.open')
+    @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
+    @mock.patch('grpc.ssl_channel_credentials')
+    @mock.patch('grpc.secure_channel')
+    def test_connection_with_tls(self,
+                                 mock_secure_channel,
+                                 mock_channel_credentials,
+                                 mock_get_connection,
+                                 mock_open):
+        conn = get_airflow_connection(
+            auth_type="TLS",
+            credential_pem_file="pem"
+        )
+        mock_get_connection.return_value = conn
+        mock_open.return_value = StringIO.StringIO('credential')
+        hook = GrpcHook("grpc_default")
+        mocked_channel = self.channel_mock.return_value
+        mock_secure_channel.return_value = mocked_channel
+        mock_credential_object = "test_credential_object"
+        mock_channel_credentials.return_value = mock_credential_object
+
+        channel = hook.get_conn()
+        expected_url = "test:8080"
+
+        mock_open.assert_called_once_with("pem")
+        mock_channel_credentials.assert_called_once_with('credential')
+        mock_secure_channel.assert_called_once_with(
+            expected_url,
+            mock_credential_object
+        )
+        self.assertEquals(channel, mocked_channel)
+    
+    @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
+    @mock.patch('google.auth.jwt.OnDemandCredentials.from_signing_credentials')
+    @mock.patch('google.auth.default')
+    @mock.patch('google.auth.transport.grpc.secure_authorized_channel')
+    def test_connection_with_jwt(self,
+                                 mock_secure_channel,
+                                 mock_google_default_auth,
+                                 mock_google_cred,
+                                 mock_get_connection):
+        conn = get_airflow_connection(
+            auth_type="JWT_GOOGLE"
+        )
+        mock_get_connection.return_value = conn
+        hook = GrpcHook("grpc_default")
+        mocked_channel = self.channel_mock.return_value
+        mock_secure_channel.return_value = mocked_channel
+        mock_credential_object = "test_credential_object"
+        mock_google_default_auth.return_value = (mock_credential_object, "")
+        mock_google_cred.return_value = mock_credential_object
+
+        channel = hook.get_conn()
+        expected_url = "test:8080"
+
+        mock_google_cred.assert_called_once_with(mock_credential_object)
+        mock_secure_channel.assert_called_once_with(
+            mock_credential_object,
+            None,
+            expected_url
+        )
+        self.assertEquals(channel, mocked_channel)
+    
+    @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
+    @mock.patch('google.auth.transport.requests.Request')
+    @mock.patch('google.auth.default')
+    @mock.patch('google.auth.transport.grpc.secure_authorized_channel')
+    def test_connection_with_google_oauth(self,
+                                          mock_secure_channel,
+                                          mock_google_default_auth,
+                                          mock_google_auth_request,
+                                          mock_get_connection):
+        conn = get_airflow_connection(
+            auth_type="OATH_GOOGLE",
+            scopes="grpc,gcs"
+        )
+        mock_get_connection.return_value = conn
+        hook = GrpcHook("grpc_default")
+        mocked_channel = self.channel_mock.return_value
+        mock_secure_channel.return_value = mocked_channel
+        mock_credential_object = "test_credential_object"
+        mock_google_default_auth.return_value = (mock_credential_object, "")
+        mock_google_auth_request.return_value = "request"
+
+        channel = hook.get_conn()
+        expected_url = "test:8080"
+
+        mock_google_default_auth.assert_called_once_with(scopes=[u"grpc", u"gcs"])
+        mock_secure_channel.assert_called_once_with(
+            mock_credential_object,
+            "request",
+            expected_url
+        )
+        self.assertEquals(channel, mocked_channel)
+
+    @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
+    def test_custom_connection(self, mock_get_connection):
+        conn = get_airflow_connection("CUSTOM")
+        mock_get_connection.return_value = conn
+        mocked_channel = self.channel_mock.return_value
+        hook = GrpcHook("grpc_default", custom_connection_func=self.custom_conn_func)
+
+        channel = hook.get_conn()
+
+        self.assertEquals(channel, mocked_channel)
+
+    @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
+    def test_custom_connection_with_no_connection_func(self, mock_get_connection):
+        conn = get_airflow_connection("CUSTOM")
+        mock_get_connection.return_value = conn
+        hook = GrpcHook("grpc_default")
+
+        with self.assertRaises(AirflowConfigException):
+            channel = hook.get_conn()
+
+    @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
+    def test_connection_type_not_supported(self, mock_get_connection):
+        conn = get_airflow_connection("NOT_SUPPORT")
+        mock_get_connection.return_value = conn
+        hook = GrpcHook("grpc_default")
+
+        with self.assertRaises(AirflowConfigException):
+            channel = hook.get_conn()
+
+    @mock.patch('grpc.intercept_channel')
+    @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
+    @mock.patch('grpc.insecure_channel')
+    def test_connection_with_interceptors(self,
+                                          mock_insecure_channel,
+                                          mock_get_connection,
+                                          mock_intercept_channel):
+        conn = get_airflow_connection()
+        mock_get_connection.return_value = conn
+        mocked_channel = self.channel_mock.return_value
+        hook = GrpcHook("grpc_default", interceptors=["test1"])
+        mock_insecure_channel.return_value = mocked_channel
+        mock_intercept_channel.return_value = mocked_channel
+
+        channel = hook.get_conn()
+
+        self.assertEquals(channel, mocked_channel)
+        mock_intercept_channel.assert_called_once_with(mocked_channel, "test1")
+
+    @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
+    @mock.patch('airflow.contrib.hooks.grpc_hook.GrpcHook.get_conn')
+    def test_simple_run(self, mock_get_conn, mock_get_connection):
+        conn = get_airflow_connection()
+        mock_get_connection.return_value = conn
+        mocked_channel = mock.Mock()
+        mocked_channel.__enter__ = mock.Mock(return_value=(mock.Mock(), None))
+        mocked_channel.__exit__ = mock.Mock(return_value=None)
+        hook = GrpcHook("grpc_default")
+        mock_get_conn.return_value = mocked_channel
+
+        response = hook.run(StubClass, "single_call", data={'data': 'hello'})
+
+        self.assertEquals(response.next(), "hello")
+
+    @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
+    @mock.patch('airflow.contrib.hooks.grpc_hook.GrpcHook.get_conn')
+    def test_stream_run(self, mock_get_conn, mock_get_connection):
+        conn = get_airflow_connection()
+        mock_get_connection.return_value = conn
+        mocked_channel = mock.Mock()
+        mocked_channel.__enter__ = mock.Mock(return_value=(mock.Mock(), None))
+        mocked_channel.__exit__ = mock.Mock(return_value=None)
+        hook = GrpcHook("grpc_default")
+        mock_get_conn.return_value = mocked_channel
+
+        response = hook.run(StubClass, "stream_call", data={'data': ['hello!', "hi"]})
+
+        self.assertEquals(response.next(), ["streaming", "call"])
+ 
\ No newline at end of file


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services