You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@airflow.apache.org by GitBox <gi...@apache.org> on 2018/08/02 14:27:21 UTC

[GitHub] caddac closed pull request #3681: [AIRFLOW-2840] - add update connections cli option

caddac closed pull request #3681: [AIRFLOW-2840] - add update connections cli option
URL: https://github.com/apache/incubator-airflow/pull/3681
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py
index e2001789d9..6b7c68efaf 100644
--- a/airflow/bin/cli.py
+++ b/airflow/bin/cli.py
@@ -1029,7 +1029,7 @@ def upgradedb(args):  # noqa
     if not ds_rows:
         qry = (
             session.query(DagRun.dag_id, DagRun.state, func.count('*'))
-                   .group_by(DagRun.dag_id, DagRun.state)
+            .group_by(DagRun.dag_id, DagRun.state)
         )
         for dag_id, state, count in qry:
             session.add(DagStat(dag_id=dag_id, state=state, count=count))
@@ -1095,7 +1095,7 @@ def connections(args):
 
         session = settings.Session()
         try:
-            to_delete = (session
+            to_update = (session
                          .query(Connection)
                          .filter(Connection.conn_id == args.conn_id)
                          .one())
@@ -1111,8 +1111,8 @@ def connections(args):
             print(msg)
             return
         else:
-            deleted_conn_id = to_delete.conn_id
-            session.delete(to_delete)
+            deleted_conn_id = to_update.conn_id
+            session.delete(to_update)
             session.commit()
             msg = '\n\tSuccessfully deleted `conn_id`={conn_id}\n'
             msg = msg.format(conn_id=deleted_conn_id)
@@ -1158,18 +1158,18 @@ def connections(args):
 
         session = settings.Session()
         if not (session.query(Connection)
-                       .filter(Connection.conn_id == new_conn.conn_id).first()):
+                .filter(Connection.conn_id == new_conn.conn_id).first()):
             session.add(new_conn)
             session.commit()
             msg = '\n\tSuccessfully added `conn_id`={conn_id} : {uri}\n'
             msg = msg.format(conn_id=new_conn.conn_id,
                              uri=args.conn_uri or
                              urlunparse((args.conn_type,
-                                        '{login}:{password}@{host}:{port}'
-                                         .format(login=args.conn_login or '',
-                                                 password=args.conn_password or '',
-                                                 host=args.conn_host or '',
-                                                 port=args.conn_port or ''),
+                                         '{login}:{password}@{host}:{port}'
+                                             .format(login=args.conn_login or '',
+                                                     password=args.conn_password or '',
+                                                     host=args.conn_host or '',
+                                                     port=args.conn_port or ''),
                                          args.conn_schema or '', '', '', '')))
             print(msg)
         else:
@@ -1179,6 +1179,111 @@ def connections(args):
 
         return
 
+    if args.update:
+        # Check that the conn_id and conn_uri args were passed to the command:
+        missing_args = list()
+        invalid_args = list()
+        if not args.conn_id:
+            missing_args.append('conn_id')
+        if args.conn_uri:
+            for arg in alternative_conn_specs:
+                if getattr(args, arg) is not None:
+                    invalid_args.append(arg)
+        elif not args.conn_type:
+            missing_args.append('conn_uri or conn_type')
+        if missing_args:
+            msg = ('\n\tThe following args are required to add a connection:' +
+                   ' {missing!r}\n'.format(missing=missing_args))
+            print(msg)
+        if invalid_args:
+            msg = ('\n\tThe following args are not compatible with the ' +
+                   '--add flag and --conn_uri flag: {invalid!r}\n')
+            msg = msg.format(invalid=invalid_args)
+            print(msg)
+        if missing_args or invalid_args:
+            return
+
+        # Delete....
+        session = settings.Session()
+        try:
+            to_update = (session
+                         .query(Connection)
+                         .filter(Connection.conn_id == args.conn_id)
+                         .one())
+
+            to_update.conn_type = args.conn_type or to_update.conn_type
+            to_update.host = args.conn_host or to_update.host
+            to_update.login = args.conn_login or to_update.login
+            to_update.password = args.conn_password or to_update.password
+            to_update.schema = args.conn_schema or to_update.schema
+            to_update.port = args.conn_port or to_update.port
+
+            if args.conn_extra is not None:
+                to_update.set_extra(args.conn_extra)
+
+        except exc.NoResultFound:
+            msg = '\n\tDid not find a connection with `conn_id`={conn_id}\n'
+            msg = msg.format(conn_id=args.conn_id)
+            print(msg)
+            return
+        except exc.MultipleResultsFound:
+            msg = ('\n\tFound more than one connection with ' +
+                   '`conn_id`={conn_id}\n')
+            msg = msg.format(conn_id=args.conn_id)
+            print(msg)
+            return
+        else:
+            # session.save(to_update)
+            session.commit()
+            msg = '\n\tSuccessfully updated `conn_id`={conn_id} : {uri}\n'
+            msg = msg.format(conn_id=to_update.conn_id,
+                             uri=args.conn_uri or
+                             urlunparse((args.conn_type,
+                                         '{login}:{password}@{host}:{port}'
+                                         .format(login=args.conn_login or '',
+                                                 password=args.conn_password or '',
+                                                 host=args.conn_host or '',
+                                                 port=args.conn_port or ''),
+                                         args.conn_schema or '', '', '', '')))
+            print(msg)
+        return
+        #
+        # if args.conn_uri:
+        #     new_conn = Connection(conn_id=args.conn_id, uri=args.conn_uri)
+        # else:
+        #     new_conn = Connection(conn_id=args.conn_id,
+        #                           conn_type=args.conn_type,
+        #                           host=args.conn_host,
+        #                           login=args.conn_login,
+        #                           password=args.conn_password,
+        #                           schema=args.conn_schema,
+        #                           port=args.conn_port)
+        # if args.conn_extra is not None:
+        #     new_conn.set_extra(args.conn_extra)
+        #
+        # session = settings.Session()
+        # if not (session.query(Connection)
+        #     .filter(Connection.conn_id == new_conn.conn_id).first()):
+        #     session.add(new_conn)
+        #     session.commit()
+        #     msg = '\n\tSuccessfully added `conn_id`={conn_id} : {uri}\n'
+        #     msg = msg.format(conn_id=new_conn.conn_id,
+        #                      uri=args.conn_uri or
+        #                          urlunparse((args.conn_type,
+        #                                      '{login}:{password}@{host}:{port}'
+        #                                      .format(login=args.conn_login or '',
+        #                                              password=args.conn_password or '',
+        #                                              host=args.conn_host or '',
+        #                                              port=args.conn_port or ''),
+        #                                      args.conn_schema or '', '', '', '')))
+        #     print(msg)
+        # else:
+        #     msg = '\n\tA connection with `conn_id`={conn_id} already exists\n'
+        #     msg = msg.format(conn_id=new_conn.conn_id)
+        #     print(msg)
+        #
+        # return
+
 
 @cli_utils.action_logging
 def flower(args):
@@ -1379,7 +1484,7 @@ def list_dag_runs(args, dag=None):
 
 
 @cli_utils.action_logging
-def sync_perm(args): # noqa
+def sync_perm(args):  # noqa
     if settings.RBAC:
         appbuilder = cached_appbuilder()
         print('Update permission, view-menu for all existing roles')
@@ -1745,6 +1850,10 @@ class CLIFactory(object):
             ('-d', '--delete'),
             help='Delete a connection',
             action='store_true'),
+        'update_connection': Arg(
+            ('-u', '--update'),
+            help='Update a connection',
+            action='store_true'),
         'conn_id': Arg(
             ('--conn_id',),
             help='Connection id, required to add/delete a connection',
@@ -1961,9 +2070,10 @@ class CLIFactory(object):
             'args': tuple(),
         }, {
             'func': connections,
-            'help': "List/Add/Delete connections",
-            'args': ('list_connections', 'add_connection', 'delete_connection',
-                     'conn_id', 'conn_uri', 'conn_extra') + tuple(alternative_conn_specs),
+            'help': "List/Add/Update/Delete connections",
+            'args': ('list_connections', 'add_connection', 'update_connection',
+                     'delete_connection', 'conn_id', 'conn_uri',
+                     'conn_extra') + tuple(alternative_conn_specs),
         }, {
             'func': create_user,
             'help': "Create an account for the Web UI",
diff --git a/tests/core.py b/tests/core.py
index 6cfd10b02a..a09cbbe8db 100644
--- a/tests/core.py
+++ b/tests/core.py
@@ -147,7 +147,7 @@ def test_schedule_dag_no_previous_runs(self):
             datetime(2015, 1, 2, 0, 0),
             dag_run.execution_date,
             msg='dag_run.execution_date did not match expectation: {0}'
-            .format(dag_run.execution_date)
+                .format(dag_run.execution_date)
         )
         self.assertEqual(State.RUNNING, dag_run.state)
         self.assertFalse(dag_run.external_trigger)
@@ -1242,6 +1242,70 @@ def test_cli_connections_add_delete(self):
                 self.assertEqual(result, (conn_id, 'google_cloud_platform',
                                           None, None, "{'extra': 'yes'}"))
 
+        new_uri = 'postgresql://airflow:different_password@host:5432/airflow'
+
+        # Update Connections
+        with mock.patch('sys.stdout',
+                        new_callable=six.StringIO) as mock_stdout:
+            cli.connections(self.parser.parse_args(
+                ['connections', '--update', '--conn_id=new1',
+                 '--conn_uri=%s' % new_uri]))
+
+            cli.connections(self.parser.parse_args(
+                ['connections', '-u', '--conn_id=new2',
+                 '--conn_uri=%s' % new_uri]))
+
+            cli.connections(self.parser.parse_args(
+                ['connections', '--update', '--conn_id=new3',
+                 '--conn_uri=%s' % new_uri, '--conn_extra', "{'extra': 'yes'}"]))
+
+            cli.connections(self.parser.parse_args(
+                ['connections', '-u', '--conn_id=new4',
+                 '--conn_uri=%s' % new_uri, '--conn_extra', "{'extra': 'yes'}"]))
+
+            cli.connections(self.parser.parse_args(
+                ['connections', '--update', '--conn_id=new5',
+                 '--conn_type=hive_metastore', '--conn_login=airflow',
+                 '--conn_password=different_password', '--conn_host=host',
+                 '--conn_port=9083', '--conn_schema=airflow']))
+
+            cli.connections(self.parser.parse_args(
+                ['connections', '-u', '--conn_id=new6',
+                 '--conn_uri', "", '--conn_type=google_cloud_platform',
+                 '--conn_extra', "{'extra': 'yes'}"]))
+            stdout = mock_stdout.getvalue()
+
+        # Check addition stdout
+        lines = [l for l in stdout.split('\n') if len(l) > 0]
+        self.assertListEqual(lines, [
+            ("\tSuccessfully updated `conn_id`=new1 : " +
+             "postgresql://airflow:different_password@host:5432/airflow"),
+            ("\tSuccessfully updated `conn_id`=new2 : " +
+             "postgresql://airflow:different_password@host:5432/airflow"),
+            ("\tSuccessfully updated `conn_id`=new3 : " +
+             "postgresql://airflow:different_password@host:5432/airflow"),
+            ("\tSuccessfully updated `conn_id`=new4 : " +
+             "postgresql://airflow:different_password@host:5432/airflow"),
+            ("\tSuccessfully updated `conn_id`=new5 : " +
+             "hive_metastore://airflow:different_password@host:9083/airflow"),
+            ("\tSuccessfully updated `conn_id`=new6 : " +
+             "google_cloud_platform://:@:")
+        ])
+
+        # Attempt to udpate without providing conn_uri
+        with mock.patch('sys.stdout',
+                        new_callable=six.StringIO) as mock_stdout:
+            cli.connections(self.parser.parse_args(
+                ['connections', '--update', '--conn_id=new']))
+            stdout = mock_stdout.getvalue()
+
+        # Check stdout for addition attempt
+        lines = [l for l in stdout.split('\n') if len(l) > 0]
+        self.assertListEqual(lines, [
+            ("\tThe following args are required to add a connection:" +
+             " ['conn_uri or conn_type']"),
+        ])
+
         # Delete connections
         with mock.patch('sys.stdout',
                         new_callable=six.StringIO) as mock_stdout:
@@ -2446,6 +2510,7 @@ def test_get_ha_client(self, mock_get_connections):
         client = HDFSHook().get_conn()
         self.assertIsInstance(client, snakebite.client.HAClient)
 
+
 send_email_test = mock.Mock()
 
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services