You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by sa...@apache.org on 2016/10/06 03:18:58 UTC

incubator-airflow git commit: [AIRFLOW-358][AIRFLOW-430] Add `connections` cli

Repository: incubator-airflow
Updated Branches:
  refs/heads/master fe5eaabb2 -> f50926425


[AIRFLOW-358][AIRFLOW-430] Add `connections` cli

This PR adds a `connections` command to Airflow's
CLI. The new
`connections` command hopes to make it easier to
automate Airflow's
deployment to different environments. Users won't
have to directly
interact with the database or input connections
manually on the UI.

Closes #1802 from PedroMDuarte/connections-cli


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/f5092642
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/f5092642
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/f5092642

Branch: refs/heads/master
Commit: f50926425389e42c8dd42eb8d6d5eee5ae21abc9
Parents: fe5eaab
Author: Pedro M Duarte <pm...@gmail.com>
Authored: Wed Oct 5 20:19:04 2016 -0700
Committer: Siddharth Anand <si...@yahoo.com>
Committed: Wed Oct 5 20:19:04 2016 -0700

----------------------------------------------------------------------
 airflow/bin/cli.py   | 141 +++++++++++++++++++++++++++++++++-
 airflow/models.py    |  24 ++++++
 airflow/www/views.py |  24 +-----
 tests/core.py        | 188 ++++++++++++++++++++++++++++++++++++++++++++++
 4 files changed, 352 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5092642/airflow/bin/cli.py
----------------------------------------------------------------------
diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py
index 66d8a26..324f869 100755
--- a/airflow/bin/cli.py
+++ b/airflow/bin/cli.py
@@ -15,6 +15,7 @@
 
 from __future__ import print_function
 import logging
+import reprlib
 import os
 import subprocess
 import textwrap
@@ -26,6 +27,7 @@ from builtins import input
 from collections import namedtuple
 from dateutil.parser import parse as parsedate
 import json
+from tabulate import tabulate
 
 import daemon
 from daemon.pidfile import TimeoutPIDLockFile
@@ -43,7 +45,7 @@ from airflow.exceptions import AirflowException
 from airflow.executors import DEFAULT_EXECUTOR
 from airflow.models import (DagModel, DagBag, TaskInstance,
                             DagPickle, DagRun, Variable, DagStat,
-                            Pool)
+                            Pool, Connection)
 from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS)
 from airflow.utils import db as db_utils
 from airflow.utils import logging as logging_utils
@@ -51,6 +53,7 @@ from airflow.utils.state import State
 from airflow.www.app import cached_app
 
 from sqlalchemy import func
+from sqlalchemy.orm import exc
 
 DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
 
@@ -889,10 +892,114 @@ def upgradedb(args):  # noqa
             session.add(DagStat(dag_id=dag_id, state=state, count=count))
         session.commit()
 
+
 def version(args):  # noqa
     print(settings.HEADER + "  v" + airflow.__version__)
 
 
+def connections(args):
+    if args.list:
+        # Check that no other flags were passed to the command
+        invalid_args = list()
+        for arg in ['conn_id', 'conn_uri', 'conn_extra']:
+            if getattr(args, arg) is not None:
+                invalid_args.append(arg)
+        if invalid_args:
+            msg = ('\n\tThe following args are not compatible with the ' +
+                   '--list flag: {invalid!r}\n')
+            msg = msg.format(invalid=invalid_args)
+            print(msg)
+            return
+
+        session = settings.Session()
+        conns = session.query(Connection.conn_id, Connection.conn_type,
+                              Connection.host, Connection.port,
+                              Connection.is_encrypted,
+                              Connection.is_extra_encrypted,
+                              Connection.extra).all()
+        conns = [map(reprlib.repr, conn) for conn in conns] 
+        print(tabulate(conns, ['Conn Id', 'Conn Type', 'Host', 'Port',
+                               'Is Encrypted', 'Is Extra Encrypted', 'Extra'],
+                       tablefmt="fancy_grid"))
+        return
+
+    if args.delete:
+        # Check that only the `conn_id` arg was passed to the command
+        invalid_args = list()
+        for arg in ['conn_uri', 'conn_extra']:
+            if getattr(args, arg) is not None:
+                invalid_args.append(arg)
+        if invalid_args:
+            msg = ('\n\tThe following args are not compatible with the ' +
+                   '--delete flag: {invalid!r}\n')
+            msg = msg.format(invalid=invalid_args)
+            print(msg)
+            return
+
+        if args.conn_id is None:
+            print('\n\tTo delete a connection, you Must provide a value for ' +
+                  'the --conn_id flag.\n')
+            return
+
+        session = settings.Session()
+        try:
+            to_delete = (session
+                         .query(Connection)
+                         .filter(Connection.conn_id == args.conn_id)
+                         .one())
+        except exc.NoResultFound:
+            msg = '\n\tDid not find a connection with `conn_id`={conn_id}\n'
+            msg = msg.format(conn_id=args.conn_id)
+            print(msg)
+            return
+        except exc.MultipleResultsFound:
+            msg = ('\n\tFound more than one connection with ' +
+                   '`conn_id`={conn_id}\n')
+            msg = msg.format(conn_id=args.conn_id)
+            print(msg)
+            return
+        else:
+            deleted_conn_id = to_delete.conn_id
+            session.delete(to_delete)
+            session.commit()
+            msg = '\n\tSuccessfully deleted `conn_id`={conn_id}\n'
+            msg = msg.format(conn_id=deleted_conn_id)
+            print(msg)
+        return
+
+    if args.add:
+        # Check that the conn_id and conn_uri args were passed to the command:
+        missing_args = list()
+        for arg in ['conn_id', 'conn_uri']:
+            if getattr(args, arg) is None:
+                missing_args.append(arg)
+        if missing_args:
+            msg = ('\n\tThe following args are required to add a connection:' +
+                   ' {missing!r}\n'.format(missing=missing_args))
+            print(msg)
+            return
+
+        new_conn = Connection(conn_id=args.conn_id, uri=args.conn_uri)
+        if args.conn_extra is not None:
+            new_conn.set_extra(args.conn_extra)
+
+        session = settings.Session()
+        if not (session
+                .query(Connection)
+                .filter(Connection.conn_id == new_conn.conn_id).first()):
+            session.add(new_conn)
+            session.commit()
+            msg = '\n\tSuccessfully added `conn_id`={conn_id} : {uri}\n'
+            msg = msg.format(conn_id=new_conn.conn_id, uri=args.conn_uri)
+            print(msg)
+        else:
+            msg = '\n\tA connection with `conn_id`={conn_id} already exists\n'
+            msg = msg.format(conn_id=new_conn.conn_id)
+            print(msg)
+
+        return
+
+
 def flower(args):
     broka = conf.get('celery', 'BROKER_URL')
     address = '--address={}'.format(args.hostname)
@@ -1240,6 +1347,31 @@ class CLIFactory(object):
         'task_params': Arg(
             ("-tp", "--task_params"),
             help="Sends a JSON params dict to the task"),
+        # connections
+        'list_connections': Arg(
+            ('-l', '--list'),
+            help='List all connections',
+            action='store_true'),
+        'add_connection': Arg(
+            ('-a', '--add'),
+            help='Add a connection',
+            action='store_true'),
+        'delete_connection': Arg(
+            ('-d', '--delete'),
+            help='Delete a connection',
+            action='store_true'),
+        'conn_id': Arg(
+            ('--conn_id',),
+            help='Connection id, required to add/delete a connection',
+            type=str),
+        'conn_uri': Arg(
+            ('--conn_uri',),
+            help='Connection URI, required to add a connection',
+            type=str),
+        'conn_extra': Arg(
+            ('--conn_extra',),
+            help='Connection `Extra` field, optional when adding a connection',
+            type=str),
     }
     subparsers = (
         {
@@ -1348,7 +1480,7 @@ class CLIFactory(object):
             'func': upgradedb,
             'help': "Upgrade the metadata database to latest version",
             'args': tuple(),
-        }, {
+        },{
             'func': scheduler,
             'help': "Start a scheduler instance",
             'args': ('dag_id_opt', 'subdir', 'run_duration', 'num_runs',
@@ -1368,6 +1500,11 @@ class CLIFactory(object):
             'func': version,
             'help': "Show the version",
             'args': tuple(),
+        }, {
+            'func': connections,
+            'help': "List/Add/Delete connections",
+            'args': ('list_connections', 'add_connection', 'delete_connection',
+                     'conn_id', 'conn_uri', 'conn_extra'),
         },
     )
     subparsers_dict = {sp['func'].__name__: sp for sp in subparsers}

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5092642/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 15bbc30..c6aa56b 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -513,6 +513,30 @@ class Connection(Base):
     is_extra_encrypted = Column(Boolean, unique=False, default=False)
     _extra = Column('extra', String(5000))
 
+    _types = [
+        ('fs', 'File (path)'),
+        ('ftp', 'FTP',),
+        ('google_cloud_platform', 'Google Cloud Platform'),
+        ('hdfs', 'HDFS',),
+        ('http', 'HTTP',),
+        ('hive_cli', 'Hive Client Wrapper',),
+        ('hive_metastore', 'Hive Metastore Thrift',),
+        ('hiveserver2', 'Hive Server 2 Thrift',),
+        ('jdbc', 'Jdbc Connection',),
+        ('mysql', 'MySQL',),
+        ('postgres', 'Postgres',),
+        ('oracle', 'Oracle',),
+        ('vertica', 'Vertica',),
+        ('presto', 'Presto',),
+        ('s3', 'S3',),
+        ('samba', 'Samba',),
+        ('sqlite', 'Sqlite',),
+        ('ssh', 'SSH',),
+        ('cloudant', 'IBM Cloudant',),
+        ('mssql', 'Microsoft SQL Server'),
+        ('mesos_framework-id', 'Mesos Framework ID'),
+    ]
+
     def __init__(
             self, conn_id=None, conn_type=None,
             host=None, login=None, password=None,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5092642/airflow/www/views.py
----------------------------------------------------------------------
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 1370a06..3614e45 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -2422,29 +2422,7 @@ class ConnectionModelView(wwwutils.SuperUserMixin, AirflowModelView):
 
     }
     form_choices = {
-        'conn_type': [
-            ('fs', 'File (path)'),
-            ('ftp', 'FTP',),
-            ('google_cloud_platform', 'Google Cloud Platform'),
-            ('hdfs', 'HDFS',),
-            ('http', 'HTTP',),
-            ('hive_cli', 'Hive Client Wrapper',),
-            ('hive_metastore', 'Hive Metastore Thrift',),
-            ('hiveserver2', 'Hive Server 2 Thrift',),
-            ('jdbc', 'Jdbc Connection',),
-            ('mysql', 'MySQL',),
-            ('postgres', 'Postgres',),
-            ('oracle', 'Oracle',),
-            ('vertica', 'Vertica',),
-            ('presto', 'Presto',),
-            ('s3', 'S3',),
-            ('samba', 'Samba',),
-            ('sqlite', 'Sqlite',),
-            ('ssh', 'SSH',),
-            ('cloudant', 'IBM Cloudant',),
-            ('mssql', 'Microsoft SQL Server'),
-            ('mesos_framework-id', 'Mesos Framework ID'),
-        ]
+        'conn_type': models.Connection._types
     }
 
     def on_model_change(self, form, model, is_created):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5092642/tests/core.py
----------------------------------------------------------------------
diff --git a/tests/core.py b/tests/core.py
index cffdc1f..bafcd7e 100644
--- a/tests/core.py
+++ b/tests/core.py
@@ -917,6 +917,194 @@ class CliTests(unittest.TestCase):
     def test_cli_initdb(self):
         cli.initdb(self.parser.parse_args(['initdb']))
 
+    def test_cli_connections_list(self):
+        with mock.patch('sys.stdout',
+                        new_callable=six.StringIO) as mock_stdout:
+            cli.connections(self.parser.parse_args(['connections', '--list']))
+            stdout = mock_stdout.getvalue()
+        conns = [[x.strip("'") for x in re.findall("'\w+'", line)[:2]]
+                  for ii, line in enumerate(stdout.split('\n'))
+                  if ii % 2 == 1]
+        conns = [conn for conn in conns if len(conn) > 0]
+
+        # Assert that some of the connections are present in the output as
+        # expected:
+        self.assertIn(['aws_default', 'aws'], conns)
+        self.assertIn(['beeline_default', 'beeline'], conns)
+        self.assertIn(['bigquery_default', 'bigquery'], conns)
+        self.assertIn(['emr_default', 'emr'], conns)
+        self.assertIn(['mssql_default', 'mssql'], conns)
+        self.assertIn(['mysql_default', 'mysql'], conns)
+        self.assertIn(['postgres_default', 'postgres'], conns)
+
+        # Attempt to list connections with invalid cli args
+        with mock.patch('sys.stdout',
+                        new_callable=six.StringIO) as mock_stdout:
+            cli.connections(self.parser.parse_args(
+                ['connections', '--list', '--conn_id=fake',
+                 '--conn_uri=fake-uri']))
+            stdout = mock_stdout.getvalue()
+
+        # Check list attempt stdout
+        lines = [l for l in stdout.split('\n') if len(l) > 0]
+        self.assertListEqual(lines, [
+            ("\tThe following args are not compatible with the " +
+             "--list flag: ['conn_id', 'conn_uri']"),
+        ])
+
+    def test_cli_connections_add_delete(self):
+        # Add connections:
+        uri = 'postgresql://airflow:airflow@host:5432/airflow'
+        with mock.patch('sys.stdout',
+                        new_callable=six.StringIO) as mock_stdout:
+            cli.connections(self.parser.parse_args(
+                ['connections', '--add', '--conn_id=new1',
+                 '--conn_uri=%s' % uri]))
+            cli.connections(self.parser.parse_args(
+                ['connections', '-a', '--conn_id=new2',
+                 '--conn_uri=%s' % uri]))
+            cli.connections(self.parser.parse_args(
+                ['connections', '--add', '--conn_id=new3',
+                 '--conn_uri=%s' % uri, '--conn_extra', "{'extra': 'yes'}"]))
+            cli.connections(self.parser.parse_args(
+                ['connections', '-a', '--conn_id=new4',
+                 '--conn_uri=%s' % uri, '--conn_extra', "{'extra': 'yes'}"]))
+            stdout = mock_stdout.getvalue()
+
+        # Check addition stdout
+        lines = [l for l in stdout.split('\n') if len(l) > 0]
+        self.assertListEqual(lines, [
+            ("\tSuccessfully added `conn_id`=new1 : " +
+             "postgresql://airflow:airflow@host:5432/airflow"),
+            ("\tSuccessfully added `conn_id`=new2 : " +
+             "postgresql://airflow:airflow@host:5432/airflow"),
+            ("\tSuccessfully added `conn_id`=new3 : " +
+             "postgresql://airflow:airflow@host:5432/airflow"),
+            ("\tSuccessfully added `conn_id`=new4 : " +
+             "postgresql://airflow:airflow@host:5432/airflow"),
+        ])
+
+        # Attempt to add duplicate
+        with mock.patch('sys.stdout',
+                        new_callable=six.StringIO) as mock_stdout:
+            cli.connections(self.parser.parse_args(
+                ['connections', '--add', '--conn_id=new1',
+                 '--conn_uri=%s' % uri]))
+            stdout = mock_stdout.getvalue()
+
+        # Check stdout for addition attempt
+        lines = [l for l in stdout.split('\n') if len(l) > 0]
+        self.assertListEqual(lines, [
+            "\tA connection with `conn_id`=new1 already exists",
+        ])
+
+        # Attempt to add without providing conn_id
+        with mock.patch('sys.stdout',
+                        new_callable=six.StringIO) as mock_stdout:
+            cli.connections(self.parser.parse_args(
+                ['connections', '--add', '--conn_uri=%s' % uri]))
+            stdout = mock_stdout.getvalue()
+
+        # Check stdout for addition attempt
+        lines = [l for l in stdout.split('\n') if len(l) > 0]
+        self.assertListEqual(lines, [
+            ("\tThe following args are required to add a connection:" +
+             " ['conn_id']"),
+        ])
+
+        # Attempt to add without providing conn_uri
+        with mock.patch('sys.stdout',
+                        new_callable=six.StringIO) as mock_stdout:
+            cli.connections(self.parser.parse_args(
+                ['connections', '--add', '--conn_id=new']))
+            stdout = mock_stdout.getvalue()
+
+        # Check stdout for addition attempt
+        lines = [l for l in stdout.split('\n') if len(l) > 0]
+        self.assertListEqual(lines, [
+            ("\tThe following args are required to add a connection:" +
+             " ['conn_uri']"),
+        ])
+
+        # Prepare to add connections
+        session = settings.Session()
+        extra = {'new1': None,
+                 'new2': None,
+                 'new3': "{'extra': 'yes'}",
+                 'new4': "{'extra': 'yes'}"}
+
+        # Add connections
+        for conn_id in ['new1', 'new2', 'new3', 'new4']:
+            result = (session
+                      .query(models.Connection)
+                      .filter(models.Connection.conn_id == conn_id)
+                      .first())
+            result = (result.conn_id, result.conn_type, result.host,
+                      result.port, result.get_extra())
+            self.assertEqual(result, (conn_id, 'postgres', 'host', 5432,
+                                      extra[conn_id]))
+
+        # Delete connections
+        with mock.patch('sys.stdout',
+                        new_callable=six.StringIO) as mock_stdout:
+            cli.connections(self.parser.parse_args(
+                ['connections', '--delete', '--conn_id=new1']))
+            cli.connections(self.parser.parse_args(
+                ['connections', '--delete', '--conn_id=new2']))
+            cli.connections(self.parser.parse_args(
+                ['connections', '--delete', '--conn_id=new3']))
+            cli.connections(self.parser.parse_args(
+                ['connections', '--delete', '--conn_id=new4']))
+            stdout = mock_stdout.getvalue()
+
+        # Check deletion stdout
+        lines = [l for l in stdout.split('\n') if len(l) > 0]
+        self.assertListEqual(lines, [
+            "\tSuccessfully deleted `conn_id`=new1",
+            "\tSuccessfully deleted `conn_id`=new2",
+            "\tSuccessfully deleted `conn_id`=new3",
+            "\tSuccessfully deleted `conn_id`=new4"
+        ])
+
+        # Check deletions
+        for conn_id in ['new1', 'new2', 'new3', 'new4']:
+            result = (session
+                      .query(models.Connection)
+                      .filter(models.Connection.conn_id == conn_id)
+                      .first())
+
+            self.assertTrue(result is None)
+
+        # Attempt to delete a non-existing connnection
+        with mock.patch('sys.stdout',
+                        new_callable=six.StringIO) as mock_stdout:
+            cli.connections(self.parser.parse_args(
+                ['connections', '--delete', '--conn_id=fake']))
+            stdout = mock_stdout.getvalue()
+
+        # Check deletion attempt stdout
+        lines = [l for l in stdout.split('\n') if len(l) > 0]
+        self.assertListEqual(lines, [
+            "\tDid not find a connection with `conn_id`=fake",
+        ])
+
+        # Attempt to delete with invalid cli args
+        with mock.patch('sys.stdout',
+                        new_callable=six.StringIO) as mock_stdout:
+            cli.connections(self.parser.parse_args(
+                ['connections', '--delete', '--conn_id=fake',
+                 '--conn_uri=%s' % uri]))
+            stdout = mock_stdout.getvalue()
+
+        # Check deletion attempt stdout
+        lines = [l for l in stdout.split('\n') if len(l) > 0]
+        self.assertListEqual(lines, [
+            ("\tThe following args are not compatible with the " +
+             "--delete flag: ['conn_uri']"),
+        ])
+
+        session.close()
+
     def test_cli_test(self):
         cli.test(self.parser.parse_args([
             'test', 'example_bash_operator', 'runme_0',