You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by di...@apache.org on 2021/04/05 17:55:26 UTC
[airflow] branch master updated: Import connections from a file
(#15177)
This is an automated email from the ASF dual-hosted git repository.
dimberman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 7cadb63 Import connections from a file (#15177)
7cadb63 is described below
commit 7cadb63d38900f581b5d81011a1de534fe713c3a
Author: natanweinberger <na...@gmail.com>
AuthorDate: Mon Apr 5 13:55:14 2021 -0400
Import connections from a file (#15177)
* Add connections import CLI command
* Add tests for CLI connections import
* Add connections import overwrite test
When a connections file contains collisions with existing connections,
skip them and print a message to stdout indicating that the connection
was not imported.
* Resolve lint errors
---
airflow/cli/cli_parser.py | 11 ++
airflow/cli/commands/connection_command.py | 41 ++++++
tests/cli/commands/test_connection_command.py | 173 ++++++++++++++++++++++++++
3 files changed, 225 insertions(+)
diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py
index b5c709f..0bf92cc 100644
--- a/airflow/cli/cli_parser.py
+++ b/airflow/cli/cli_parser.py
@@ -602,6 +602,7 @@ ARG_CONN_EXPORT = Arg(
ARG_CONN_EXPORT_FORMAT = Arg(
('--format',), help='Format of the connections data in file', type=str, choices=['json', 'yaml', 'env']
)
+ARG_CONN_IMPORT = Arg(("file",), help="Import connections from a file")
# providers
ARG_PROVIDER_NAME = Arg(
@@ -1200,6 +1201,16 @@ CONNECTIONS_COMMANDS = (
ARG_CONN_EXPORT_FORMAT,
),
),
+ ActionCommand(
+ name='import',
+ help='Import connections from a file',
+ description=(
+ "Connections can be imported from the output of the export command.\n"
+ "The filetype must by json, yaml or env and will be automatically inferred."
+ ),
+ func=lazy_load_command('airflow.cli.commands.connection_command.connections_import'),
+ args=(ARG_CONN_IMPORT,),
+ ),
)
PROVIDERS_COMMANDS = (
ActionCommand(
diff --git a/airflow/cli/commands/connection_command.py b/airflow/cli/commands/connection_command.py
index f35ed36..6e45e2a 100644
--- a/airflow/cli/commands/connection_command.py
+++ b/airflow/cli/commands/connection_command.py
@@ -29,6 +29,7 @@ from airflow.cli.simple_table import AirflowConsole
from airflow.exceptions import AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
+from airflow.secrets.local_filesystem import _create_connection, load_connections_dict
from airflow.utils import cli as cli_utils
from airflow.utils.cli import suppress_logs_and_warning
from airflow.utils.session import create_session
@@ -234,3 +235,43 @@ def connections_delete(args):
else:
session.delete(to_delete)
print(f"Successfully deleted connection with `conn_id`={to_delete.conn_id}")
+
+
+@cli_utils.action_logging
+def connections_import(args):
+ """Imports connections from a given file"""
+ if os.path.exists(args.file):
+ _import_helper(args.file)
+ else:
+ raise SystemExit("Missing connections file.")
+
+
+def _import_helper(file_path):
+ """Helps import connections from a file"""
+ connections_dict = load_connections_dict(file_path)
+ with create_session() as session:
+ for conn_id, conn_values in connections_dict.items():
+ if session.query(Connection).filter(Connection.conn_id == conn_id).first():
+ print(f'Could not import connection {conn_id}: connection already exists.')
+ continue
+
+ allowed_fields = [
+ 'extra',
+ 'description',
+ 'conn_id',
+ 'login',
+ 'conn_type',
+ 'host',
+ 'password',
+ 'schema',
+ 'port',
+ 'uri',
+ 'extra_dejson',
+ ]
+ filtered_connection_values = {
+ key: value for key, value in conn_values.items() if key in allowed_fields
+ }
+ connection = _create_connection(conn_id, filtered_connection_values)
+ session.add(connection)
+ session.commit()
+ print(f'Imported connection {conn_id}')
diff --git a/tests/cli/commands/test_connection_command.py b/tests/cli/commands/test_connection_command.py
index c81ff81..cf27941 100644
--- a/tests/cli/commands/test_connection_command.py
+++ b/tests/cli/commands/test_connection_command.py
@@ -27,6 +27,7 @@ from parameterized import parameterized
from airflow.cli import cli_parser
from airflow.cli.commands import connection_command
+from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.utils.db import merge_conn
from airflow.utils.session import create_session, provide_session
@@ -716,3 +717,175 @@ class TestCliDeleteConnections(unittest.TestCase):
# Attempt to delete a non-existing connection
with pytest.raises(SystemExit, match=r"Did not find a connection with `conn_id`=fake"):
connection_command.connections_delete(self.parser.parse_args(["connections", "delete", "fake"]))
+
+
+class TestCliImportConnections(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.parser = cli_parser.get_parser()
+ clear_db_connections(add_default_connections_back=False)
+
+ @classmethod
+ def tearDownClass(cls):
+ clear_db_connections()
+
+ @mock.patch('os.path.exists')
+ def test_cli_connections_import_should_return_error_if_file_does_not_exist(self, mock_exists):
+ mock_exists.return_value = False
+ filepath = '/does/not/exist.json'
+ with pytest.raises(SystemExit, match=r"Missing connections file."):
+ connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath]))
+
+ @parameterized.expand(
+ [
+ ("sample.jso",),
+ ("sample.yml",),
+ ("sample.environ",),
+ ]
+ )
+ @mock.patch('os.path.exists')
+ def test_cli_connections_import_should_return_error_if_file_format_is_invalid(
+ self, filepath, mock_exists
+ ):
+ mock_exists.return_value = True
+ with pytest.raises(
+ AirflowException,
+ match=r"Unsupported file format. The file must have the extension .env or .json or .yaml",
+ ):
+ connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath]))
+
+ @mock.patch('airflow.cli.commands.connection_command.load_connections_dict')
+ @mock.patch('os.path.exists')
+ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_load_connections_dict):
+ mock_exists.return_value = True
+
+ # Sample connections to import
+ expected_connections = {
+ "new0": {
+ "conn_type": "postgres",
+ "description": "new0 description",
+ "host": "host",
+ "is_encrypted": False,
+ "is_extra_encrypted": False,
+ "login": "airflow",
+ "port": 5432,
+ "schema": "airflow",
+ },
+ "new1": {
+ "conn_type": "mysql",
+ "description": "new1 description",
+ "host": "host",
+ "is_encrypted": False,
+ "is_extra_encrypted": False,
+ "login": "airflow",
+ "port": 3306,
+ "schema": "airflow",
+ },
+ }
+
+ # We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env
+ mock_load_connections_dict.return_value = expected_connections
+
+ connection_command.connections_import(
+ self.parser.parse_args(["connections", "import", 'sample.json'])
+ )
+
+ # Verify that the imported connections match the expected, sample connections
+ with create_session() as session:
+ current_conns = session.query(Connection).all()
+
+ comparable_attrs = [
+ "conn_type",
+ "description",
+ "host",
+ "is_encrypted",
+ "is_extra_encrypted",
+ "login",
+ "port",
+ "schema",
+ ]
+
+ current_conns_as_dicts = {
+ current_conn.conn_id: {attr: getattr(current_conn, attr) for attr in comparable_attrs}
+ for current_conn in current_conns
+ }
+ assert expected_connections == current_conns_as_dicts
+
+ @provide_session
+ @mock.patch('airflow.cli.commands.connection_command.load_connections_dict')
+ @mock.patch('os.path.exists')
+ def test_cli_connections_import_should_not_overwrite_existing_connections(
+ self, mock_exists, mock_load_connections_dict, session=None
+ ):
+ mock_exists.return_value = True
+
+ # Add a pre-existing connection "new1"
+ merge_conn(
+ Connection(
+ conn_id="new1",
+ conn_type="mysql",
+ description="mysql description",
+ host="mysql",
+ login="root",
+ password="",
+ schema="airflow",
+ ),
+ session=session,
+ )
+
+ # Sample connections to import, including a collision with "new1"
+ expected_connections = {
+ "new0": {
+ "conn_type": "postgres",
+ "description": "new0 description",
+ "host": "host",
+ "is_encrypted": False,
+ "is_extra_encrypted": False,
+ "login": "airflow",
+ "port": 5432,
+ "schema": "airflow",
+ },
+ "new1": {
+ "conn_type": "mysql",
+ "description": "new1 description",
+ "host": "host",
+ "is_encrypted": False,
+ "is_extra_encrypted": False,
+ "login": "airflow",
+ "port": 3306,
+ "schema": "airflow",
+ },
+ }
+
+ # We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env
+ mock_load_connections_dict.return_value = expected_connections
+
+ with redirect_stdout(io.StringIO()) as stdout:
+ connection_command.connections_import(
+ self.parser.parse_args(["connections", "import", 'sample.json'])
+ )
+
+ assert 'Could not import connection new1: connection already exists.' in stdout.getvalue()
+
+ # Verify that the imported connections match the expected, sample connections
+ current_conns = session.query(Connection).all()
+
+ comparable_attrs = [
+ "conn_type",
+ "description",
+ "host",
+ "is_encrypted",
+ "is_extra_encrypted",
+ "login",
+ "port",
+ "schema",
+ ]
+
+ current_conns_as_dicts = {
+ current_conn.conn_id: {attr: getattr(current_conn, attr) for attr in comparable_attrs}
+ for current_conn in current_conns
+ }
+ assert current_conns_as_dicts['new0'] == expected_connections['new0']
+
+ # The existing connection's description should not have changed
+ assert current_conns_as_dicts['new1']['description'] == 'new1 description'