You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/06/22 13:46:26 UTC

[airflow] 36/38: Fix CLI connections import and migrate logic from secrets to Connection model (#15425)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit fc30a4c80ffc19f40ab04cb502581d81b127c55f
Author: natanweinberger <na...@gmail.com>
AuthorDate: Fri Jun 11 07:34:06 2021 -0400

    Fix CLI connections import and migrate logic from secrets to Connection model (#15425)
    
    * Add field 'extra' to Connection init
    
    * Fix connections import CLI
    
    In connections_import, each connection was deserialized and stored into a
    Connection model instance rather than a dictionary, so an erroneous call to the
    dictionary methods .items() resulted in an AttributeError. With this fix,
    connection information is loaded from dictionaries directly into the
    Connection constructor and committed to the DB.
    
    * Apply suggestions from code review
    
    * Use load_connections_dict in connections import
    
    Co-authored-by: Ash Berlin-Taylor <as...@firemirror.com>
    (cherry picked from commit 002075af91965416e595880040e138b1d6ddec43)
---
 airflow/cli/commands/connection_command.py    | 29 +++---------
 airflow/models/connection.py                  |  6 ++-
 tests/cli/commands/test_connection_command.py | 66 ++++++++++++++-------------
 3 files changed, 44 insertions(+), 57 deletions(-)

diff --git a/airflow/cli/commands/connection_command.py b/airflow/cli/commands/connection_command.py
index 6e45e2a..8f60022 100644
--- a/airflow/cli/commands/connection_command.py
+++ b/airflow/cli/commands/connection_command.py
@@ -29,8 +29,8 @@ from airflow.cli.simple_table import AirflowConsole
 from airflow.exceptions import AirflowNotFoundException
 from airflow.hooks.base import BaseHook
 from airflow.models import Connection
-from airflow.secrets.local_filesystem import _create_connection, load_connections_dict
-from airflow.utils import cli as cli_utils
+from airflow.secrets.local_filesystem import load_connections_dict
+from airflow.utils import cli as cli_utils, yaml
 from airflow.utils.cli import suppress_logs_and_warning
 from airflow.utils.session import create_session
 
@@ -239,7 +239,7 @@ def connections_delete(args):
 
 @cli_utils.action_logging
 def connections_import(args):
-    """Imports connections from a given file"""
+    """Imports connections from a file"""
     if os.path.exists(args.file):
         _import_helper(args.file)
     else:
@@ -247,31 +247,14 @@ def connections_import(args):
 
 
 def _import_helper(file_path):
-    """Helps import connections from a file"""
+    """Load connections from a file and save them to the DB. On collision, skip."""
     connections_dict = load_connections_dict(file_path)
     with create_session() as session:
-        for conn_id, conn_values in connections_dict.items():
+        for conn_id, conn in connections_dict.items():
             if session.query(Connection).filter(Connection.conn_id == conn_id).first():
                 print(f'Could not import connection {conn_id}: connection already exists.')
                 continue
 
-            allowed_fields = [
-                'extra',
-                'description',
-                'conn_id',
-                'login',
-                'conn_type',
-                'host',
-                'password',
-                'schema',
-                'port',
-                'uri',
-                'extra_dejson',
-            ]
-            filtered_connection_values = {
-                key: value for key, value in conn_values.items() if key in allowed_fields
-            }
-            connection = _create_connection(conn_id, filtered_connection_values)
-            session.add(connection)
+            session.add(conn)
             session.commit()
             print(f'Imported connection {conn_id}')
diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index 9021edb..73d0d8d 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -19,7 +19,7 @@
 import json
 import warnings
 from json import JSONDecodeError
-from typing import Dict, Optional
+from typing import Dict, Optional, Union
 from urllib.parse import parse_qsl, quote, unquote, urlencode, urlparse
 
 from sqlalchemy import Boolean, Column, Integer, String, Text
@@ -117,12 +117,14 @@ class Connection(Base, LoggingMixin):  # pylint: disable=too-many-instance-attri
         password: Optional[str] = None,
         schema: Optional[str] = None,
         port: Optional[int] = None,
-        extra: Optional[str] = None,
+        extra: Optional[Union[str, dict]] = None,
         uri: Optional[str] = None,
     ):
         super().__init__()
         self.conn_id = conn_id
         self.description = description
+        if extra and not isinstance(extra, str):
+            extra = json.dumps(extra)
         if uri and (  # pylint: disable=too-many-boolean-expressions
             conn_type or host or login or password or schema or port or extra
         ):
diff --git a/tests/cli/commands/test_connection_command.py b/tests/cli/commands/test_connection_command.py
index 136811d..5339083 100644
--- a/tests/cli/commands/test_connection_command.py
+++ b/tests/cli/commands/test_connection_command.py
@@ -758,9 +758,9 @@ class TestCliImportConnections(unittest.TestCase):
         ):
             connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath]))
 
-    @mock.patch('airflow.cli.commands.connection_command.load_connections_dict')
+    @mock.patch('airflow.secrets.local_filesystem._parse_secret_file')
     @mock.patch('os.path.exists')
-    def test_cli_connections_import_should_load_connections(self, mock_exists, mock_load_connections_dict):
+    def test_cli_connections_import_should_load_connections(self, mock_exists, mock_parse_secret_file):
         mock_exists.return_value = True
 
         # Sample connections to import
@@ -769,26 +769,26 @@ class TestCliImportConnections(unittest.TestCase):
                 "conn_type": "postgres",
                 "description": "new0 description",
                 "host": "host",
-                "is_encrypted": False,
-                "is_extra_encrypted": False,
                 "login": "airflow",
+                "password": "password",
                 "port": 5432,
                 "schema": "airflow",
+                "extra": "test",
             },
             "new1": {
                 "conn_type": "mysql",
                 "description": "new1 description",
                 "host": "host",
-                "is_encrypted": False,
-                "is_extra_encrypted": False,
                 "login": "airflow",
+                "password": "password",
                 "port": 3306,
                 "schema": "airflow",
+                "extra": "test",
             },
         }
 
-        # We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env
-        mock_load_connections_dict.return_value = expected_connections
+        # We're not testing the behavior of _parse_secret_file, assume it successfully reads JSON, YAML or env
+        mock_parse_secret_file.return_value = expected_connections
 
         connection_command.connections_import(
             self.parser.parse_args(["connections", "import", 'sample.json'])
@@ -799,14 +799,15 @@ class TestCliImportConnections(unittest.TestCase):
             current_conns = session.query(Connection).all()
 
             comparable_attrs = [
+                "conn_id",
                 "conn_type",
                 "description",
                 "host",
-                "is_encrypted",
-                "is_extra_encrypted",
                 "login",
+                "password",
                 "port",
                 "schema",
+                "extra",
             ]
 
             current_conns_as_dicts = {
@@ -816,80 +817,81 @@ class TestCliImportConnections(unittest.TestCase):
             assert expected_connections == current_conns_as_dicts
 
     @provide_session
-    @mock.patch('airflow.cli.commands.connection_command.load_connections_dict')
+    @mock.patch('airflow.secrets.local_filesystem._parse_secret_file')
     @mock.patch('os.path.exists')
     def test_cli_connections_import_should_not_overwrite_existing_connections(
-        self, mock_exists, mock_load_connections_dict, session=None
+        self, mock_exists, mock_parse_secret_file, session=None
     ):
         mock_exists.return_value = True
 
-        # Add a pre-existing connection "new1"
+        # Add a pre-existing connection "new3"
         merge_conn(
             Connection(
-                conn_id="new1",
+                conn_id="new3",
                 conn_type="mysql",
-                description="mysql description",
+                description="original description",
                 host="mysql",
                 login="root",
-                password="",
+                password="password",
                 schema="airflow",
             ),
             session=session,
         )
 
-        # Sample connections to import, including a collision with "new1"
+        # Sample connections to import, including a collision with "new3"
         expected_connections = {
-            "new0": {
+            "new2": {
                 "conn_type": "postgres",
-                "description": "new0 description",
+                "description": "new2 description",
                 "host": "host",
-                "is_encrypted": False,
-                "is_extra_encrypted": False,
                 "login": "airflow",
+                "password": "password",
                 "port": 5432,
                 "schema": "airflow",
+                "extra": "test",
             },
-            "new1": {
+            "new3": {
                 "conn_type": "mysql",
-                "description": "new1 description",
+                "description": "updated description",
                 "host": "host",
-                "is_encrypted": False,
-                "is_extra_encrypted": False,
                 "login": "airflow",
+                "password": "new password",
                 "port": 3306,
                 "schema": "airflow",
+                "extra": "test",
             },
         }
 
-        # We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env
-        mock_load_connections_dict.return_value = expected_connections
+        # We're not testing the behavior of _parse_secret_file, assume it successfully reads JSON, YAML or env
+        mock_parse_secret_file.return_value = expected_connections
 
         with redirect_stdout(io.StringIO()) as stdout:
             connection_command.connections_import(
                 self.parser.parse_args(["connections", "import", 'sample.json'])
             )
 
-            assert 'Could not import connection new1: connection already exists.' in stdout.getvalue()
+            assert 'Could not import connection new3: connection already exists.' in stdout.getvalue()
 
         # Verify that the imported connections match the expected, sample connections
         current_conns = session.query(Connection).all()
 
         comparable_attrs = [
+            "conn_id",
             "conn_type",
             "description",
             "host",
-            "is_encrypted",
-            "is_extra_encrypted",
             "login",
+            "password",
             "port",
             "schema",
+            "extra",
         ]
 
         current_conns_as_dicts = {
             current_conn.conn_id: {attr: getattr(current_conn, attr) for attr in comparable_attrs}
             for current_conn in current_conns
         }
-        assert current_conns_as_dicts['new0'] == expected_connections['new0']
+        assert current_conns_as_dicts['new2'] == expected_connections['new2']
 
         # The existing connection's description should not have changed
-        assert current_conns_as_dicts['new1']['description'] == 'new1 description'
+        assert current_conns_as_dicts['new3']['description'] == 'original description'