You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/03/19 15:06:09 UTC
[airflow] 07/42: [AIRFLOW-7044] Host key can be specified via SSH
connection extras. (#12944)
This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 472077ec5df36f83490739cc102c2b709ad7db37
Author: Andreas Franzén <an...@devil.se>
AuthorDate: Fri Jan 8 12:02:53 2021 +0100
[AIRFLOW-7044] Host key can be specified via SSH connection extras. (#12944)
(cherry picked from commit 52339a55c054bddd1d46253575274a3d5d141ebe)
---
...2da_increase_size_of_connection_extra_field_.py | 56 +++++++++++++
airflow/models/connection.py | 2 +-
airflow/providers/sftp/hooks/sftp.py | 6 ++
airflow/providers/ssh/hooks/ssh.py | 18 ++++-
.../connections/ssh.rst | 6 +-
docs/spelling_wordlist.txt | 1 +
tests/providers/sftp/hooks/test_sftp.py | 41 +++++++++-
tests/providers/ssh/hooks/test_ssh.py | 93 ++++++++++++++++++++++
8 files changed, 216 insertions(+), 7 deletions(-)
diff --git a/airflow/migrations/versions/449b4072c2da_increase_size_of_connection_extra_field_.py b/airflow/migrations/versions/449b4072c2da_increase_size_of_connection_extra_field_.py
new file mode 100644
index 0000000..d3d9432
--- /dev/null
+++ b/airflow/migrations/versions/449b4072c2da_increase_size_of_connection_extra_field_.py
@@ -0,0 +1,56 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Increase size of connection.extra field to handle multiple RSA keys
+
+Revision ID: 449b4072c2da
+Revises: e959f08ac86c
+Create Date: 2020-03-16 19:02:55.337710
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = '449b4072c2da'
+down_revision = 'e959f08ac86c'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ """Apply increase_length_for_connection_password"""
+ with op.batch_alter_table('connection', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'extra',
+ existing_type=sa.VARCHAR(length=5000),
+ type_=sa.TEXT(),
+ existing_nullable=True,
+ )
+
+
+def downgrade():
+ """Unapply increase_length_for_connection_password"""
+ with op.batch_alter_table('connection', schema=None) as batch_op:
+ batch_op.alter_column(
+ 'extra',
+ existing_type=sa.TEXT(),
+ type_=sa.VARCHAR(length=5000),
+ existing_nullable=True,
+ )
diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index 1159a44..c030571 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -102,7 +102,7 @@ class Connection(Base, LoggingMixin): # pylint: disable=too-many-instance-attri
port = Column(Integer())
is_encrypted = Column(Boolean, unique=False, default=False)
is_extra_encrypted = Column(Boolean, unique=False, default=False)
- _extra = Column('extra', String(5000))
+ _extra = Column('extra', Text())
def __init__( # pylint: disable=too-many-arguments
self,
diff --git a/airflow/providers/sftp/hooks/sftp.py b/airflow/providers/sftp/hooks/sftp.py
index 498f362..e2a991e 100644
--- a/airflow/providers/sftp/hooks/sftp.py
+++ b/airflow/providers/sftp/hooks/sftp.py
@@ -115,6 +115,12 @@ class SFTPHook(SSHHook):
cnopts = pysftp.CnOpts()
if self.no_host_key_check:
cnopts.hostkeys = None
+ else:
+ if self.host_key is not None:
+ cnopts.hostkeys.add(self.remote_host, 'ssh-rsa', self.host_key)
+ else:
+ pass # will fallback to system host keys if none explicitly specified in conn extra
+
cnopts.compression = self.compress
cnopts.ciphers = self.ciphers
conn_params = {
diff --git a/airflow/providers/ssh/hooks/ssh.py b/airflow/providers/ssh/hooks/ssh.py
index d420b1b..1b35db3 100644
--- a/airflow/providers/ssh/hooks/ssh.py
+++ b/airflow/providers/ssh/hooks/ssh.py
@@ -19,6 +19,7 @@
import getpass
import os
import warnings
+from base64 import decodebytes
from io import StringIO
from typing import Dict, Optional, Tuple, Union
@@ -30,7 +31,7 @@ from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
-class SSHHook(BaseHook):
+class SSHHook(BaseHook): # pylint: disable=too-many-instance-attributes
"""
Hook for ssh remote execution using Paramiko.
ref: https://github.com/paramiko/paramiko
@@ -72,7 +73,7 @@ class SSHHook(BaseHook):
},
}
- def __init__(
+ def __init__( # pylint: disable=too-many-statements
self,
ssh_conn_id: Optional[str] = None,
remote_host: Optional[str] = None,
@@ -99,6 +100,7 @@ class SSHHook(BaseHook):
self.no_host_key_check = True
self.allow_host_key_change = False
self.host_proxy = None
+ self.host_key = None
self.look_for_keys = True
# Placeholder for deprecated __enter__
@@ -149,7 +151,9 @@ class SSHHook(BaseHook):
and str(extra_options["look_for_keys"]).lower() == 'false'
):
self.look_for_keys = False
-
+ if "host_key" in extra_options and self.no_host_key_check is False:
+ decoded_host_key = decodebytes(extra_options["host_key"].encode('utf-8'))
+ self.host_key = paramiko.RSAKey(data=decoded_host_key)
if self.pkey and self.key_file:
raise AirflowException(
"Params key_file and private_key both provided. Must provide no more than one."
@@ -198,10 +202,18 @@ class SSHHook(BaseHook):
'This wont protect against Man-In-The-Middle attacks'
)
client.load_system_host_keys()
+
if self.no_host_key_check:
self.log.warning('No Host Key Verification. This wont protect against Man-In-The-Middle attacks')
# Default is RejectPolicy
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ else:
+ if self.host_key is not None:
+ client_host_keys = client.get_host_keys()
+ client_host_keys.add(self.remote_host, 'ssh-rsa', self.host_key)
+ else:
+ pass # will fallback to system host keys if none explicitly specified in conn extra
+
connect_kwargs = dict(
hostname=self.remote_host,
username=self.username,
diff --git a/docs/apache-airflow-providers-ssh/connections/ssh.rst b/docs/apache-airflow-providers-ssh/connections/ssh.rst
index 54e902e..f320381 100644
--- a/docs/apache-airflow-providers-ssh/connections/ssh.rst
+++ b/docs/apache-airflow-providers-ssh/connections/ssh.rst
@@ -47,9 +47,10 @@ Extra (optional)
* ``private_key_passphrase`` - Content of the private key passphrase used to decrypt the private key.
* ``timeout`` - An optional timeout (in seconds) for the TCP connect. Default is ``10``.
* ``compress`` - ``true`` to ask the remote client/server to compress traffic; ``false`` to refuse compression. Default is ``true``.
- * ``no_host_key_check`` - Set to ``false`` to restrict connecting to hosts with no entries in ``~/.ssh/known_hosts`` (Hosts file). This provides maximum protection against trojan horse attacks, but can be troublesome when the ``/etc/ssh/ssh_known_hosts`` file is poorly maintained or connections to new hosts are frequently made. This option forces the user to manually add all new hosts. Default is ``true``, ssh will automatically add new host keys to the user known hosts files.
+ * ``no_host_key_check`` - Set to ``false`` to restrict connecting to hosts with either no entries in ``~/.ssh/known_hosts`` (Hosts file) or not present in the ``host_key`` extra. This provides maximum protection against trojan horse attacks, but can be troublesome when the ``/etc/ssh/ssh_known_hosts`` file is poorly maintained or connections to new hosts are frequently made. This option forces the user to manually add all new hosts. Default is ``true``, ssh will automatically add new [...]
* ``allow_host_key_change`` - Set to ``true`` if you want to allow connecting to hosts that has host key changed or when you get 'REMOTE HOST IDENTIFICATION HAS CHANGED' error. This wont protect against Man-In-The-Middle attacks. Other possible solution is to remove the host entry from ``~/.ssh/known_hosts`` file. Default is ``false``.
* ``look_for_keys`` - Set to ``false`` if you want to disable searching for discoverable private key files in ``~/.ssh/``
+ * ``host_key`` - The base64 encoded ssh-rsa public key of the host, as you would find in the ``known_hosts`` file. Specifying this, along with ``no_host_key_check=False`` allows you to only make the connection if the public key of the endpoint matches this value.
Example "extras" field:
@@ -59,9 +60,10 @@ Extra (optional)
"key_file": "/home/airflow/.ssh/id_rsa",
"timeout": "10",
"compress": "false",
+ "look_for_keys": "false",
"no_host_key_check": "false",
"allow_host_key_change": "false",
- "look_for_keys": "false"
+ "host_key": "AAAHD...YDWwq=="
}
When specifying the connection as URI (in :envvar:`AIRFLOW_CONN_{CONN_ID}` variable) you should specify it
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 238021e..84f1860 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1157,6 +1157,7 @@ rootcss
rowid
rpc
rshift
+rsa
rst
rtype
ru
diff --git a/tests/providers/sftp/hooks/test_sftp.py b/tests/providers/sftp/hooks/test_sftp.py
index 45097e6..9211c30 100644
--- a/tests/providers/sftp/hooks/test_sftp.py
+++ b/tests/providers/sftp/hooks/test_sftp.py
@@ -15,12 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+import json
import os
import shutil
import unittest
+from io import StringIO
from unittest import mock
+import paramiko
import pysftp
from parameterized import parameterized
@@ -28,6 +30,15 @@ from airflow.models import Connection
from airflow.providers.sftp.hooks.sftp import SFTPHook
from airflow.utils.session import provide_session
+
+def generate_host_key(pkey: paramiko.PKey):
+ key_fh = StringIO()
+ pkey.write_private_key(key_fh)
+ key_fh.seek(0)
+ key_obj = paramiko.RSAKey(file_obj=key_fh)
+ return key_obj.get_base64()
+
+
TMP_PATH = '/tmp'
TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
SUB_DIR = "sub_dir"
@@ -35,6 +46,9 @@ TMP_FILE_FOR_TESTS = 'test_file.txt'
SFTP_CONNECTION_USER = "root"
+TEST_PKEY = paramiko.RSAKey.generate(4096)
+TEST_HOST_KEY = generate_host_key(pkey=TEST_PKEY)
+
class TestSFTPHook(unittest.TestCase):
@provide_session
@@ -178,6 +192,31 @@ class TestSFTPHook(unittest.TestCase):
hook = SFTPHook()
self.assertEqual(hook.no_host_key_check, False)
+ @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
+ def test_host_key_default(self, get_connection):
+ connection = Connection(login='login', host='host')
+ get_connection.return_value = connection
+ hook = SFTPHook()
+ self.assertEqual(hook.host_key, None)
+
+ @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
+ def test_host_key(self, get_connection):
+ connection = Connection(
+ login='login',
+ host='host',
+ extra=json.dumps({"host_key": TEST_HOST_KEY, "no_host_key_check": False}),
+ )
+ get_connection.return_value = connection
+ hook = SFTPHook()
+ self.assertEqual(hook.host_key.get_base64(), TEST_HOST_KEY)
+
+ @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
+ def test_host_key_with_no_host_key_check(self, get_connection):
+ connection = Connection(login='login', host='host', extra=json.dumps({"host_key": TEST_HOST_KEY}))
+ get_connection.return_value = connection
+ hook = SFTPHook()
+ self.assertEqual(hook.host_key, None)
+
@parameterized.expand(
[
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True),
diff --git a/tests/providers/ssh/hooks/test_ssh.py b/tests/providers/ssh/hooks/test_ssh.py
index 027de40..fea52bc 100644
--- a/tests/providers/ssh/hooks/test_ssh.py
+++ b/tests/providers/ssh/hooks/test_ssh.py
@@ -51,8 +51,17 @@ def generate_key_string(pkey: paramiko.PKey, passphrase: Optional[str] = None):
return key_str
+def generate_host_key(pkey: paramiko.PKey):
+ key_fh = StringIO()
+ pkey.write_private_key(key_fh)
+ key_fh.seek(0)
+ key_obj = paramiko.RSAKey(file_obj=key_fh)
+ return key_obj.get_base64()
+
+
TEST_PKEY = paramiko.RSAKey.generate(4096)
TEST_PRIVATE_KEY = generate_key_string(pkey=TEST_PKEY)
+TEST_HOST_KEY = generate_host_key(pkey=TEST_PKEY)
PASSPHRASE = ''.join(random.choice(string.ascii_letters) for i in range(10))
TEST_ENCRYPTED_PRIVATE_KEY = generate_key_string(pkey=TEST_PKEY, passphrase=PASSPHRASE)
@@ -63,6 +72,10 @@ class TestSSHHook(unittest.TestCase):
CONN_SSH_WITH_PRIVATE_KEY_PASSPHRASE_EXTRA = 'ssh_with_private_key_passphrase_extra'
CONN_SSH_WITH_EXTRA = 'ssh_with_extra'
CONN_SSH_WITH_EXTRA_FALSE_LOOK_FOR_KEYS = 'ssh_with_extra_false_look_for_keys'
+ CONN_SSH_WITH_HOST_KEY_EXTRA = 'ssh_with_host_key_extra'
+ CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE = 'ssh_with_host_key_and_no_host_key_check_false'
+ CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE = 'ssh_with_host_key_and_no_host_key_check_true'
+ CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE = 'ssh_with_no_host_key_and_no_host_key_check_false'
@classmethod
def tearDownClass(cls) -> None:
@@ -70,6 +83,11 @@ class TestSSHHook(unittest.TestCase):
conns_to_reset = [
cls.CONN_SSH_WITH_PRIVATE_KEY_EXTRA,
cls.CONN_SSH_WITH_PRIVATE_KEY_PASSPHRASE_EXTRA,
+ cls.CONN_SSH_WITH_EXTRA,
+ cls.CONN_SSH_WITH_HOST_KEY_EXTRA,
+ cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
+ cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
+ cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
]
connections = session.query(Connection).filter(Connection.conn_id.in_(conns_to_reset))
connections.delete(synchronize_session=False)
@@ -116,6 +134,42 @@ class TestSSHHook(unittest.TestCase):
),
)
)
+ db.merge_conn(
+ Connection(
+ conn_id=cls.CONN_SSH_WITH_HOST_KEY_EXTRA,
+ host='localhost',
+ conn_type='ssh',
+ extra=json.dumps({"private_key": TEST_PRIVATE_KEY, "host_key": TEST_HOST_KEY}),
+ )
+ )
+ db.merge_conn(
+ Connection(
+ conn_id=cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
+ host='remote_host',
+ conn_type='ssh',
+ extra=json.dumps(
+ {"private_key": TEST_PRIVATE_KEY, "host_key": TEST_HOST_KEY, "no_host_key_check": False}
+ ),
+ )
+ )
+ db.merge_conn(
+ Connection(
+ conn_id=cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
+ host='remote_host',
+ conn_type='ssh',
+ extra=json.dumps(
+ {"private_key": TEST_PRIVATE_KEY, "host_key": TEST_HOST_KEY, "no_host_key_check": True}
+ ),
+ )
+ )
+ db.merge_conn(
+ Connection(
+ conn_id=cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
+ host='remote_host',
+ conn_type='ssh',
+ extra=json.dumps({"private_key": TEST_PRIVATE_KEY, "no_host_key_check": False}),
+ )
+ )
@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_password(self, ssh_mock):
@@ -344,3 +398,42 @@ class TestSSHHook(unittest.TestCase):
sock=None,
look_for_keys=True,
)
+
+ @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
+ def test_ssh_connection_with_host_key_extra(self, ssh_client):
+ hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_EXTRA)
+ assert hook.host_key is None # Since default no_host_key_check = True unless explicit override
+ with hook.get_conn():
+ assert ssh_client.return_value.connect.called is True
+ assert ssh_client.return_value.get_host_keys.return_value.add.called is False
+
+ @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
+ def test_ssh_connection_with_host_key_where_no_host_key_check_is_true(self, ssh_client):
+ hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE)
+ assert hook.host_key is None
+ with hook.get_conn():
+ assert ssh_client.return_value.connect.called is True
+ assert ssh_client.return_value.get_host_keys.return_value.add.called is False
+
+ @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
+ def test_ssh_connection_with_host_key_where_no_host_key_check_is_false(self, ssh_client):
+ hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE)
+ assert hook.host_key.get_base64() == TEST_HOST_KEY
+ with hook.get_conn():
+ assert ssh_client.return_value.connect.called is True
+ assert ssh_client.return_value.get_host_keys.return_value.add.called is True
+ assert ssh_client.return_value.get_host_keys.return_value.add.call_args == mock.call(
+ hook.remote_host, 'ssh-rsa', hook.host_key
+ )
+
+ @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
+ def test_ssh_connection_with_no_host_key_where_no_host_key_check_is_false(self, ssh_client):
+ hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE)
+ assert hook.host_key is None
+ with hook.get_conn():
+ assert ssh_client.return_value.connect.called is True
+ assert ssh_client.return_value.get_host_keys.return_value.add.called is False
+
+
+if __name__ == '__main__':
+ unittest.main()