You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/04/13 16:52:46 UTC

[airflow] branch main updated: Support importing connections from files with ".yml" extension (#22872)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3c0ad4af31 Support importing connections from files with ".yml" extension (#22872)
3c0ad4af31 is described below

commit 3c0ad4af310483cd051e94550a7d857653dcee6d
Author: Hank Ehly <he...@gmail.com>
AuthorDate: Thu Apr 14 01:52:39 2022 +0900

    Support importing connections from files with ".yml" extension (#22872)
    
    * Add .yml to list of import-able secret file extensions
---
 airflow/secrets/local_filesystem.py           |  5 +++--
 tests/cli/commands/test_connection_command.py |  7 +++++--
 tests/secrets/test_local_filesystem.py        | 20 ++++++++++++++++++--
 3 files changed, 26 insertions(+), 6 deletions(-)

diff --git a/airflow/secrets/local_filesystem.py b/airflow/secrets/local_filesystem.py
index 2c399b30f5..25c4eed3db 100644
--- a/airflow/secrets/local_filesystem.py
+++ b/airflow/secrets/local_filesystem.py
@@ -109,7 +109,6 @@ def _parse_yaml_file(file_path: str) -> Tuple[Dict[str, List[str]], List[FileSyn
         return {}, [FileSyntaxError(line_no=1, message="The file is empty.")]
     try:
         secrets = yaml.safe_load(content)
-
     except yaml.MarkedYAMLError as e:
         err_line_no = e.problem_mark.line if e.problem_mark else -1
         return {}, [FileSyntaxError(line_no=err_line_no, message=str(e))]
@@ -145,6 +144,7 @@ FILE_PARSERS = {
     "env": _parse_env_file,
     "json": _parse_json_file,
     "yaml": _parse_yaml_file,
+    "yml": _parse_yaml_file,
 }
 
 
@@ -166,7 +166,8 @@ def _parse_secret_file(file_path: str) -> Dict[str, Any]:
 
     if ext not in FILE_PARSERS:
         raise AirflowException(
-            "Unsupported file format. The file must have the extension .env or .json or .yaml"
+            "Unsupported file format. The file must have one of the following extensions: "
+            ".env .json .yaml .yml"
         )
 
     secrets, parse_errors = FILE_PARSERS[ext](file_path)
diff --git a/tests/cli/commands/test_connection_command.py b/tests/cli/commands/test_connection_command.py
index 028731209c..a4e1148e83 100644
--- a/tests/cli/commands/test_connection_command.py
+++ b/tests/cli/commands/test_connection_command.py
@@ -630,7 +630,7 @@ class TestCliImportConnections:
         with pytest.raises(SystemExit, match=r"Missing connections file."):
             connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath]))
 
-    @pytest.mark.parametrize('filepath', ["sample.jso", "sample.yml", "sample.environ"])
+    @pytest.mark.parametrize('filepath', ["sample.jso", "sample.environ"])
     @mock.patch('os.path.exists')
     def test_cli_connections_import_should_return_error_if_file_format_is_invalid(
         self, mock_exists, filepath
@@ -638,7 +638,10 @@ class TestCliImportConnections:
         mock_exists.return_value = True
         with pytest.raises(
             AirflowException,
-            match=r"Unsupported file format. The file must have the extension .env or .json or .yaml",
+            match=(
+                "Unsupported file format. The file must have one of the following extensions: "
+                ".env .json .yaml .yml"
+            ),
         ):
             connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath]))
 
diff --git a/tests/secrets/test_local_filesystem.py b/tests/secrets/test_local_filesystem.py
index 5993eb3af2..a7bbd824f0 100644
--- a/tests/secrets/test_local_filesystem.py
+++ b/tests/secrets/test_local_filesystem.py
@@ -122,8 +122,9 @@ class TestLoadVariables(unittest.TestCase):
     )
     def test_yaml_file_should_load_variables(self, file_content, expected_variables):
         with mock_local_file(file_content):
-            variables = local_filesystem.load_variables('a.yaml')
-            assert expected_variables == variables
+            vars_yaml = local_filesystem.load_variables('a.yaml')
+            vars_yml = local_filesystem.load_variables('a.yml')
+            assert expected_variables == vars_yaml == vars_yml
 
 
 class TestLoadConnection(unittest.TestCase):
@@ -390,6 +391,21 @@ class TestLoadConnection(unittest.TestCase):
             with pytest.raises(ConnectionNotUnique):
                 local_filesystem.load_connections_dict("a.yaml")
 
+    @parameterized.expand(
+        (("conn_a: mysql://hosta"),),
+    )
+    def test_yaml_extension_parsers_return_same_result(self, file_content):
+        with mock_local_file(file_content):
+            conn_uri_by_conn_id_yaml = {
+                conn_id: conn.get_uri()
+                for conn_id, conn in local_filesystem.load_connections_dict("a.yaml").items()
+            }
+            conn_uri_by_conn_id_yml = {
+                conn_id: conn.get_uri()
+                for conn_id, conn in local_filesystem.load_connections_dict("a.yml").items()
+            }
+            assert conn_uri_by_conn_id_yaml == conn_uri_by_conn_id_yml
+
 
 class TestLocalFileBackend(unittest.TestCase):
     def test_should_read_variable(self):