You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/03/19 15:06:09 UTC

[airflow] 07/42: [AIRFLOW-7044] Host key can be specified via SSH connection extras. (#12944)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 472077ec5df36f83490739cc102c2b709ad7db37
Author: Andreas Franzén <an...@devil.se>
AuthorDate: Fri Jan 8 12:02:53 2021 +0100

    [AIRFLOW-7044] Host key can be specified via SSH connection extras. (#12944)
    
    (cherry picked from commit 52339a55c054bddd1d46253575274a3d5d141ebe)
---
 ...2da_increase_size_of_connection_extra_field_.py | 56 +++++++++++++
 airflow/models/connection.py                       |  2 +-
 airflow/providers/sftp/hooks/sftp.py               |  6 ++
 airflow/providers/ssh/hooks/ssh.py                 | 18 ++++-
 .../connections/ssh.rst                            |  6 +-
 docs/spelling_wordlist.txt                         |  1 +
 tests/providers/sftp/hooks/test_sftp.py            | 41 +++++++++-
 tests/providers/ssh/hooks/test_ssh.py              | 93 ++++++++++++++++++++++
 8 files changed, 216 insertions(+), 7 deletions(-)

diff --git a/airflow/migrations/versions/449b4072c2da_increase_size_of_connection_extra_field_.py b/airflow/migrations/versions/449b4072c2da_increase_size_of_connection_extra_field_.py
new file mode 100644
index 0000000..d3d9432
--- /dev/null
+++ b/airflow/migrations/versions/449b4072c2da_increase_size_of_connection_extra_field_.py
@@ -0,0 +1,56 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Increase size of connection.extra field to handle multiple RSA keys
+
+Revision ID: 449b4072c2da
+Revises: e959f08ac86c
+Create Date: 2020-03-16 19:02:55.337710
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = '449b4072c2da'
+down_revision = 'e959f08ac86c'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    """Apply increase_length_for_connection_password"""
+    with op.batch_alter_table('connection', schema=None) as batch_op:
+        batch_op.alter_column(
+            'extra',
+            existing_type=sa.VARCHAR(length=5000),
+            type_=sa.TEXT(),
+            existing_nullable=True,
+        )
+
+
+def downgrade():
+    """Unapply increase_length_for_connection_password"""
+    with op.batch_alter_table('connection', schema=None) as batch_op:
+        batch_op.alter_column(
+            'extra',
+            existing_type=sa.TEXT(),
+            type_=sa.VARCHAR(length=5000),
+            existing_nullable=True,
+        )
diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index 1159a44..c030571 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -102,7 +102,7 @@ class Connection(Base, LoggingMixin):  # pylint: disable=too-many-instance-attri
     port = Column(Integer())
     is_encrypted = Column(Boolean, unique=False, default=False)
     is_extra_encrypted = Column(Boolean, unique=False, default=False)
-    _extra = Column('extra', String(5000))
+    _extra = Column('extra', Text())
 
     def __init__(  # pylint: disable=too-many-arguments
         self,
diff --git a/airflow/providers/sftp/hooks/sftp.py b/airflow/providers/sftp/hooks/sftp.py
index 498f362..e2a991e 100644
--- a/airflow/providers/sftp/hooks/sftp.py
+++ b/airflow/providers/sftp/hooks/sftp.py
@@ -115,6 +115,12 @@ class SFTPHook(SSHHook):
             cnopts = pysftp.CnOpts()
             if self.no_host_key_check:
                 cnopts.hostkeys = None
+            else:
+                if self.host_key is not None:
+                    cnopts.hostkeys.add(self.remote_host, 'ssh-rsa', self.host_key)
+                else:
+                    pass  # will fallback to system host keys if none explicitly specified in conn extra
+
             cnopts.compression = self.compress
             cnopts.ciphers = self.ciphers
             conn_params = {
diff --git a/airflow/providers/ssh/hooks/ssh.py b/airflow/providers/ssh/hooks/ssh.py
index d420b1b..1b35db3 100644
--- a/airflow/providers/ssh/hooks/ssh.py
+++ b/airflow/providers/ssh/hooks/ssh.py
@@ -19,6 +19,7 @@
 import getpass
 import os
 import warnings
+from base64 import decodebytes
 from io import StringIO
 from typing import Dict, Optional, Tuple, Union
 
@@ -30,7 +31,7 @@ from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 
 
-class SSHHook(BaseHook):
+class SSHHook(BaseHook):  # pylint: disable=too-many-instance-attributes
     """
     Hook for ssh remote execution using Paramiko.
     ref: https://github.com/paramiko/paramiko
@@ -72,7 +73,7 @@ class SSHHook(BaseHook):
             },
         }
 
-    def __init__(
+    def __init__(  # pylint: disable=too-many-statements
         self,
         ssh_conn_id: Optional[str] = None,
         remote_host: Optional[str] = None,
@@ -99,6 +100,7 @@ class SSHHook(BaseHook):
         self.no_host_key_check = True
         self.allow_host_key_change = False
         self.host_proxy = None
+        self.host_key = None
         self.look_for_keys = True
 
         # Placeholder for deprecated __enter__
@@ -149,7 +151,9 @@ class SSHHook(BaseHook):
                     and str(extra_options["look_for_keys"]).lower() == 'false'
                 ):
                     self.look_for_keys = False
-
+                if "host_key" in extra_options and self.no_host_key_check is False:
+                    decoded_host_key = decodebytes(extra_options["host_key"].encode('utf-8'))
+                    self.host_key = paramiko.RSAKey(data=decoded_host_key)
         if self.pkey and self.key_file:
             raise AirflowException(
                 "Params key_file and private_key both provided.  Must provide no more than one."
@@ -198,10 +202,18 @@ class SSHHook(BaseHook):
                 'This wont protect against Man-In-The-Middle attacks'
             )
             client.load_system_host_keys()
+
         if self.no_host_key_check:
             self.log.warning('No Host Key Verification. This wont protect against Man-In-The-Middle attacks')
             # Default is RejectPolicy
             client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+        else:
+            if self.host_key is not None:
+                client_host_keys = client.get_host_keys()
+                client_host_keys.add(self.remote_host, 'ssh-rsa', self.host_key)
+            else:
+                pass  # will fallback to system host keys if none explicitly specified in conn extra
+
         connect_kwargs = dict(
             hostname=self.remote_host,
             username=self.username,
diff --git a/docs/apache-airflow-providers-ssh/connections/ssh.rst b/docs/apache-airflow-providers-ssh/connections/ssh.rst
index 54e902e..f320381 100644
--- a/docs/apache-airflow-providers-ssh/connections/ssh.rst
+++ b/docs/apache-airflow-providers-ssh/connections/ssh.rst
@@ -47,9 +47,10 @@ Extra (optional)
     * ``private_key_passphrase`` - Content of the private key passphrase used to decrypt the private key.
     * ``timeout`` - An optional timeout (in seconds) for the TCP connect. Default is ``10``.
     * ``compress`` - ``true`` to ask the remote client/server to compress traffic; ``false`` to refuse compression. Default is ``true``.
-    * ``no_host_key_check`` - Set to ``false`` to restrict connecting to hosts with no entries in ``~/.ssh/known_hosts`` (Hosts file). This provides maximum protection against trojan horse attacks, but can be troublesome when the ``/etc/ssh/ssh_known_hosts`` file is poorly maintained or connections to new hosts are frequently made. This option forces the user to manually add all new hosts. Default is ``true``, ssh will automatically add new host keys to the user known hosts files.
+    * ``no_host_key_check`` - Set to ``false`` to restrict connecting to hosts with either no entries in ``~/.ssh/known_hosts`` (Hosts file) or not present in the ``host_key`` extra. This provides maximum protection against trojan horse attacks, but can be troublesome when the ``/etc/ssh/ssh_known_hosts`` file is poorly maintained or connections to new hosts are frequently made. This option forces the user to manually add all new hosts. Default is ``true``, ssh will automatically add new [...]
     * ``allow_host_key_change`` - Set to ``true`` if you want to allow connecting to hosts that has host key changed or when you get 'REMOTE HOST IDENTIFICATION HAS CHANGED' error.  This wont protect against Man-In-The-Middle attacks. Other possible solution is to remove the host entry from ``~/.ssh/known_hosts`` file. Default is ``false``.
     * ``look_for_keys`` - Set to ``false`` if you want to disable searching for discoverable private key files in ``~/.ssh/``
+    * ``host_key`` - The base64 encoded ssh-rsa public key of the host, as you would find in the ``known_hosts`` file. Specifying this, along with ``no_host_key_check=False`` allows you to only make the connection if the public key of the endpoint matches this value.
 
     Example "extras" field:
 
@@ -59,9 +60,10 @@ Extra (optional)
           "key_file": "/home/airflow/.ssh/id_rsa",
           "timeout": "10",
           "compress": "false",
+          "look_for_keys": "false",
           "no_host_key_check": "false",
           "allow_host_key_change": "false",
-          "look_for_keys": "false"
+          "host_key": "AAAHD...YDWwq=="
        }
 
     When specifying the connection as URI (in :envvar:`AIRFLOW_CONN_{CONN_ID}` variable) you should specify it
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 238021e..84f1860 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1157,6 +1157,7 @@ rootcss
 rowid
 rpc
 rshift
+rsa
 rst
 rtype
 ru
diff --git a/tests/providers/sftp/hooks/test_sftp.py b/tests/providers/sftp/hooks/test_sftp.py
index 45097e6..9211c30 100644
--- a/tests/providers/sftp/hooks/test_sftp.py
+++ b/tests/providers/sftp/hooks/test_sftp.py
@@ -15,12 +15,14 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+import json
 import os
 import shutil
 import unittest
+from io import StringIO
 from unittest import mock
 
+import paramiko
 import pysftp
 from parameterized import parameterized
 
@@ -28,6 +30,15 @@ from airflow.models import Connection
 from airflow.providers.sftp.hooks.sftp import SFTPHook
 from airflow.utils.session import provide_session
 
+
+def generate_host_key(pkey: paramiko.PKey):
+    key_fh = StringIO()
+    pkey.write_private_key(key_fh)
+    key_fh.seek(0)
+    key_obj = paramiko.RSAKey(file_obj=key_fh)
+    return key_obj.get_base64()
+
+
 TMP_PATH = '/tmp'
 TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
 SUB_DIR = "sub_dir"
@@ -35,6 +46,9 @@ TMP_FILE_FOR_TESTS = 'test_file.txt'
 
 SFTP_CONNECTION_USER = "root"
 
+TEST_PKEY = paramiko.RSAKey.generate(4096)
+TEST_HOST_KEY = generate_host_key(pkey=TEST_PKEY)
+
 
 class TestSFTPHook(unittest.TestCase):
     @provide_session
@@ -178,6 +192,31 @@ class TestSFTPHook(unittest.TestCase):
         hook = SFTPHook()
         self.assertEqual(hook.no_host_key_check, False)
 
+    @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
+    def test_host_key_default(self, get_connection):
+        connection = Connection(login='login', host='host')
+        get_connection.return_value = connection
+        hook = SFTPHook()
+        self.assertEqual(hook.host_key, None)
+
+    @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
+    def test_host_key(self, get_connection):
+        connection = Connection(
+            login='login',
+            host='host',
+            extra=json.dumps({"host_key": TEST_HOST_KEY, "no_host_key_check": False}),
+        )
+        get_connection.return_value = connection
+        hook = SFTPHook()
+        self.assertEqual(hook.host_key.get_base64(), TEST_HOST_KEY)
+
+    @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
+    def test_host_key_with_no_host_key_check(self, get_connection):
+        connection = Connection(login='login', host='host', extra=json.dumps({"host_key": TEST_HOST_KEY}))
+        get_connection.return_value = connection
+        hook = SFTPHook()
+        self.assertEqual(hook.host_key, None)
+
     @parameterized.expand(
         [
             (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True),
diff --git a/tests/providers/ssh/hooks/test_ssh.py b/tests/providers/ssh/hooks/test_ssh.py
index 027de40..fea52bc 100644
--- a/tests/providers/ssh/hooks/test_ssh.py
+++ b/tests/providers/ssh/hooks/test_ssh.py
@@ -51,8 +51,17 @@ def generate_key_string(pkey: paramiko.PKey, passphrase: Optional[str] = None):
     return key_str
 
 
+def generate_host_key(pkey: paramiko.PKey):
+    key_fh = StringIO()
+    pkey.write_private_key(key_fh)
+    key_fh.seek(0)
+    key_obj = paramiko.RSAKey(file_obj=key_fh)
+    return key_obj.get_base64()
+
+
 TEST_PKEY = paramiko.RSAKey.generate(4096)
 TEST_PRIVATE_KEY = generate_key_string(pkey=TEST_PKEY)
+TEST_HOST_KEY = generate_host_key(pkey=TEST_PKEY)
 
 PASSPHRASE = ''.join(random.choice(string.ascii_letters) for i in range(10))
 TEST_ENCRYPTED_PRIVATE_KEY = generate_key_string(pkey=TEST_PKEY, passphrase=PASSPHRASE)
@@ -63,6 +72,10 @@ class TestSSHHook(unittest.TestCase):
     CONN_SSH_WITH_PRIVATE_KEY_PASSPHRASE_EXTRA = 'ssh_with_private_key_passphrase_extra'
     CONN_SSH_WITH_EXTRA = 'ssh_with_extra'
     CONN_SSH_WITH_EXTRA_FALSE_LOOK_FOR_KEYS = 'ssh_with_extra_false_look_for_keys'
+    CONN_SSH_WITH_HOST_KEY_EXTRA = 'ssh_with_host_key_extra'
+    CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE = 'ssh_with_host_key_and_no_host_key_check_false'
+    CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE = 'ssh_with_host_key_and_no_host_key_check_true'
+    CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE = 'ssh_with_no_host_key_and_no_host_key_check_false'
 
     @classmethod
     def tearDownClass(cls) -> None:
@@ -70,6 +83,11 @@ class TestSSHHook(unittest.TestCase):
             conns_to_reset = [
                 cls.CONN_SSH_WITH_PRIVATE_KEY_EXTRA,
                 cls.CONN_SSH_WITH_PRIVATE_KEY_PASSPHRASE_EXTRA,
+                cls.CONN_SSH_WITH_EXTRA,
+                cls.CONN_SSH_WITH_HOST_KEY_EXTRA,
+                cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
+                cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
+                cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
             ]
             connections = session.query(Connection).filter(Connection.conn_id.in_(conns_to_reset))
             connections.delete(synchronize_session=False)
@@ -116,6 +134,42 @@ class TestSSHHook(unittest.TestCase):
                 ),
             )
         )
+        db.merge_conn(
+            Connection(
+                conn_id=cls.CONN_SSH_WITH_HOST_KEY_EXTRA,
+                host='localhost',
+                conn_type='ssh',
+                extra=json.dumps({"private_key": TEST_PRIVATE_KEY, "host_key": TEST_HOST_KEY}),
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id=cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
+                host='remote_host',
+                conn_type='ssh',
+                extra=json.dumps(
+                    {"private_key": TEST_PRIVATE_KEY, "host_key": TEST_HOST_KEY, "no_host_key_check": False}
+                ),
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id=cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
+                host='remote_host',
+                conn_type='ssh',
+                extra=json.dumps(
+                    {"private_key": TEST_PRIVATE_KEY, "host_key": TEST_HOST_KEY, "no_host_key_check": True}
+                ),
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id=cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
+                host='remote_host',
+                conn_type='ssh',
+                extra=json.dumps({"private_key": TEST_PRIVATE_KEY, "no_host_key_check": False}),
+            )
+        )
 
     @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
     def test_ssh_connection_with_password(self, ssh_mock):
@@ -344,3 +398,42 @@ class TestSSHHook(unittest.TestCase):
                 sock=None,
                 look_for_keys=True,
             )
+
+    @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
+    def test_ssh_connection_with_host_key_extra(self, ssh_client):
+        hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_EXTRA)
+        assert hook.host_key is None  # Since default no_host_key_check = True unless explicit override
+        with hook.get_conn():
+            assert ssh_client.return_value.connect.called is True
+            assert ssh_client.return_value.get_host_keys.return_value.add.called is False
+
+    @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
+    def test_ssh_connection_with_host_key_where_no_host_key_check_is_true(self, ssh_client):
+        hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE)
+        assert hook.host_key is None
+        with hook.get_conn():
+            assert ssh_client.return_value.connect.called is True
+            assert ssh_client.return_value.get_host_keys.return_value.add.called is False
+
+    @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
+    def test_ssh_connection_with_host_key_where_no_host_key_check_is_false(self, ssh_client):
+        hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE)
+        assert hook.host_key.get_base64() == TEST_HOST_KEY
+        with hook.get_conn():
+            assert ssh_client.return_value.connect.called is True
+            assert ssh_client.return_value.get_host_keys.return_value.add.called is True
+            assert ssh_client.return_value.get_host_keys.return_value.add.call_args == mock.call(
+                hook.remote_host, 'ssh-rsa', hook.host_key
+            )
+
+    @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
+    def test_ssh_connection_with_no_host_key_where_no_host_key_check_is_false(self, ssh_client):
+        hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE)
+        assert hook.host_key is None
+        with hook.get_conn():
+            assert ssh_client.return_value.connect.called is True
+            assert ssh_client.return_value.get_host_keys.return_value.add.called is False
+
+
+if __name__ == '__main__':
+    unittest.main()