You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by "ASF GitHub Bot (JIRA)" <ji...@apache.org> on 2018/12/09 03:55:00 UTC

[jira] [Commented] (AIRFLOW-3390) Add connections to api

    [ https://issues.apache.org/jira/browse/AIRFLOW-3390?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16713849#comment-16713849 ] 

ASF GitHub Bot commented on AIRFLOW-3390:
-----------------------------------------

jmcarp closed pull request #4232: [AIRFLOW-3390] Add connections to api.
URL: https://github.com/apache/incubator-airflow/pull/4232
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/api/common/experimental/connection.py b/airflow/api/common/experimental/connection.py
new file mode 100644
index 0000000000..3365b1cc6f
--- /dev/null
+++ b/airflow/api/common/experimental/connection.py
@@ -0,0 +1,77 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.exceptions import AirflowBadRequest, ConnectionNotFound
+from airflow.models import Connection
+from airflow.utils.db import provide_session
+
+
+@provide_session
+def get_connection(conn_id, session=None):
+    """Get connection by a given ID."""
+    if not (conn_id and conn_id.strip()):
+        raise AirflowBadRequest("Connection ID shouldn't be empty")
+
+    connection = session.query(Connection).filter_by(conn_id=conn_id).first()
+    if connection is None:
+        raise ConnectionNotFound("Connection '%s' doesn't exist" % conn_id)
+
+    return connection
+
+
+@provide_session
+def get_connections(session=None):
+    """Get all connections."""
+    return session.query(Connection).all()
+
+
+@provide_session
+def create_connection(conn_id, session=None, **kwargs):
+    """Create a connection with the given parameters."""
+    if not (conn_id and conn_id.strip()):
+        raise AirflowBadRequest("Connection ID shouldn't be empty")
+
+    session.expire_on_commit = False
+    connection = session.query(Connection).filter_by(conn_id=conn_id).first()
+    if connection is None:
+        connection = Connection(conn_id=conn_id, **kwargs)
+        session.add(connection)
+    else:
+        for key, value in kwargs.items():
+            setattr(connection, key, value)
+
+    session.commit()
+
+    return connection
+
+
+@provide_session
+def delete_connection(conn_id, session=None):
+    """Delete connection by a given ID."""
+    if not (conn_id and conn_id.strip()):
+        raise AirflowBadRequest("Connection ID shouldn't be empty")
+
+    connection = session.query(Connection).filter_by(conn_id=conn_id).first()
+    if connection is None:
+        raise ConnectionNotFound("Connection '%s' doesn't exist" % conn_id)
+
+    session.delete(connection)
+    session.commit()
+
+    return connection
diff --git a/airflow/api/common/experimental/pool.py b/airflow/api/common/experimental/pool.py
index f125036188..40eb8ecd08 100644
--- a/airflow/api/common/experimental/pool.py
+++ b/airflow/api/common/experimental/pool.py
@@ -43,7 +43,7 @@ def get_pools(session=None):
 
 @provide_session
 def create_pool(name, slots, description, session=None):
-    """Create a pool with a given parameters."""
+    """Create a pool with the given parameters."""
     if not (name and name.strip()):
         raise AirflowBadRequest("Pool name shouldn't be empty")
 
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index d4098c4a32..953c7264e6 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -107,3 +107,8 @@ class TaskInstanceNotFound(AirflowNotFoundException):
 class PoolNotFound(AirflowNotFoundException):
     """Raise when a Pool is not available in the system"""
     pass
+
+
+class ConnectionNotFound(AirflowNotFoundException):
+    """Raise when a Connection is not available in the system"""
+    pass
diff --git a/airflow/models.py b/airflow/models.py
index 9ab2348cc2..25f4aa25d2 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -827,6 +827,16 @@ def get_hook(self):
     def __repr__(self):
         return self.conn_id
 
+    def to_json(self):
+        return {
+            'conn_id': self.conn_id,
+            'conn_type': self.conn_type,
+            'host': self.host,
+            'port': self.port,
+            'is_encrypted': self.is_encrypted,
+            'is_extra_encrypted': self.is_extra_encrypted,
+        }
+
     def debug_info(self):
         return ("id: {}. Host: {}, Port: {}, Schema: {}, "
                 "Login: {}, Password: {}, extra: {}".
diff --git a/airflow/www/api/experimental/endpoints.py b/airflow/www/api/experimental/endpoints.py
index f0bc319eb6..c0dd52c333 100644
--- a/airflow/www/api/experimental/endpoints.py
+++ b/airflow/www/api/experimental/endpoints.py
@@ -23,6 +23,7 @@
 import airflow.api
 from airflow.api.common.experimental import delete_dag as delete
 from airflow.api.common.experimental import pool as pool_api
+from airflow.api.common.experimental import connection as conn_api
 from airflow.api.common.experimental import trigger_dag as trigger
 from airflow.api.common.experimental.get_task import get_task
 from airflow.api.common.experimental.get_task_instance import get_task_instance
@@ -253,3 +254,62 @@ def delete_pool(name):
         return response
     else:
         return jsonify(pool.to_json())
+
+
+@api_experimental.route('/connections/<string:conn_id>', methods=['GET'])
+@requires_authentication
+def get_connection(conn_id):
+    """Get connection by a given ID."""
+    try:
+        conn = conn_api.get_connection(conn_id=conn_id)
+    except AirflowException as err:
+        _log.error(err)
+        response = jsonify(error="{}".format(err))
+        response.status_code = err.status_code
+        return response
+    return jsonify(conn.to_json())
+
+
+@api_experimental.route('/connections', methods=['GET'])
+@requires_authentication
+def get_connections():
+    """Get all connections."""
+    try:
+        connections = conn_api.get_connections()
+    except AirflowException as err:
+        _log.error(err)
+        response = jsonify(error="{}".format(err))
+        response.status_code = err.status_code
+        return response
+    return jsonify([conn.to_json() for conn in connections])
+
+
+@csrf.exempt
+@api_experimental.route('/connections', methods=['POST'])
+@requires_authentication
+def create_connection():
+    """Create a connection."""
+    params = request.get_json(force=True)
+    try:
+        conn = conn_api.create_connection(**params)
+    except AirflowException as err:
+        _log.error(err)
+        response = jsonify(error="{}".format(err))
+        response.status_code = err.status_code
+        return response
+    return jsonify(conn.to_json())
+
+
+@csrf.exempt
+@api_experimental.route('/connections/<string:conn_id>', methods=['DELETE'])
+@requires_authentication
+def delete_connection(conn_id):
+    """Delete connection."""
+    try:
+        conn = conn_api.delete_connection(conn_id=conn_id)
+    except AirflowException as err:
+        _log.error(err)
+        response = jsonify(error="{}".format(err))
+        response.status_code = err.status_code
+        return response
+    return jsonify(conn.to_json())
diff --git a/airflow/www_rbac/api/experimental/endpoints.py b/airflow/www_rbac/api/experimental/endpoints.py
index 21aa7e8a1a..53c3572acd 100644
--- a/airflow/www_rbac/api/experimental/endpoints.py
+++ b/airflow/www_rbac/api/experimental/endpoints.py
@@ -19,6 +19,7 @@
 import airflow.api
 
 from airflow.api.common.experimental import pool as pool_api
+from airflow.api.common.experimental import connection as conn_api
 from airflow.api.common.experimental import trigger_dag as trigger
 from airflow.api.common.experimental.get_dag_runs import get_dag_runs
 from airflow.api.common.experimental.get_task import get_task
@@ -319,3 +320,62 @@ def delete_pool(name):
         return response
     else:
         return jsonify(pool.to_json())
+
+
+@api_experimental.route('/connections/<string:conn_id>', methods=['GET'])
+@requires_authentication
+def get_connection(conn_id):
+    """Get connection by a given ID."""
+    try:
+        conn = conn_api.get_connection(conn_id=conn_id)
+    except AirflowException as err:
+        _log.error(err)
+        response = jsonify(error="{}".format(err))
+        response.status_code = err.status_code
+        return response
+    return jsonify(conn.to_json())
+
+
+@api_experimental.route('/connections', methods=['GET'])
+@requires_authentication
+def get_connections():
+    """Get all connections."""
+    try:
+        connections = conn_api.get_connections()
+    except AirflowException as err:
+        _log.error(err)
+        response = jsonify(error="{}".format(err))
+        response.status_code = err.status_code
+        return response
+    return jsonify([conn.to_json() for conn in connections])
+
+
+@csrf.exempt
+@api_experimental.route('/connections', methods=['POST'])
+@requires_authentication
+def create_connection():
+    """Create a connection."""
+    params = request.get_json(force=True)
+    try:
+        conn = conn_api.create_connection(**params)
+    except AirflowException as err:
+        _log.error(err)
+        response = jsonify(error="{}".format(err))
+        response.status_code = err.status_code
+        return response
+    return jsonify(conn.to_json())
+
+
+@csrf.exempt
+@api_experimental.route('/connections/<string:conn_id>', methods=['DELETE'])
+@requires_authentication
+def delete_connection(conn_id):
+    """Delete connection."""
+    try:
+        conn = conn_api.delete_connection(conn_id=conn_id)
+    except AirflowException as err:
+        _log.error(err)
+        response = jsonify(error="{}".format(err))
+        response.status_code = err.status_code
+        return response
+    return jsonify(conn.to_json())
diff --git a/tests/www/api/experimental/test_endpoints.py b/tests/www/api/experimental/test_endpoints.py
index 187ca921d5..f68a90a2a0 100644
--- a/tests/www/api/experimental/test_endpoints.py
+++ b/tests/www/api/experimental/test_endpoints.py
@@ -24,7 +24,7 @@
 
 from airflow import configuration
 from airflow.api.common.experimental.trigger_dag import trigger_dag
-from airflow.models import DagBag, DagModel, DagRun, Pool, TaskInstance
+from airflow.models import DagBag, DagModel, DagRun, Pool, Connection, TaskInstance
 from airflow.settings import Session
 from airflow.utils.timezone import datetime, utcnow
 from airflow.www import app as application
@@ -323,5 +323,114 @@ def test_delete_pool_non_existing(self):
                          "Pool 'foo' doesn't exist")
 
 
+class TestConnectionApiExperimental(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        super(TestConnectionApiExperimental, cls).setUpClass()
+        session = Session()
+        session.query(Connection).delete()
+        session.commit()
+        session.close()
+
+    def setUp(self):
+        super(TestConnectionApiExperimental, self).setUp()
+        configuration.load_test_config()
+        app = application.create_app(testing=True)
+        self.app = app.test_client()
+        self.session = Session()
+        self.conns = []
+        for i in range(2):
+            conn_id = 'experimental_%s' % (i + 1)
+            conn = Connection(
+                conn_id=conn_id,
+                password=str(i),
+            )
+            self.session.add(conn)
+            self.conns.append(conn)
+        self.session.commit()
+        self.conn = self.conns[0]
+
+    def tearDown(self):
+        self.session.query(Connection).delete()
+        self.session.commit()
+        self.session.close()
+        super(TestConnectionApiExperimental, self).tearDown()
+
+    def _get_connection_count(self):
+        response = self.app.get('/api/experimental/connections')
+        self.assertEqual(response.status_code, 200)
+        return len(json.loads(response.data.decode('utf-8')))
+
+    def test_get_connection(self):
+        response = self.app.get(
+            '/api/experimental/connections/{}'.format(self.conn.conn_id),
+        )
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(json.loads(response.data.decode('utf-8')),
+                         self.conn.to_json())
+
+    def test_get_connection_non_existing(self):
+        response = self.app.get('/api/experimental/connections/foo')
+        self.assertEqual(response.status_code, 404)
+        self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
+                         "Connection 'foo' doesn't exist")
+
+    def test_get_connections(self):
+        response = self.app.get('/api/experimental/connections')
+        self.assertEqual(response.status_code, 200)
+        conns = json.loads(response.data.decode('utf-8'))
+        self.assertEqual(len(conns), 2)
+        for i, connection in enumerate(sorted(conns, key=lambda p: p['conn_id'])):
+            self.assertDictEqual(connection, self.conns[i].to_json())
+
+    def test_create_connection(self):
+        response = self.app.post(
+            '/api/experimental/connections',
+            data=json.dumps({
+                'conn_id': 'foo',
+                'password': '1',
+            }),
+            content_type='application/json',
+        )
+        self.assertEqual(response.status_code, 200)
+        connection = json.loads(response.data.decode('utf-8'))
+        self.assertEqual(connection['conn_id'], 'foo')
+        self.assertEqual(self._get_connection_count(), 3)
+
+    def test_create_connection_with_bad_name(self):
+        for name in ('', '    '):
+            response = self.app.post(
+                '/api/experimental/connections',
+                data=json.dumps({
+                    'conn_id': name,
+                }),
+                content_type='application/json',
+            )
+            self.assertEqual(response.status_code, 400)
+            self.assertEqual(
+                json.loads(response.data.decode('utf-8'))['error'],
+                "Connection ID shouldn't be empty",
+            )
+        self.assertEqual(self._get_connection_count(), 2)
+
+    def test_delete_connection(self):
+        response = self.app.delete(
+            '/api/experimental/connections/{}'.format(self.conn.conn_id),
+        )
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(json.loads(response.data.decode('utf-8')),
+                         self.conn.to_json())
+        self.assertEqual(self._get_connection_count(), 1)
+
+    def test_delete_connection_non_existing(self):
+        response = self.app.delete(
+            '/api/experimental/connections/foo',
+        )
+        self.assertEqual(response.status_code, 404)
+        self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
+                         "Connection 'foo' doesn't exist")
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/tests/www_rbac/api/experimental/test_endpoints.py b/tests/www_rbac/api/experimental/test_endpoints.py
index 059ae0eabb..9e3d2e627f 100644
--- a/tests/www_rbac/api/experimental/test_endpoints.py
+++ b/tests/www_rbac/api/experimental/test_endpoints.py
@@ -26,7 +26,7 @@
 from airflow import configuration as conf
 from airflow import settings
 from airflow.api.common.experimental.trigger_dag import trigger_dag
-from airflow.models import DagBag, DagRun, Pool, TaskInstance
+from airflow.models import DagBag, DagRun, Pool, Connection, TaskInstance
 from airflow.settings import Session
 from airflow.utils.timezone import datetime, utcnow
 from airflow.www_rbac import app as application
@@ -372,5 +372,110 @@ def test_delete_pool_non_existing(self):
                          "Pool 'foo' doesn't exist")
 
 
+class TestConnectionApiExperimental(TestBase):
+
+    @classmethod
+    def setUpClass(cls):
+        super(TestConnectionApiExperimental, cls).setUpClass()
+        session = Session()
+        session.query(Connection).delete()
+        session.commit()
+        session.close()
+
+    def setUp(self):
+        super(TestConnectionApiExperimental, self).setUp()
+        self.conns = []
+        for i in range(2):
+            conn_id = 'experimental_%s' % (i + 1)
+            conn = Connection(
+                conn_id=conn_id,
+                password=str(i),
+            )
+            self.session.add(conn)
+            self.conns.append(conn)
+        self.session.commit()
+        self.conn = self.conns[0]
+
+    def tearDown(self):
+        self.session.query(Connection).delete()
+        self.session.commit()
+        self.session.close()
+        super(TestConnectionApiExperimental, self).tearDown()
+
+    def _get_connection_count(self):
+        response = self.client.get('/api/experimental/connections')
+        self.assertEqual(response.status_code, 200)
+        return len(json.loads(response.data.decode('utf-8')))
+
+    def test_get_connection(self):
+        response = self.client.get(
+            '/api/experimental/connections/{}'.format(self.conn.conn_id),
+        )
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(json.loads(response.data.decode('utf-8')),
+                         self.conn.to_json())
+
+    def test_get_connection_non_existing(self):
+        response = self.client.get('/api/experimental/connections/foo')
+        self.assertEqual(response.status_code, 404)
+        self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
+                         "Connection 'foo' doesn't exist")
+
+    def test_get_connections(self):
+        response = self.client.get('/api/experimental/connections')
+        self.assertEqual(response.status_code, 200)
+        conns = json.loads(response.data.decode('utf-8'))
+        self.assertEqual(len(conns), 2)
+        for i, connection in enumerate(sorted(conns, key=lambda p: p['conn_id'])):
+            self.assertDictEqual(connection, self.conns[i].to_json())
+
+    def test_create_connection(self):
+        response = self.client.post(
+            '/api/experimental/connections',
+            data=json.dumps({
+                'conn_id': 'foo',
+                'password': '1',
+            }),
+            content_type='application/json',
+        )
+        self.assertEqual(response.status_code, 200)
+        connection = json.loads(response.data.decode('utf-8'))
+        self.assertEqual(connection['conn_id'], 'foo')
+        self.assertEqual(self._get_connection_count(), 3)
+
+    def test_create_connection_with_bad_name(self):
+        for name in ('', '    '):
+            response = self.client.post(
+                '/api/experimental/connections',
+                data=json.dumps({
+                    'conn_id': name,
+                }),
+                content_type='application/json',
+            )
+            self.assertEqual(response.status_code, 400)
+            self.assertEqual(
+                json.loads(response.data.decode('utf-8'))['error'],
+                "Connection ID shouldn't be empty",
+            )
+        self.assertEqual(self._get_connection_count(), 2)
+
+    def test_delete_connection(self):
+        response = self.client.delete(
+            '/api/experimental/connections/{}'.format(self.conn.conn_id),
+        )
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(json.loads(response.data.decode('utf-8')),
+                         self.conn.to_json())
+        self.assertEqual(self._get_connection_count(), 1)
+
+    def test_delete_connection_non_existing(self):
+        response = self.client.delete(
+            '/api/experimental/connections/foo',
+        )
+        self.assertEqual(response.status_code, 404)
+        self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
+                         "Connection 'foo' doesn't exist")
+
+
 if __name__ == '__main__':
     unittest.main()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


> Add connections to api
> ----------------------
>
>                 Key: AIRFLOW-3390
>                 URL: https://issues.apache.org/jira/browse/AIRFLOW-3390
>             Project: Apache Airflow
>          Issue Type: Improvement
>            Reporter: Josh Carp
>            Priority: Minor
>
> It would be useful to be able to read and write connections via the api.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)