You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by sa...@apache.org on 2016/10/06 03:18:58 UTC
incubator-airflow git commit: [AIRFLOW-358][AIRFLOW-430] Add
`connections` cli
Repository: incubator-airflow
Updated Branches:
refs/heads/master fe5eaabb2 -> f50926425
[AIRFLOW-358][AIRFLOW-430] Add `connections` cli
This PR adds a `connections` command to Airflow's
CLI. The new
`connections` command hopes to make it easier to
automate Airflow's
deployment to different environments. Users won't
have to directly
interact with the database or input connections
manually on the UI.
Closes #1802 from PedroMDuarte/connections-cli
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/f5092642
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/f5092642
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/f5092642
Branch: refs/heads/master
Commit: f50926425389e42c8dd42eb8d6d5eee5ae21abc9
Parents: fe5eaab
Author: Pedro M Duarte <pm...@gmail.com>
Authored: Wed Oct 5 20:19:04 2016 -0700
Committer: Siddharth Anand <si...@yahoo.com>
Committed: Wed Oct 5 20:19:04 2016 -0700
----------------------------------------------------------------------
airflow/bin/cli.py | 141 +++++++++++++++++++++++++++++++++-
airflow/models.py | 24 ++++++
airflow/www/views.py | 24 +-----
tests/core.py | 188 ++++++++++++++++++++++++++++++++++++++++++++++
4 files changed, 352 insertions(+), 25 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5092642/airflow/bin/cli.py
----------------------------------------------------------------------
diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py
index 66d8a26..324f869 100755
--- a/airflow/bin/cli.py
+++ b/airflow/bin/cli.py
@@ -15,6 +15,7 @@
from __future__ import print_function
import logging
+import reprlib
import os
import subprocess
import textwrap
@@ -26,6 +27,7 @@ from builtins import input
from collections import namedtuple
from dateutil.parser import parse as parsedate
import json
+from tabulate import tabulate
import daemon
from daemon.pidfile import TimeoutPIDLockFile
@@ -43,7 +45,7 @@ from airflow.exceptions import AirflowException
from airflow.executors import DEFAULT_EXECUTOR
from airflow.models import (DagModel, DagBag, TaskInstance,
DagPickle, DagRun, Variable, DagStat,
- Pool)
+ Pool, Connection)
from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS)
from airflow.utils import db as db_utils
from airflow.utils import logging as logging_utils
@@ -51,6 +53,7 @@ from airflow.utils.state import State
from airflow.www.app import cached_app
from sqlalchemy import func
+from sqlalchemy.orm import exc
DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
@@ -889,10 +892,114 @@ def upgradedb(args): # noqa
session.add(DagStat(dag_id=dag_id, state=state, count=count))
session.commit()
+
def version(args): # noqa
print(settings.HEADER + " v" + airflow.__version__)
+def connections(args):
+ if args.list:
+ # Check that no other flags were passed to the command
+ invalid_args = list()
+ for arg in ['conn_id', 'conn_uri', 'conn_extra']:
+ if getattr(args, arg) is not None:
+ invalid_args.append(arg)
+ if invalid_args:
+ msg = ('\n\tThe following args are not compatible with the ' +
+ '--list flag: {invalid!r}\n')
+ msg = msg.format(invalid=invalid_args)
+ print(msg)
+ return
+
+ session = settings.Session()
+ conns = session.query(Connection.conn_id, Connection.conn_type,
+ Connection.host, Connection.port,
+ Connection.is_encrypted,
+ Connection.is_extra_encrypted,
+ Connection.extra).all()
+ conns = [map(reprlib.repr, conn) for conn in conns]
+ print(tabulate(conns, ['Conn Id', 'Conn Type', 'Host', 'Port',
+ 'Is Encrypted', 'Is Extra Encrypted', 'Extra'],
+ tablefmt="fancy_grid"))
+ return
+
+ if args.delete:
+ # Check that only the `conn_id` arg was passed to the command
+ invalid_args = list()
+ for arg in ['conn_uri', 'conn_extra']:
+ if getattr(args, arg) is not None:
+ invalid_args.append(arg)
+ if invalid_args:
+ msg = ('\n\tThe following args are not compatible with the ' +
+ '--delete flag: {invalid!r}\n')
+ msg = msg.format(invalid=invalid_args)
+ print(msg)
+ return
+
+ if args.conn_id is None:
+ print('\n\tTo delete a connection, you Must provide a value for ' +
+ 'the --conn_id flag.\n')
+ return
+
+ session = settings.Session()
+ try:
+ to_delete = (session
+ .query(Connection)
+ .filter(Connection.conn_id == args.conn_id)
+ .one())
+ except exc.NoResultFound:
+ msg = '\n\tDid not find a connection with `conn_id`={conn_id}\n'
+ msg = msg.format(conn_id=args.conn_id)
+ print(msg)
+ return
+ except exc.MultipleResultsFound:
+ msg = ('\n\tFound more than one connection with ' +
+ '`conn_id`={conn_id}\n')
+ msg = msg.format(conn_id=args.conn_id)
+ print(msg)
+ return
+ else:
+ deleted_conn_id = to_delete.conn_id
+ session.delete(to_delete)
+ session.commit()
+ msg = '\n\tSuccessfully deleted `conn_id`={conn_id}\n'
+ msg = msg.format(conn_id=deleted_conn_id)
+ print(msg)
+ return
+
+ if args.add:
+ # Check that the conn_id and conn_uri args were passed to the command:
+ missing_args = list()
+ for arg in ['conn_id', 'conn_uri']:
+ if getattr(args, arg) is None:
+ missing_args.append(arg)
+ if missing_args:
+ msg = ('\n\tThe following args are required to add a connection:' +
+ ' {missing!r}\n'.format(missing=missing_args))
+ print(msg)
+ return
+
+ new_conn = Connection(conn_id=args.conn_id, uri=args.conn_uri)
+ if args.conn_extra is not None:
+ new_conn.set_extra(args.conn_extra)
+
+ session = settings.Session()
+ if not (session
+ .query(Connection)
+ .filter(Connection.conn_id == new_conn.conn_id).first()):
+ session.add(new_conn)
+ session.commit()
+ msg = '\n\tSuccessfully added `conn_id`={conn_id} : {uri}\n'
+ msg = msg.format(conn_id=new_conn.conn_id, uri=args.conn_uri)
+ print(msg)
+ else:
+ msg = '\n\tA connection with `conn_id`={conn_id} already exists\n'
+ msg = msg.format(conn_id=new_conn.conn_id)
+ print(msg)
+
+ return
+
+
def flower(args):
broka = conf.get('celery', 'BROKER_URL')
address = '--address={}'.format(args.hostname)
@@ -1240,6 +1347,31 @@ class CLIFactory(object):
'task_params': Arg(
("-tp", "--task_params"),
help="Sends a JSON params dict to the task"),
+ # connections
+ 'list_connections': Arg(
+ ('-l', '--list'),
+ help='List all connections',
+ action='store_true'),
+ 'add_connection': Arg(
+ ('-a', '--add'),
+ help='Add a connection',
+ action='store_true'),
+ 'delete_connection': Arg(
+ ('-d', '--delete'),
+ help='Delete a connection',
+ action='store_true'),
+ 'conn_id': Arg(
+ ('--conn_id',),
+ help='Connection id, required to add/delete a connection',
+ type=str),
+ 'conn_uri': Arg(
+ ('--conn_uri',),
+ help='Connection URI, required to add a connection',
+ type=str),
+ 'conn_extra': Arg(
+ ('--conn_extra',),
+ help='Connection `Extra` field, optional when adding a connection',
+ type=str),
}
subparsers = (
{
@@ -1348,7 +1480,7 @@ class CLIFactory(object):
'func': upgradedb,
'help': "Upgrade the metadata database to latest version",
'args': tuple(),
- }, {
+ },{
'func': scheduler,
'help': "Start a scheduler instance",
'args': ('dag_id_opt', 'subdir', 'run_duration', 'num_runs',
@@ -1368,6 +1500,11 @@ class CLIFactory(object):
'func': version,
'help': "Show the version",
'args': tuple(),
+ }, {
+ 'func': connections,
+ 'help': "List/Add/Delete connections",
+ 'args': ('list_connections', 'add_connection', 'delete_connection',
+ 'conn_id', 'conn_uri', 'conn_extra'),
},
)
subparsers_dict = {sp['func'].__name__: sp for sp in subparsers}
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5092642/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 15bbc30..c6aa56b 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -513,6 +513,30 @@ class Connection(Base):
is_extra_encrypted = Column(Boolean, unique=False, default=False)
_extra = Column('extra', String(5000))
+ _types = [
+ ('fs', 'File (path)'),
+ ('ftp', 'FTP',),
+ ('google_cloud_platform', 'Google Cloud Platform'),
+ ('hdfs', 'HDFS',),
+ ('http', 'HTTP',),
+ ('hive_cli', 'Hive Client Wrapper',),
+ ('hive_metastore', 'Hive Metastore Thrift',),
+ ('hiveserver2', 'Hive Server 2 Thrift',),
+ ('jdbc', 'Jdbc Connection',),
+ ('mysql', 'MySQL',),
+ ('postgres', 'Postgres',),
+ ('oracle', 'Oracle',),
+ ('vertica', 'Vertica',),
+ ('presto', 'Presto',),
+ ('s3', 'S3',),
+ ('samba', 'Samba',),
+ ('sqlite', 'Sqlite',),
+ ('ssh', 'SSH',),
+ ('cloudant', 'IBM Cloudant',),
+ ('mssql', 'Microsoft SQL Server'),
+ ('mesos_framework-id', 'Mesos Framework ID'),
+ ]
+
def __init__(
self, conn_id=None, conn_type=None,
host=None, login=None, password=None,
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5092642/airflow/www/views.py
----------------------------------------------------------------------
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 1370a06..3614e45 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -2422,29 +2422,7 @@ class ConnectionModelView(wwwutils.SuperUserMixin, AirflowModelView):
}
form_choices = {
- 'conn_type': [
- ('fs', 'File (path)'),
- ('ftp', 'FTP',),
- ('google_cloud_platform', 'Google Cloud Platform'),
- ('hdfs', 'HDFS',),
- ('http', 'HTTP',),
- ('hive_cli', 'Hive Client Wrapper',),
- ('hive_metastore', 'Hive Metastore Thrift',),
- ('hiveserver2', 'Hive Server 2 Thrift',),
- ('jdbc', 'Jdbc Connection',),
- ('mysql', 'MySQL',),
- ('postgres', 'Postgres',),
- ('oracle', 'Oracle',),
- ('vertica', 'Vertica',),
- ('presto', 'Presto',),
- ('s3', 'S3',),
- ('samba', 'Samba',),
- ('sqlite', 'Sqlite',),
- ('ssh', 'SSH',),
- ('cloudant', 'IBM Cloudant',),
- ('mssql', 'Microsoft SQL Server'),
- ('mesos_framework-id', 'Mesos Framework ID'),
- ]
+ 'conn_type': models.Connection._types
}
def on_model_change(self, form, model, is_created):
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5092642/tests/core.py
----------------------------------------------------------------------
diff --git a/tests/core.py b/tests/core.py
index cffdc1f..bafcd7e 100644
--- a/tests/core.py
+++ b/tests/core.py
@@ -917,6 +917,194 @@ class CliTests(unittest.TestCase):
def test_cli_initdb(self):
cli.initdb(self.parser.parse_args(['initdb']))
+ def test_cli_connections_list(self):
+ with mock.patch('sys.stdout',
+ new_callable=six.StringIO) as mock_stdout:
+ cli.connections(self.parser.parse_args(['connections', '--list']))
+ stdout = mock_stdout.getvalue()
+ conns = [[x.strip("'") for x in re.findall("'\w+'", line)[:2]]
+ for ii, line in enumerate(stdout.split('\n'))
+ if ii % 2 == 1]
+ conns = [conn for conn in conns if len(conn) > 0]
+
+ # Assert that some of the connections are present in the output as
+ # expected:
+ self.assertIn(['aws_default', 'aws'], conns)
+ self.assertIn(['beeline_default', 'beeline'], conns)
+ self.assertIn(['bigquery_default', 'bigquery'], conns)
+ self.assertIn(['emr_default', 'emr'], conns)
+ self.assertIn(['mssql_default', 'mssql'], conns)
+ self.assertIn(['mysql_default', 'mysql'], conns)
+ self.assertIn(['postgres_default', 'postgres'], conns)
+
+ # Attempt to list connections with invalid cli args
+ with mock.patch('sys.stdout',
+ new_callable=six.StringIO) as mock_stdout:
+ cli.connections(self.parser.parse_args(
+ ['connections', '--list', '--conn_id=fake',
+ '--conn_uri=fake-uri']))
+ stdout = mock_stdout.getvalue()
+
+ # Check list attempt stdout
+ lines = [l for l in stdout.split('\n') if len(l) > 0]
+ self.assertListEqual(lines, [
+ ("\tThe following args are not compatible with the " +
+ "--list flag: ['conn_id', 'conn_uri']"),
+ ])
+
+ def test_cli_connections_add_delete(self):
+ # Add connections:
+ uri = 'postgresql://airflow:airflow@host:5432/airflow'
+ with mock.patch('sys.stdout',
+ new_callable=six.StringIO) as mock_stdout:
+ cli.connections(self.parser.parse_args(
+ ['connections', '--add', '--conn_id=new1',
+ '--conn_uri=%s' % uri]))
+ cli.connections(self.parser.parse_args(
+ ['connections', '-a', '--conn_id=new2',
+ '--conn_uri=%s' % uri]))
+ cli.connections(self.parser.parse_args(
+ ['connections', '--add', '--conn_id=new3',
+ '--conn_uri=%s' % uri, '--conn_extra', "{'extra': 'yes'}"]))
+ cli.connections(self.parser.parse_args(
+ ['connections', '-a', '--conn_id=new4',
+ '--conn_uri=%s' % uri, '--conn_extra', "{'extra': 'yes'}"]))
+ stdout = mock_stdout.getvalue()
+
+ # Check addition stdout
+ lines = [l for l in stdout.split('\n') if len(l) > 0]
+ self.assertListEqual(lines, [
+ ("\tSuccessfully added `conn_id`=new1 : " +
+ "postgresql://airflow:airflow@host:5432/airflow"),
+ ("\tSuccessfully added `conn_id`=new2 : " +
+ "postgresql://airflow:airflow@host:5432/airflow"),
+ ("\tSuccessfully added `conn_id`=new3 : " +
+ "postgresql://airflow:airflow@host:5432/airflow"),
+ ("\tSuccessfully added `conn_id`=new4 : " +
+ "postgresql://airflow:airflow@host:5432/airflow"),
+ ])
+
+ # Attempt to add duplicate
+ with mock.patch('sys.stdout',
+ new_callable=six.StringIO) as mock_stdout:
+ cli.connections(self.parser.parse_args(
+ ['connections', '--add', '--conn_id=new1',
+ '--conn_uri=%s' % uri]))
+ stdout = mock_stdout.getvalue()
+
+ # Check stdout for addition attempt
+ lines = [l for l in stdout.split('\n') if len(l) > 0]
+ self.assertListEqual(lines, [
+ "\tA connection with `conn_id`=new1 already exists",
+ ])
+
+ # Attempt to add without providing conn_id
+ with mock.patch('sys.stdout',
+ new_callable=six.StringIO) as mock_stdout:
+ cli.connections(self.parser.parse_args(
+ ['connections', '--add', '--conn_uri=%s' % uri]))
+ stdout = mock_stdout.getvalue()
+
+ # Check stdout for addition attempt
+ lines = [l for l in stdout.split('\n') if len(l) > 0]
+ self.assertListEqual(lines, [
+ ("\tThe following args are required to add a connection:" +
+ " ['conn_id']"),
+ ])
+
+ # Attempt to add without providing conn_uri
+ with mock.patch('sys.stdout',
+ new_callable=six.StringIO) as mock_stdout:
+ cli.connections(self.parser.parse_args(
+ ['connections', '--add', '--conn_id=new']))
+ stdout = mock_stdout.getvalue()
+
+ # Check stdout for addition attempt
+ lines = [l for l in stdout.split('\n') if len(l) > 0]
+ self.assertListEqual(lines, [
+ ("\tThe following args are required to add a connection:" +
+ " ['conn_uri']"),
+ ])
+
+ # Prepare to add connections
+ session = settings.Session()
+ extra = {'new1': None,
+ 'new2': None,
+ 'new3': "{'extra': 'yes'}",
+ 'new4': "{'extra': 'yes'}"}
+
+ # Add connections
+ for conn_id in ['new1', 'new2', 'new3', 'new4']:
+ result = (session
+ .query(models.Connection)
+ .filter(models.Connection.conn_id == conn_id)
+ .first())
+ result = (result.conn_id, result.conn_type, result.host,
+ result.port, result.get_extra())
+ self.assertEqual(result, (conn_id, 'postgres', 'host', 5432,
+ extra[conn_id]))
+
+ # Delete connections
+ with mock.patch('sys.stdout',
+ new_callable=six.StringIO) as mock_stdout:
+ cli.connections(self.parser.parse_args(
+ ['connections', '--delete', '--conn_id=new1']))
+ cli.connections(self.parser.parse_args(
+ ['connections', '--delete', '--conn_id=new2']))
+ cli.connections(self.parser.parse_args(
+ ['connections', '--delete', '--conn_id=new3']))
+ cli.connections(self.parser.parse_args(
+ ['connections', '--delete', '--conn_id=new4']))
+ stdout = mock_stdout.getvalue()
+
+ # Check deletion stdout
+ lines = [l for l in stdout.split('\n') if len(l) > 0]
+ self.assertListEqual(lines, [
+ "\tSuccessfully deleted `conn_id`=new1",
+ "\tSuccessfully deleted `conn_id`=new2",
+ "\tSuccessfully deleted `conn_id`=new3",
+ "\tSuccessfully deleted `conn_id`=new4"
+ ])
+
+ # Check deletions
+ for conn_id in ['new1', 'new2', 'new3', 'new4']:
+ result = (session
+ .query(models.Connection)
+ .filter(models.Connection.conn_id == conn_id)
+ .first())
+
+ self.assertTrue(result is None)
+
+ # Attempt to delete a non-existing connnection
+ with mock.patch('sys.stdout',
+ new_callable=six.StringIO) as mock_stdout:
+ cli.connections(self.parser.parse_args(
+ ['connections', '--delete', '--conn_id=fake']))
+ stdout = mock_stdout.getvalue()
+
+ # Check deletion attempt stdout
+ lines = [l for l in stdout.split('\n') if len(l) > 0]
+ self.assertListEqual(lines, [
+ "\tDid not find a connection with `conn_id`=fake",
+ ])
+
+ # Attempt to delete with invalid cli args
+ with mock.patch('sys.stdout',
+ new_callable=six.StringIO) as mock_stdout:
+ cli.connections(self.parser.parse_args(
+ ['connections', '--delete', '--conn_id=fake',
+ '--conn_uri=%s' % uri]))
+ stdout = mock_stdout.getvalue()
+
+ # Check deletion attempt stdout
+ lines = [l for l in stdout.split('\n') if len(l) > 0]
+ self.assertListEqual(lines, [
+ ("\tThe following args are not compatible with the " +
+ "--delete flag: ['conn_uri']"),
+ ])
+
+ session.close()
+
def test_cli_test(self):
cli.test(self.parser.parse_args([
'test', 'example_bash_operator', 'runme_0',