You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/02/02 00:20:57 UTC

[airflow] branch main updated: Add banner_timeout feature to SSH Hook/Operator (#21262)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d353f02  Add banner_timeout feature to SSH Hook/Operator (#21262)
d353f02 is described below

commit d353f023ff8856c00b9f054526cb2e40ff0116ae
Author: Jarek Potiuk <ja...@potiuk.com>
AuthorDate: Wed Feb 2 01:20:17 2022 +0100

    Add banner_timeout feature to SSH Hook/Operator (#21262)
    
    Recently ssh tests in CI started to fail intermittently with
    Error reading SSH protocol banner error. This error is raised
    when SSH server is slow to start (which might happen for
    example when there is not enough entropy to generate keys)
    
    This can be mitigated by adding banner_timeout.
---
 airflow/providers/ssh/hooks/ssh.py        |  4 ++++
 airflow/providers/ssh/operators/ssh.py    |  9 ++++++++-
 tests/providers/ssh/hooks/test_ssh.py     | 10 ++++++++++
 tests/providers/ssh/operators/test_ssh.py | 24 +++++++++++++++++++++---
 4 files changed, 43 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/ssh/hooks/ssh.py b/airflow/providers/ssh/hooks/ssh.py
index 8256827..ac41c91 100644
--- a/airflow/providers/ssh/hooks/ssh.py
+++ b/airflow/providers/ssh/hooks/ssh.py
@@ -66,6 +66,7 @@ class SSHHook(BaseHook):
         Use conn_timeout instead.
     :param keepalive_interval: send a keepalive packet to remote host every
         keepalive_interval seconds
+    :param banner_timeout: timeout to wait for banner from the server in seconds
     """
 
     # List of classes to try loading private keys as, ordered (roughly) by most common to least common
@@ -109,6 +110,7 @@ class SSHHook(BaseHook):
         timeout: Optional[int] = None,
         conn_timeout: Optional[int] = None,
         keepalive_interval: int = 30,
+        banner_timeout: float = 30.0,
     ) -> None:
         super().__init__()
         self.ssh_conn_id = ssh_conn_id
@@ -121,6 +123,7 @@ class SSHHook(BaseHook):
         self.timeout = timeout
         self.conn_timeout = conn_timeout
         self.keepalive_interval = keepalive_interval
+        self.banner_timeout = banner_timeout
         self.host_proxy_cmd = None
 
         # Default values, overridable from Connection
@@ -293,6 +296,7 @@ class SSHHook(BaseHook):
             port=self.port,
             sock=self.host_proxy,
             look_for_keys=self.look_for_keys,
+            banner_timeout=self.banner_timeout,
         )
 
         if self.password:
diff --git a/airflow/providers/ssh/operators/ssh.py b/airflow/providers/ssh/operators/ssh.py
index 9c2e6cc..be0792d 100644
--- a/airflow/providers/ssh/operators/ssh.py
+++ b/airflow/providers/ssh/operators/ssh.py
@@ -56,6 +56,7 @@ class SSHOperator(BaseOperator):
         to have the remote process killed upon task timeout.
         The default is ``False`` but note that `get_pty` is forced to ``True``
         when the `command` starts with ``sudo``.
+    :param banner_timeout: timeout to wait for banner from the server in seconds
     """
 
     template_fields: Sequence[str] = ('command', 'remote_host')
@@ -74,6 +75,7 @@ class SSHOperator(BaseOperator):
         cmd_timeout: Optional[int] = None,
         environment: Optional[dict] = None,
         get_pty: bool = False,
+        banner_timeout: float = 30.0,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -90,6 +92,7 @@ class SSHOperator(BaseOperator):
             self.cmd_timeout = self.timeout if self.timeout else CMD_TIMEOUT
         self.environment = environment
         self.get_pty = get_pty
+        self.banner_timeout = banner_timeout
 
         if self.timeout:
             warnings.warn(
@@ -106,7 +109,11 @@ class SSHOperator(BaseOperator):
                 self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
             else:
                 self.log.info("ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook.")
-                self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, conn_timeout=self.conn_timeout)
+                self.ssh_hook = SSHHook(
+                    ssh_conn_id=self.ssh_conn_id,
+                    conn_timeout=self.conn_timeout,
+                    banner_timeout=self.banner_timeout,
+                )
 
         if not self.ssh_hook:
             raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")
diff --git a/tests/providers/ssh/hooks/test_ssh.py b/tests/providers/ssh/hooks/test_ssh.py
index 2351c4e..06e54ce 100644
--- a/tests/providers/ssh/hooks/test_ssh.py
+++ b/tests/providers/ssh/hooks/test_ssh.py
@@ -249,6 +249,7 @@ class TestSSHHook(unittest.TestCase):
 
         with hook.get_conn():
             ssh_mock.return_value.connect.assert_called_once_with(
+                banner_timeout=30.0,
                 hostname='remote_host',
                 username='username',
                 password='password',
@@ -268,6 +269,7 @@ class TestSSHHook(unittest.TestCase):
 
         with hook.get_conn():
             ssh_mock.return_value.connect.assert_called_once_with(
+                banner_timeout=30.0,
                 hostname='remote_host',
                 username='username',
                 key_filename='fake.file',
@@ -455,6 +457,7 @@ class TestSSHHook(unittest.TestCase):
 
         with hook.get_conn():
             ssh_mock.return_value.connect.assert_called_once_with(
+                banner_timeout=30.0,
                 hostname='remote_host',
                 username='username',
                 pkey=TEST_PKEY,
@@ -477,6 +480,7 @@ class TestSSHHook(unittest.TestCase):
 
         with hook.get_conn():
             ssh_mock.return_value.connect.assert_called_once_with(
+                banner_timeout=30.0,
                 hostname='remote_host',
                 username='username',
                 pkey=TEST_PKEY,
@@ -530,6 +534,7 @@ class TestSSHHook(unittest.TestCase):
 
         with hook.get_conn():
             ssh_mock.return_value.connect.assert_called_once_with(
+                banner_timeout=30.0,
                 hostname='remote_host',
                 username='username',
                 password='password',
@@ -555,6 +560,7 @@ class TestSSHHook(unittest.TestCase):
 
         with hook.get_conn():
             ssh_mock.return_value.connect.assert_called_once_with(
+                banner_timeout=30.0,
                 hostname='remote_host',
                 username='username',
                 password='password',
@@ -578,6 +584,7 @@ class TestSSHHook(unittest.TestCase):
 
         with hook.get_conn():
             ssh_mock.return_value.connect.assert_called_once_with(
+                banner_timeout=30.0,
                 hostname='remote_host',
                 username='username',
                 timeout=20,
@@ -601,6 +608,7 @@ class TestSSHHook(unittest.TestCase):
         # conn_timeout parameter wins over extra options
         with hook.get_conn():
             ssh_mock.return_value.connect.assert_called_once_with(
+                banner_timeout=30.0,
                 hostname='remote_host',
                 username='username',
                 timeout=15,
@@ -624,6 +632,7 @@ class TestSSHHook(unittest.TestCase):
         # conn_timeout parameter wins over extra options
         with hook.get_conn():
             ssh_mock.return_value.connect.assert_called_once_with(
+                banner_timeout=30.0,
                 hostname='remote_host',
                 username='username',
                 timeout=15,
@@ -679,6 +688,7 @@ class TestSSHHook(unittest.TestCase):
         # conn_timeout parameter wins over extra options
         with hook.get_conn():
             ssh_mock.return_value.connect.assert_called_once_with(
+                banner_timeout=30.0,
                 hostname='remote_host',
                 username='username',
                 timeout=expected_value,
diff --git a/tests/providers/ssh/operators/test_ssh.py b/tests/providers/ssh/operators/test_ssh.py
index b715dcf..a58a07c 100644
--- a/tests/providers/ssh/operators/test_ssh.py
+++ b/tests/providers/ssh/operators/test_ssh.py
@@ -51,7 +51,7 @@ class TestSSHOperator:
     def setup_method(self):
         from airflow.providers.ssh.hooks.ssh import SSHHook
 
-        hook = SSHHook(ssh_conn_id='ssh_default')
+        hook = SSHHook(ssh_conn_id='ssh_default', banner_timeout=100)
         hook.no_host_key_check = True
         self.dag = DAG('ssh_test', default_args={'start_date': DEFAULT_DATE})
         self.hook = hook
@@ -60,7 +60,13 @@ class TestSSHOperator:
         timeout = 20
         ssh_id = "ssh_default"
         with self.dag:
-            task = SSHOperator(task_id="test", command=COMMAND, timeout=timeout, ssh_conn_id="ssh_default")
+            task = SSHOperator(
+                task_id="test",
+                command=COMMAND,
+                timeout=timeout,
+                ssh_conn_id="ssh_default",
+                banner_timeout=100,
+            )
         task.execute(None)
         assert timeout == task.ssh_hook.conn_timeout
         assert ssh_id == task.ssh_hook.ssh_conn_id
@@ -76,6 +82,7 @@ class TestSSHOperator:
                 conn_timeout=conn_timeout,
                 cmd_timeout=cmd_timeout,
                 ssh_conn_id="ssh_default",
+                banner_timeout=100,
             )
         task.execute(None)
         assert conn_timeout == task.ssh_hook.conn_timeout
@@ -90,6 +97,7 @@ class TestSSHOperator:
             ssh_hook=self.hook,
             command=COMMAND,
             do_xcom_push=True,
+            banner_timeout=100,
         )
         ti.run()
         assert ti.duration is not None
@@ -104,6 +112,7 @@ class TestSSHOperator:
             ssh_hook=self.hook,
             command=COMMAND,
             do_xcom_push=True,
+            banner_timeout=100,
         )
         ti.run()
         assert ti.duration is not None
@@ -119,6 +128,7 @@ class TestSSHOperator:
             command=COMMAND,
             do_xcom_push=True,
             environment={'TEST': 'value'},
+            banner_timeout=100,
         )
         ti.run()
         assert ti.duration is not None
@@ -133,6 +143,7 @@ class TestSSHOperator:
             ssh_hook=self.hook,
             command="sleep 1",
             do_xcom_push=True,
+            banner_timeout=100,
         )
         ti.run()
         assert ti.duration is not None
@@ -153,6 +164,7 @@ class TestSSHOperator:
             command=COMMAND,
             timeout=TIMEOUT,
             dag=self.dag,
+            banner_timeout=100,
         )
         try:
             task_1.execute(None)
@@ -166,6 +178,7 @@ class TestSSHOperator:
             command=COMMAND,
             timeout=TIMEOUT,
             dag=self.dag,
+            banner_timeout=100,
         )
         try:
             task_2.execute(None)
@@ -181,6 +194,7 @@ class TestSSHOperator:
             command=COMMAND,
             timeout=TIMEOUT,
             dag=self.dag,
+            banner_timeout=100,
         )
         task_3.execute(None)
         assert task_3.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id
@@ -193,6 +207,7 @@ class TestSSHOperator:
             timeout=TIMEOUT,
             dag=self.dag,
             remote_host='operator_remote_host',
+            banner_timeout=100,
         )
         try:
             task_4.execute(None)
@@ -220,6 +235,7 @@ class TestSSHOperator:
             cmd_timeout=TIMEOUT,
             get_pty=get_pty_in,
             dag=self.dag,
+            banner_timeout=100,
         )
         if command is None:
             with pytest.raises(AirflowException) as ctx:
@@ -237,6 +253,7 @@ class TestSSHOperator:
             ssh_hook=self.hook,
             command="ls",
             dag=self.dag,
+            banner_timeout=100,
         )
 
         se = SSHClientSideEffect(self.hook)
@@ -261,7 +278,7 @@ class TestSSHOperator:
                     success = True
                 return success
 
-        task = CustomSSHOperator(task_id="test", ssh_hook=self.hook, dag=self.dag)
+        task = CustomSSHOperator(task_id="test", ssh_hook=self.hook, dag=self.dag, banner_timeout=100)
         se = SSHClientSideEffect(self.hook)
         with unittest.mock.patch.object(task, 'get_ssh_client') as mock_get, unittest.mock.patch.object(
             task, 'run_ssh_client_command'
@@ -294,6 +311,7 @@ class TestSSHOperator:
             ssh_hook=self.hook,
             command=command,
             dag=self.dag,
+            banner_timeout=100,
         )
         with pytest.raises(AirflowException, match=f"error running cmd: {command}, error: .*"):
             task.execute(None)