You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by di...@apache.org on 2021/04/05 17:55:26 UTC

[airflow] branch master updated: Import connections from a file (#15177)

This is an automated email from the ASF dual-hosted git repository.

dimberman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 7cadb63  Import connections from a file (#15177)
7cadb63 is described below

commit 7cadb63d38900f581b5d81011a1de534fe713c3a
Author: natanweinberger <na...@gmail.com>
AuthorDate: Mon Apr 5 13:55:14 2021 -0400

    Import connections from a file (#15177)
    
    * Add connections import CLI command
    
    * Add tests for CLI connections import
    
    * Add connections import overwrite test
    
    When a connections file contains collisions with existing connections,
    skip them and print a message to stdout indicating that the connection
    was not imported.
    
    * Resolve lint errors
---
 airflow/cli/cli_parser.py                     |  11 ++
 airflow/cli/commands/connection_command.py    |  41 ++++++
 tests/cli/commands/test_connection_command.py | 173 ++++++++++++++++++++++++++
 3 files changed, 225 insertions(+)

diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py
index b5c709f..0bf92cc 100644
--- a/airflow/cli/cli_parser.py
+++ b/airflow/cli/cli_parser.py
@@ -602,6 +602,7 @@ ARG_CONN_EXPORT = Arg(
 ARG_CONN_EXPORT_FORMAT = Arg(
     ('--format',), help='Format of the connections data in file', type=str, choices=['json', 'yaml', 'env']
 )
+ARG_CONN_IMPORT = Arg(("file",), help="Import connections from a file")
 
 # providers
 ARG_PROVIDER_NAME = Arg(
@@ -1200,6 +1201,16 @@ CONNECTIONS_COMMANDS = (
             ARG_CONN_EXPORT_FORMAT,
         ),
     ),
+    ActionCommand(
+        name='import',
+        help='Import connections from a file',
+        description=(
+            "Connections can be imported from the output of the export command.\n"
+            "The filetype must by json, yaml or env and will be automatically inferred."
+        ),
+        func=lazy_load_command('airflow.cli.commands.connection_command.connections_import'),
+        args=(ARG_CONN_IMPORT,),
+    ),
 )
 PROVIDERS_COMMANDS = (
     ActionCommand(
diff --git a/airflow/cli/commands/connection_command.py b/airflow/cli/commands/connection_command.py
index f35ed36..6e45e2a 100644
--- a/airflow/cli/commands/connection_command.py
+++ b/airflow/cli/commands/connection_command.py
@@ -29,6 +29,7 @@ from airflow.cli.simple_table import AirflowConsole
 from airflow.exceptions import AirflowNotFoundException
 from airflow.hooks.base import BaseHook
 from airflow.models import Connection
+from airflow.secrets.local_filesystem import _create_connection, load_connections_dict
 from airflow.utils import cli as cli_utils
 from airflow.utils.cli import suppress_logs_and_warning
 from airflow.utils.session import create_session
@@ -234,3 +235,43 @@ def connections_delete(args):
         else:
             session.delete(to_delete)
             print(f"Successfully deleted connection with `conn_id`={to_delete.conn_id}")
+
+
+@cli_utils.action_logging
+def connections_import(args):
+    """Imports connections from a given file"""
+    if os.path.exists(args.file):
+        _import_helper(args.file)
+    else:
+        raise SystemExit("Missing connections file.")
+
+
+def _import_helper(file_path):
+    """Helps import connections from a file"""
+    connections_dict = load_connections_dict(file_path)
+    with create_session() as session:
+        for conn_id, conn_values in connections_dict.items():
+            if session.query(Connection).filter(Connection.conn_id == conn_id).first():
+                print(f'Could not import connection {conn_id}: connection already exists.')
+                continue
+
+            allowed_fields = [
+                'extra',
+                'description',
+                'conn_id',
+                'login',
+                'conn_type',
+                'host',
+                'password',
+                'schema',
+                'port',
+                'uri',
+                'extra_dejson',
+            ]
+            filtered_connection_values = {
+                key: value for key, value in conn_values.items() if key in allowed_fields
+            }
+            connection = _create_connection(conn_id, filtered_connection_values)
+            session.add(connection)
+            session.commit()
+            print(f'Imported connection {conn_id}')
diff --git a/tests/cli/commands/test_connection_command.py b/tests/cli/commands/test_connection_command.py
index c81ff81..cf27941 100644
--- a/tests/cli/commands/test_connection_command.py
+++ b/tests/cli/commands/test_connection_command.py
@@ -27,6 +27,7 @@ from parameterized import parameterized
 
 from airflow.cli import cli_parser
 from airflow.cli.commands import connection_command
+from airflow.exceptions import AirflowException
 from airflow.models import Connection
 from airflow.utils.db import merge_conn
 from airflow.utils.session import create_session, provide_session
@@ -716,3 +717,175 @@ class TestCliDeleteConnections(unittest.TestCase):
         # Attempt to delete a non-existing connection
         with pytest.raises(SystemExit, match=r"Did not find a connection with `conn_id`=fake"):
             connection_command.connections_delete(self.parser.parse_args(["connections", "delete", "fake"]))
+
+
+class TestCliImportConnections(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        cls.parser = cli_parser.get_parser()
+        clear_db_connections(add_default_connections_back=False)
+
+    @classmethod
+    def tearDownClass(cls):
+        clear_db_connections()
+
+    @mock.patch('os.path.exists')
+    def test_cli_connections_import_should_return_error_if_file_does_not_exist(self, mock_exists):
+        mock_exists.return_value = False
+        filepath = '/does/not/exist.json'
+        with pytest.raises(SystemExit, match=r"Missing connections file."):
+            connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath]))
+
+    @parameterized.expand(
+        [
+            ("sample.jso",),
+            ("sample.yml",),
+            ("sample.environ",),
+        ]
+    )
+    @mock.patch('os.path.exists')
+    def test_cli_connections_import_should_return_error_if_file_format_is_invalid(
+        self, filepath, mock_exists
+    ):
+        mock_exists.return_value = True
+        with pytest.raises(
+            AirflowException,
+            match=r"Unsupported file format. The file must have the extension .env or .json or .yaml",
+        ):
+            connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath]))
+
+    @mock.patch('airflow.cli.commands.connection_command.load_connections_dict')
+    @mock.patch('os.path.exists')
+    def test_cli_connections_import_should_load_connections(self, mock_exists, mock_load_connections_dict):
+        mock_exists.return_value = True
+
+        # Sample connections to import
+        expected_connections = {
+            "new0": {
+                "conn_type": "postgres",
+                "description": "new0 description",
+                "host": "host",
+                "is_encrypted": False,
+                "is_extra_encrypted": False,
+                "login": "airflow",
+                "port": 5432,
+                "schema": "airflow",
+            },
+            "new1": {
+                "conn_type": "mysql",
+                "description": "new1 description",
+                "host": "host",
+                "is_encrypted": False,
+                "is_extra_encrypted": False,
+                "login": "airflow",
+                "port": 3306,
+                "schema": "airflow",
+            },
+        }
+
+        # We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env
+        mock_load_connections_dict.return_value = expected_connections
+
+        connection_command.connections_import(
+            self.parser.parse_args(["connections", "import", 'sample.json'])
+        )
+
+        # Verify that the imported connections match the expected, sample connections
+        with create_session() as session:
+            current_conns = session.query(Connection).all()
+
+            comparable_attrs = [
+                "conn_type",
+                "description",
+                "host",
+                "is_encrypted",
+                "is_extra_encrypted",
+                "login",
+                "port",
+                "schema",
+            ]
+
+            current_conns_as_dicts = {
+                current_conn.conn_id: {attr: getattr(current_conn, attr) for attr in comparable_attrs}
+                for current_conn in current_conns
+            }
+            assert expected_connections == current_conns_as_dicts
+
+    @provide_session
+    @mock.patch('airflow.cli.commands.connection_command.load_connections_dict')
+    @mock.patch('os.path.exists')
+    def test_cli_connections_import_should_not_overwrite_existing_connections(
+        self, mock_exists, mock_load_connections_dict, session=None
+    ):
+        mock_exists.return_value = True
+
+        # Add a pre-existing connection "new1"
+        merge_conn(
+            Connection(
+                conn_id="new1",
+                conn_type="mysql",
+                description="mysql description",
+                host="mysql",
+                login="root",
+                password="",
+                schema="airflow",
+            ),
+            session=session,
+        )
+
+        # Sample connections to import, including a collision with "new1"
+        expected_connections = {
+            "new0": {
+                "conn_type": "postgres",
+                "description": "new0 description",
+                "host": "host",
+                "is_encrypted": False,
+                "is_extra_encrypted": False,
+                "login": "airflow",
+                "port": 5432,
+                "schema": "airflow",
+            },
+            "new1": {
+                "conn_type": "mysql",
+                "description": "new1 description",
+                "host": "host",
+                "is_encrypted": False,
+                "is_extra_encrypted": False,
+                "login": "airflow",
+                "port": 3306,
+                "schema": "airflow",
+            },
+        }
+
+        # We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env
+        mock_load_connections_dict.return_value = expected_connections
+
+        with redirect_stdout(io.StringIO()) as stdout:
+            connection_command.connections_import(
+                self.parser.parse_args(["connections", "import", 'sample.json'])
+            )
+
+            assert 'Could not import connection new1: connection already exists.' in stdout.getvalue()
+
+        # Verify that the imported connections match the expected, sample connections
+        current_conns = session.query(Connection).all()
+
+        comparable_attrs = [
+            "conn_type",
+            "description",
+            "host",
+            "is_encrypted",
+            "is_extra_encrypted",
+            "login",
+            "port",
+            "schema",
+        ]
+
+        current_conns_as_dicts = {
+            current_conn.conn_id: {attr: getattr(current_conn, attr) for attr in comparable_attrs}
+            for current_conn in current_conns
+        }
+        assert current_conns_as_dicts['new0'] == expected_connections['new0']
+
+        # The existing connection's description should not have changed
+        assert current_conns_as_dicts['new1']['description'] == 'new1 description'