You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2022/01/07 21:00:57 UTC
[airflow] branch main updated: Bugfix: ``SFTPHook`` does not respect ``ssh_conn_id`` arg (#20756)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c2fc760 Bugfix: ``SFTPHook`` does not respect ``ssh_conn_id`` arg (#20756)
c2fc760 is described below
commit c2fc760c9024ce9f0bec287679d9981f6ab1fd98
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Sat Jan 8 02:30:17 2022 +0530
Bugfix: ``SFTPHook`` does not respect ``ssh_conn_id`` arg (#20756)
closes https://github.com/apache/airflow/issues/20735
---
airflow/providers/sftp/hooks/sftp.py | 9 +-
tests/providers/sftp/hooks/test_sftp.py | 19 ++
tests/providers/sftp/hooks/test_sftp_outdated.py | 340 -----------------------
3 files changed, 23 insertions(+), 345 deletions(-)
diff --git a/airflow/providers/sftp/hooks/sftp.py b/airflow/providers/sftp/hooks/sftp.py
index b7e667b..f3d3ee0 100644
--- a/airflow/providers/sftp/hooks/sftp.py
+++ b/airflow/providers/sftp/hooks/sftp.py
@@ -74,19 +74,18 @@ class SFTPHook(SSHHook):
def __init__(
self,
ssh_conn_id: Optional[str] = 'sftp_default',
- ftp_conn_id: Optional[str] = 'sftp_default',
*args,
**kwargs,
) -> None:
-
+ ftp_conn_id = kwargs.pop('ftp_conn_id', None)
if ftp_conn_id:
warnings.warn(
- 'Parameter `ftp_conn_id` is deprecated.' 'Please use `ssh_conn_id` instead.',
+ 'Parameter `ftp_conn_id` is deprecated. Please use `ssh_conn_id` instead.',
DeprecationWarning,
stacklevel=2,
)
- kwargs['ssh_conn_id'] = ftp_conn_id
- self.ssh_conn_id = ssh_conn_id
+ ssh_conn_id = ftp_conn_id
+ kwargs['ssh_conn_id'] = ssh_conn_id
super().__init__(*args, **kwargs)
self.conn = None
diff --git a/tests/providers/sftp/hooks/test_sftp.py b/tests/providers/sftp/hooks/test_sftp.py
index 38e0e73..9c63402 100644
--- a/tests/providers/sftp/hooks/test_sftp.py
+++ b/tests/providers/sftp/hooks/test_sftp.py
@@ -334,6 +334,25 @@ class TestSFTPHook(unittest.TestCase):
assert status is True
assert msg == 'Connection successfully tested'
+ @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
+ def test_deprecation_ftp_conn_id(self, mock_get_connection):
+ connection = Connection(conn_id='ftp_default', login='login', host='host')
+ mock_get_connection.return_value = connection
+ # If `ftp_conn_id` is provided, it will be used but would show a deprecation warning.
+ with self.assertWarnsRegex(DeprecationWarning, "Parameter `ftp_conn_id` is deprecated"):
+ assert SFTPHook(ftp_conn_id='ftp_default').ssh_conn_id == 'ftp_default'
+
+ # If both are provided, ftp_conn_id will be used but would show a deprecation warning.
+ with self.assertWarnsRegex(DeprecationWarning, "Parameter `ftp_conn_id` is deprecated"):
+ assert (
+ SFTPHook(ftp_conn_id='ftp_default', ssh_conn_id='sftp_default').ssh_conn_id == 'ftp_default'
+ )
+
+ # If `ssh_conn_id` is provided, it should use it for ssh_conn_id
+ assert SFTPHook(ssh_conn_id='sftp_default').ssh_conn_id == 'sftp_default'
+ # Default is 'sftp_default
+ assert SFTPHook().ssh_conn_id == 'sftp_default'
+
def tearDown(self):
shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
diff --git a/tests/providers/sftp/hooks/test_sftp_outdated.py b/tests/providers/sftp/hooks/test_sftp_outdated.py
deleted file mode 100644
index 72ba9de..0000000
--- a/tests/providers/sftp/hooks/test_sftp_outdated.py
+++ /dev/null
@@ -1,340 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import json
-import os
-import shutil
-import unittest
-from io import StringIO
-from unittest import mock
-
-import paramiko
-import pysftp
-from parameterized import parameterized
-
-from airflow.models import Connection
-from airflow.providers.sftp.hooks.sftp import SFTPHook
-from airflow.utils.session import provide_session
-
-
-def generate_host_key(pkey: paramiko.PKey):
- key_fh = StringIO()
- pkey.write_private_key(key_fh)
- key_fh.seek(0)
- key_obj = paramiko.RSAKey(file_obj=key_fh)
- return key_obj.get_base64()
-
-
-TMP_PATH = '/tmp'
-TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
-SUB_DIR = "sub_dir"
-TMP_FILE_FOR_TESTS = 'test_file.txt'
-
-SFTP_CONNECTION_USER = "root"
-
-TEST_PKEY = paramiko.RSAKey.generate(4096)
-TEST_HOST_KEY = generate_host_key(pkey=TEST_PKEY)
-TEST_KEY_FILE = "~/.ssh/id_rsa"
-
-
-class TestSFTPHook(unittest.TestCase):
- @provide_session
- def update_connection(self, login, session=None):
- connection = session.query(Connection).filter(Connection.conn_id == "sftp_default").first()
- old_login = connection.login
- connection.login = login
- session.commit()
- return old_login
-
- def setUp(self):
- self.old_login = self.update_connection(SFTP_CONNECTION_USER)
- self.hook = SFTPHook(ftp_conn_id='sftp_default')
- os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))
-
- with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
- file.write('Test file')
- with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file:
- file.write('Test file')
-
- def test_get_conn(self):
- output = self.hook.get_conn()
- assert isinstance(output, pysftp.Connection)
-
- def test_close_conn(self):
- self.hook.conn = self.hook.get_conn()
- assert self.hook.conn is not None
- self.hook.close_conn()
- assert self.hook.conn is None
-
- def test_describe_directory(self):
- output = self.hook.describe_directory(TMP_PATH)
- assert TMP_DIR_FOR_TESTS in output
-
- def test_list_directory(self):
- output = self.hook.list_directory(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- assert output == [SUB_DIR]
-
- def test_create_and_delete_directory(self):
- new_dir_name = 'new_dir'
- self.hook.create_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
- output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- assert new_dir_name in output
- self.hook.delete_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
- output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- assert new_dir_name not in output
-
- def test_create_and_delete_directories(self):
- base_dir = "base_dir"
- sub_dir = "sub_dir"
- new_dir_path = os.path.join(base_dir, sub_dir)
- self.hook.create_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
- output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- assert base_dir in output
- output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
- assert sub_dir in output
- self.hook.delete_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
- self.hook.delete_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
- output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- assert new_dir_path not in output
- assert base_dir not in output
-
- def test_store_retrieve_and_delete_file(self):
- self.hook.store_file(
- remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
- local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS),
- )
- output = self.hook.list_directory(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- assert output == [SUB_DIR, TMP_FILE_FOR_TESTS]
- retrieved_file_name = 'retrieved.txt'
- self.hook.retrieve_file(
- remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
- local_full_path=os.path.join(TMP_PATH, retrieved_file_name),
- )
- assert retrieved_file_name in os.listdir(TMP_PATH)
- os.remove(os.path.join(TMP_PATH, retrieved_file_name))
- self.hook.delete_file(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
- output = self.hook.list_directory(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- assert output == [SUB_DIR]
-
- def test_get_mod_time(self):
- self.hook.store_file(
- remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
- local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS),
- )
- output = self.hook.get_mod_time(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
- assert len(output) == 14
-
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_no_host_key_check_default(self, get_connection):
- connection = Connection(login='login', host='host')
- get_connection.return_value = connection
- hook = SFTPHook()
- assert hook.no_host_key_check is False
-
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_no_host_key_check_enabled(self, get_connection):
- connection = Connection(login='login', host='host', extra='{"no_host_key_check": true}')
-
- get_connection.return_value = connection
- hook = SFTPHook()
- assert hook.no_host_key_check is True
-
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_no_host_key_check_disabled(self, get_connection):
- connection = Connection(login='login', host='host', extra='{"no_host_key_check": false}')
-
- get_connection.return_value = connection
- hook = SFTPHook()
- assert hook.no_host_key_check is False
-
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_ciphers(self, get_connection):
- connection = Connection(login='login', host='host', extra='{"ciphers": ["A", "B", "C"]}')
-
- get_connection.return_value = connection
- hook = SFTPHook()
- assert hook.ciphers == ["A", "B", "C"]
-
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_no_host_key_check_disabled_for_all_but_true(self, get_connection):
- connection = Connection(login='login', host='host', extra='{"no_host_key_check": "foo"}')
-
- get_connection.return_value = connection
- hook = SFTPHook()
- assert hook.no_host_key_check is False
-
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_no_host_key_check_ignore(self, get_connection):
- connection = Connection(login='login', host='host', extra='{"ignore_hostkey_verification": true}')
-
- get_connection.return_value = connection
- hook = SFTPHook()
- assert hook.no_host_key_check is True
-
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_no_host_key_check_no_ignore(self, get_connection):
- connection = Connection(login='login', host='host', extra='{"ignore_hostkey_verification": false}')
-
- get_connection.return_value = connection
- hook = SFTPHook()
- assert hook.no_host_key_check is False
-
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_host_key_default(self, get_connection):
- connection = Connection(login='login', host='host')
- get_connection.return_value = connection
- hook = SFTPHook()
- assert hook.host_key is None
-
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_host_key(self, get_connection):
- connection = Connection(
- login='login',
- host='host',
- extra=json.dumps({"host_key": TEST_HOST_KEY}),
- )
- get_connection.return_value = connection
- hook = SFTPHook()
- assert hook.host_key.get_base64() == TEST_HOST_KEY
-
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_host_key_with_type(self, get_connection):
- connection = Connection(
- login='login',
- host='host',
- extra=json.dumps({"host_key": "ssh-rsa " + TEST_HOST_KEY}),
- )
- get_connection.return_value = connection
- hook = SFTPHook()
- assert hook.host_key.get_base64() == TEST_HOST_KEY
-
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_host_key_with_no_host_key_check(self, get_connection):
- connection = Connection(login='login', host='host', extra=json.dumps({"host_key": TEST_HOST_KEY}))
- get_connection.return_value = connection
- hook = SFTPHook()
- assert hook.host_key is not None
-
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_key_content_as_str(self, get_connection):
- file_obj = StringIO()
- TEST_PKEY.write_private_key(file_obj)
- file_obj.seek(0)
- key_content_str = file_obj.read()
-
- connection = Connection(
- login='login',
- host='host',
- extra=json.dumps(
- {
- "private_key": key_content_str,
- }
- ),
- )
- get_connection.return_value = connection
- hook = SFTPHook()
- assert hook.pkey == TEST_PKEY
- assert hook.key_file is None
-
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_key_file(self, get_connection):
- connection = Connection(
- login='login',
- host='host',
- extra=json.dumps(
- {
- "key_file": TEST_KEY_FILE,
- }
- ),
- )
- get_connection.return_value = connection
- hook = SFTPHook()
- assert hook.key_file == TEST_KEY_FILE
-
- @parameterized.expand(
- [
- (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True),
- (os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True),
- (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False),
- (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False),
- ]
- )
- def test_path_exists(self, path, exists):
- result = self.hook.path_exists(path)
- assert result == exists
-
- @parameterized.expand(
- [
- ("test/path/file.bin", None, None, True),
- ("test/path/file.bin", "test", None, True),
- ("test/path/file.bin", "test/", None, True),
- ("test/path/file.bin", None, "bin", True),
- ("test/path/file.bin", "test", "bin", True),
- ("test/path/file.bin", "test/", "file.bin", True),
- ("test/path/file.bin", None, "file.bin", True),
- ("test/path/file.bin", "diff", None, False),
- ("test/path/file.bin", "test//", None, False),
- ("test/path/file.bin", None, ".txt", False),
- ("test/path/file.bin", "diff", ".txt", False),
- ]
- )
- def test_path_match(self, path, prefix, delimiter, match):
- result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter)
- assert result == match
-
- def test_get_tree_map(self):
- tree_map = self.hook.get_tree_map(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- files, dirs, unknowns = tree_map
-
- assert files == [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)]
- assert dirs == [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)]
- assert unknowns == []
-
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_connection_failure(self, mock_get_connection):
- connection = Connection(
- login='login',
- host='host',
- )
- mock_get_connection.return_value = connection
- with mock.patch.object(SFTPHook, 'get_conn') as get_conn:
- type(get_conn.return_value).pwd = mock.PropertyMock(side_effect=Exception('Connection Error'))
-
- hook = SFTPHook()
- status, msg = hook.test_connection()
- assert status is False
- assert msg == 'Connection Error'
-
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_connection_success(self, mock_get_connection):
- connection = Connection(
- login='login',
- host='host',
- )
- mock_get_connection.return_value = connection
-
- with mock.patch.object(SFTPHook, 'get_conn') as get_conn:
- get_conn.return_value.pwd = '/home/someuser'
- hook = SFTPHook()
- status, msg = hook.test_connection()
- assert status is True
- assert msg == 'Connection successfully tested'
-
- def tearDown(self):
- shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
- self.update_connection(self.old_login)