You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/06/18 21:09:09 UTC

[airflow] branch main updated: Get rid of TimedJSONWebSignatureSerializer (#24519)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1f8e4c9553 Get rid of TimedJSONWebSignatureSerializer (#24519)
1f8e4c9553 is described below

commit 1f8e4c955394b31462956501d9a6741b98892f3a
Author: Jarek Potiuk <ja...@polidea.com>
AuthorDate: Sat Jun 18 23:08:51 2022 +0200

    Get rid of TimedJSONWebSignatureSerializer (#24519)
    
    The TimedJSONWebSignatureSerializer has been deprecated from the
    itsdangerous library and they recommended to use dedicated
    libraries for it.
    
    https://github.com/pallets/itsdangerous/issues/129
    
    Since we are going to move to FAB 4+ with #22397 where newer version of
    itsdangerous is used, we need to switch to another library.
    
    We are already using PyJWT so the choice is obvious.
    
    Additionally to switching, the following improvements were done:
    
    * the use of JWT claims has been fixed to follow JWT standard.
      We were using "iat" header wrongly. The specification of JWT only
      expects the header to be there and be valid UTC timestamp, but the
      claim does not impact maturity of the signature - the signature
      is valid if iat is in the future.
      Instead "nbf" - "not before" claim should be used to verify if the
      request is not coming from the future. We now require all claims
      to be present in the request.
    
    * rather than using salt/signing_context we switched to standard
      JWT "audience" claim (same end result)
    
    * we have now much better diagnostics on the server side of the
      reason why request is forbidden - explicit error messages
      are printed in server logs and details of the exception. This
      is secure, we do not spill the information about the reason
      to the client, it's only available in server logs, so there is
      no risk attacker could use it.
    
    * the JWTSigner is "use-agnostic". We should be able to use the
      same class for any other signatures (Internal API from AIP-44)
      with just different audience
    
    * Short, 5 seconds default clock skew is allowed, to account for
      systems that have "almost" synchronized time
    
    * more tests addded with proper time freezing testing both
      expiry and immaturity of the request
    
    This change is not a breaking one because the JWT authentication
    details are not "public API" - but in case someone reverse engineered
    our claims and implemented their own log file retrieval, we
    should add a change in our changelog - therefore newsfragment
    is added.
---
 airflow/utils/jwt_signer.py            |  82 ++++++++++++++++++
 airflow/utils/log/file_task_handler.py |  18 ++--
 airflow/utils/serve_logs.py            |  84 +++++++++++++------
 newsfragments/24519.misc.rst           |   1 +
 tests/utils/test_serve_logs.py         | 147 +++++++++++++++++++++++++++------
 5 files changed, 272 insertions(+), 60 deletions(-)

diff --git a/airflow/utils/jwt_signer.py b/airflow/utils/jwt_signer.py
new file mode 100644
index 0000000000..941a3d0598
--- /dev/null
+++ b/airflow/utils/jwt_signer.py
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from datetime import datetime, timedelta
+from typing import Any, Dict
+
+import jwt
+
+
+class JWTSigner:
+    """
+    Signs and verifies JWT Token. Used to authorise and verify requests.
+
+    :param secret_key: key used to sign the request
+    :param expiration_time_in_seconds: time after which the token becomes invalid (in seconds)
+    :param audience: audience that the request is expected to have
+    :param leeway_in_seconds: leeway that allows for a small clock skew between the two parties
+    :param algorithm: algorithm used for signing
+    """
+
+    def __init__(
+        self,
+        secret_key: str,
+        expiration_time_in_seconds: int,
+        audience: str,
+        leeway_in_seconds: int = 5,
+        algorithm: str = "HS512",
+    ):
+        self._secret_key = secret_key
+        self._expiration_time_in_seconds = expiration_time_in_seconds
+        self._audience = audience
+        self._leeway_in_seconds = leeway_in_seconds
+        self._algorithm = algorithm
+
+    def generate_signed_token(self, extra_payload: Dict[str, Any]) -> str:
+        """
+        Generate JWT with extra payload added.
+        :param extra_payload: extra payload that is added to the signed token
+        :return: signed token
+        """
+        jwt_dict = {
+            "aud": self._audience,
+            "iat": datetime.utcnow(),
+            "nbf": datetime.utcnow(),
+            "exp": datetime.utcnow() + timedelta(seconds=self._expiration_time_in_seconds),
+        }
+        jwt_dict.update(extra_payload)
+        token = jwt.encode(
+            jwt_dict,
+            self._secret_key,
+            algorithm=self._algorithm,
+        )
+        return token
+
+    def verify_token(self, token: str) -> Dict[str, Any]:
+        payload = jwt.decode(
+            token,
+            self._secret_key,
+            leeway=timedelta(seconds=self._leeway_in_seconds),
+            algorithms=[self._algorithm],
+            options={
+                "verify_signature": True,
+                "require_exp": True,
+                "require_iat": True,
+                "require_nbf": True,
+            },
+            audience=self._audience,
+        )
+        return payload
diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py
index db34ea5f6b..2c53529a72 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -23,11 +23,10 @@ from datetime import datetime
 from pathlib import Path
 from typing import TYPE_CHECKING, Optional, Tuple
 
-from itsdangerous import TimedJSONWebSignatureSerializer
-
 from airflow.configuration import AirflowConfigException, conf
 from airflow.utils.context import Context
 from airflow.utils.helpers import parse_template_string, render_template_to_string
+from airflow.utils.jwt_signer import JWTSigner
 from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler
 from airflow.utils.session import create_session
 
@@ -201,16 +200,17 @@ class FileTaskHandler(logging.Handler):
                 except (AirflowConfigException, ValueError):
                     pass
 
-                signer = TimedJSONWebSignatureSerializer(
+                signer = JWTSigner(
                     secret_key=conf.get('webserver', 'secret_key'),
-                    algorithm_name='HS512',
-                    expires_in=conf.getint('webserver', 'log_request_clock_grace', fallback=30),
-                    # This isn't really a "salt", more of a signing context
-                    salt='task-instance-logs',
+                    expiration_time_in_seconds=conf.getint(
+                        'webserver', 'log_request_clock_grace', fallback=30
+                    ),
+                    audience="task-instance-logs",
                 )
-
                 response = httpx.get(
-                    url, timeout=timeout, headers={'Authorization': signer.dumps(log_relative_path)}
+                    url,
+                    timeout=timeout,
+                    headers={b'Authorization': signer.generate_signed_token({"filename": log_relative_path})},
                 )
                 response.encoding = "utf-8"
 
diff --git a/airflow/utils/serve_logs.py b/airflow/utils/serve_logs.py
index 50fdb47a02..e14162178b 100644
--- a/airflow/utils/serve_logs.py
+++ b/airflow/utils/serve_logs.py
@@ -16,55 +16,89 @@
 # under the License.
 
 """Serve logs process"""
+import logging
 import os
-import time
 
 import gunicorn.app.base
 from flask import Flask, abort, request, send_from_directory
-from itsdangerous import TimedJSONWebSignatureSerializer
+from jwt.exceptions import (
+    ExpiredSignatureError,
+    ImmatureSignatureError,
+    InvalidAudienceError,
+    InvalidIssuedAtError,
+    InvalidSignatureError,
+)
 from setproctitle import setproctitle
 
 from airflow.configuration import conf
+from airflow.utils.docs import get_docs_url
+from airflow.utils.jwt_signer import JWTSigner
+
+logger = logging.getLogger(__name__)
 
 
 def create_app():
     flask_app = Flask(__name__, static_folder=None)
-    max_request_age = conf.getint('webserver', 'log_request_clock_grace', fallback=30)
+    expiration_time_in_seconds = conf.getint('webserver', 'log_request_clock_grace', fallback=30)
     log_directory = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER'))
 
-    signer = TimedJSONWebSignatureSerializer(
+    signer = JWTSigner(
         secret_key=conf.get('webserver', 'secret_key'),
-        algorithm_name='HS512',
-        expires_in=max_request_age,
-        # This isn't really a "salt", more of a signing context
-        salt='task-instance-logs',
+        expiration_time_in_seconds=expiration_time_in_seconds,
+        audience="task-instance-logs",
     )
 
     # Prevent direct access to the logs port
     @flask_app.before_request
     def validate_pre_signed_url():
         try:
-            auth = request.headers['Authorization']
-
-            # We don't actually care about the payload, just that the signature
-            # was valid and the `exp` claim is correct
-            filename, headers = signer.loads(auth, return_header=True)
-
-            issued_at = int(headers['iat'])
-            expires_at = int(headers['exp'])
-        except Exception:
+            auth = request.headers.get('Authorization')
+            if auth is None:
+                logger.warning("The Authorization header is missing: %s.", request.headers)
+                abort(403)
+            payload = signer.verify_token(auth)
+            token_filename = payload.get("filename")
+            request_filename = request.view_args['filename']
+            if token_filename is None:
+                logger.warning("The payload does not contain 'filename' key: %s.", payload)
+                abort(403)
+            if token_filename != request_filename:
+                logger.warning(
+                    "The payload log_relative_path key is different than the one in token:"
+                    "Request path: %s. Token path: %s.",
+                    request_filename,
+                    token_filename,
+                )
+                abort(403)
+        except InvalidAudienceError:
+            logger.warning("Invalid audience for the request", exc_info=True)
             abort(403)
-
-        if filename != request.view_args['filename']:
+        except InvalidSignatureError:
+            logger.warning("The signature of the request was wrong", exc_info=True)
             abort(403)
-
-        # Validate the `iat` and `exp` are within `max_request_age` of now.
-        now = int(time.time())
-        if abs(now - issued_at) > max_request_age:
+        except ImmatureSignatureError:
+            logger.warning("The signature of the request was sent from the future", exc_info=True)
             abort(403)
-        if abs(now - expires_at) > max_request_age:
+        except ExpiredSignatureError:
+            logger.warning(
+                "The signature of the request has expired. Make sure that all components "
+                "in your system have synchronized clocks. "
+                "See more at %s",
+                get_docs_url("configurations-ref.html#secret-key"),
+                exc_info=True,
+            )
             abort(403)
-        if issued_at > expires_at or expires_at - issued_at > max_request_age:
+        except InvalidIssuedAtError:
+            logger.warning(
+                "The request was issues in the future. Make sure that all components "
+                "in your system have synchronized clocks. "
+                "See more at %s",
+                get_docs_url("configurations-ref.html#secret-key"),
+                exc_info=True,
+            )
+            abort(403)
+        except Exception:
+            logger.warning("Unknown error", exc_info=True)
             abort(403)
 
     @flask_app.route('/log/<path:filename>')
diff --git a/newsfragments/24519.misc.rst b/newsfragments/24519.misc.rst
new file mode 100644
index 0000000000..799d9141d2
--- /dev/null
+++ b/newsfragments/24519.misc.rst
@@ -0,0 +1 @@
+The JWT claims in the request to retrieve logs have been standardized: we use "nbf" and "aud" claims for maturity and audience of the requests. Also "filename" payload field is used to keep log name.
diff --git a/tests/utils/test_serve_logs.py b/tests/utils/test_serve_logs.py
index 168a43a012..f8d3881759 100644
--- a/tests/utils/test_serve_logs.py
+++ b/tests/utils/test_serve_logs.py
@@ -14,12 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import datetime
 from typing import TYPE_CHECKING
 
+import jwt
 import pytest
-from itsdangerous import TimedJSONWebSignatureSerializer
+from freezegun import freeze_time
 
 from airflow.configuration import conf
+from airflow.utils.jwt_signer import JWTSigner
 from airflow.utils.serve_logs import create_app
 from tests.test_utils.config import conf_vars
 
@@ -47,12 +50,19 @@ def sample_log(tmpdir):
 
 @pytest.fixture
 def signer():
-    return TimedJSONWebSignatureSerializer(
+    return JWTSigner(
         secret_key=conf.get('webserver', 'secret_key'),
-        algorithm_name='HS512',
-        expires_in=30,
-        # This isn't really a "salt", more of a signing context
-        salt='task-instance-logs',
+        expiration_time_in_seconds=30,
+        audience="task-instance-logs",
+    )
+
+
+@pytest.fixture
+def different_audience():
+    return JWTSigner(
+        secret_key=conf.get('webserver', 'secret_key'),
+        expiration_time_in_seconds=30,
+        audience="different-audience",
     )
 
 
@@ -62,49 +72,134 @@ class TestServeLogs:
         assert 403 == client.get('/log/sample.log').status_code
 
     def test_should_serve_file(self, client: "FlaskClient", signer):
+        response = client.get(
+            '/log/sample.log',
+            headers={
+                'Authorization': signer.generate_signed_token({"filename": 'sample.log'}),
+            },
+        )
+        assert response.data.decode() == LOG_DATA
+        assert response.status_code == 200
+
+    def test_forbidden_different_logname(self, client: "FlaskClient", signer):
+        response = client.get(
+            '/log/sample.log',
+            headers={
+                'Authorization': signer.generate_signed_token({"filename": 'different.log'}),
+            },
+        )
+        assert response.status_code == 403
+
+    def test_forbidden_expired(self, client: "FlaskClient", signer):
+        with freeze_time("2010-01-14"):
+            token = signer.generate_signed_token({"filename": 'sample.log'})
+        assert (
+            client.get(
+                '/log/sample.log',
+                headers={
+                    'Authorization': token,
+                },
+            ).status_code
+            == 403
+        )
+
+    def test_forbidden_future(self, client: "FlaskClient", signer):
+        with freeze_time(datetime.datetime.utcnow() + datetime.timedelta(seconds=3600)):
+            token = signer.generate_signed_token({"filename": 'sample.log'})
         assert (
-            LOG_DATA
-            == client.get(
+            client.get(
                 '/log/sample.log',
                 headers={
-                    'Authorization': signer.dumps('sample.log'),
+                    'Authorization': token,
                 },
-            ).data.decode()
+            ).status_code
+            == 403
         )
 
-    def test_forbidden_too_long_validity(self, client: "FlaskClient", signer):
-        signer.expires_in = 3600
+    def test_ok_with_short_future_skew(self, client: "FlaskClient", signer):
+        with freeze_time(datetime.datetime.utcnow() + datetime.timedelta(seconds=1)):
+            token = signer.generate_signed_token({"filename": 'sample.log'})
         assert (
-            403
-            == client.get(
+            client.get(
                 '/log/sample.log',
                 headers={
-                    'Authorization': signer.dumps('sample.log'),
+                    'Authorization': token,
                 },
             ).status_code
+            == 200
         )
 
-    def test_forbidden_expired(self, client: "FlaskClient", signer):
-        # Fake the time we think we are
-        signer.now = lambda: 0
+    def test_ok_with_short_past_skew(self, client: "FlaskClient", signer):
+        with freeze_time(datetime.datetime.utcnow() - datetime.timedelta(seconds=31)):
+            token = signer.generate_signed_token({"filename": 'sample.log'})
+        assert (
+            client.get(
+                '/log/sample.log',
+                headers={
+                    'Authorization': token,
+                },
+            ).status_code
+            == 200
+        )
+
+    def test_forbidden_with_long_future_skew(self, client: "FlaskClient", signer):
+        with freeze_time(datetime.datetime.utcnow() + datetime.timedelta(seconds=10)):
+            token = signer.generate_signed_token({"filename": 'sample.log'})
+        assert (
+            client.get(
+                '/log/sample.log',
+                headers={
+                    'Authorization': token,
+                },
+            ).status_code
+            == 403
+        )
+
+    def test_forbidden_with_long_past_skew(self, client: "FlaskClient", signer):
+        with freeze_time(datetime.datetime.utcnow() - datetime.timedelta(seconds=40)):
+            token = signer.generate_signed_token({"filename": 'sample.log'})
+        assert (
+            client.get(
+                '/log/sample.log',
+                headers={
+                    'Authorization': token,
+                },
+            ).status_code
+            == 403
+        )
+
+    def test_wrong_audience(self, client: "FlaskClient", different_audience):
         assert (
-            403
-            == client.get(
+            client.get(
                 '/log/sample.log',
                 headers={
-                    'Authorization': signer.dumps('sample.log'),
+                    'Authorization': different_audience.generate_signed_token({"filename": 'sample.log'}),
                 },
             ).status_code
+            == 403
         )
 
-    def test_wrong_context(self, client: "FlaskClient", signer):
-        signer.salt = None
+    @pytest.mark.parametrize("claim_to_remove", ["iat", "exp", "nbf", "aud"])
+    def test_missing_claims(self, claim_to_remove: str, client: "FlaskClient"):
+        jwt_dict = {
+            "aud": "task-instance-logs",
+            "iat": datetime.datetime.utcnow(),
+            "nbf": datetime.datetime.utcnow(),
+            "exp": datetime.datetime.utcnow() + datetime.timedelta(seconds=30),
+        }
+        del jwt_dict[claim_to_remove]
+        jwt_dict.update({"filename": 'sample.log'})
+        token = jwt.encode(
+            jwt_dict,
+            conf.get('webserver', 'secret_key'),
+            algorithm="HS512",
+        )
         assert (
-            403
-            == client.get(
+            client.get(
                 '/log/sample.log',
                 headers={
-                    'Authorization': signer.dumps('sample.log'),
+                    'Authorization': token,
                 },
             ).status_code
+            == 403
         )