You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/08/15 18:45:06 UTC

[airflow] 30/45: Allow wildcarded CORS origins (#25553)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-3-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 6821fe12f880696d5219057278b6d2a6c425bd86
Author: Mark Norman Francis <no...@201created.com>
AuthorDate: Fri Aug 5 18:41:05 2022 +0100

    Allow wildcarded CORS origins (#25553)
    
    '*' is a valid 'Access-Control-Allow-Origin' response, but was being
    dropped as it failed to match the Origin header sent in requests.
    
    (cherry picked from commit e81b27e713e9ef6f7104c7038f0c37cc55d96593)
---
 airflow/www/extensions/init_views.py |   8 +-
 tests/api_connexion/test_cors.py     | 140 +++++++++++++++++++++++++++++++++++
 2 files changed, 145 insertions(+), 3 deletions(-)

diff --git a/airflow/www/extensions/init_views.py b/airflow/www/extensions/init_views.py
index 83dbc50eaa..4a2d4a5119 100644
--- a/airflow/www/extensions/init_views.py
+++ b/airflow/www/extensions/init_views.py
@@ -159,11 +159,13 @@ def set_cors_headers_on_response(response):
     allow_headers = conf.get('api', 'access_control_allow_headers')
     allow_methods = conf.get('api', 'access_control_allow_methods')
     allow_origins = conf.get('api', 'access_control_allow_origins')
-    if allow_headers is not None:
+    if allow_headers:
         response.headers['Access-Control-Allow-Headers'] = allow_headers
-    if allow_methods is not None:
+    if allow_methods:
         response.headers['Access-Control-Allow-Methods'] = allow_methods
-    if allow_origins is not None:
+    if allow_origins == '*':
+        response.headers['Access-Control-Allow-Origin'] = '*'
+    elif allow_origins:
         allowed_origins = allow_origins.split(' ')
         origin = request.environ.get('HTTP_ORIGIN', allowed_origins[0])
         if origin in allowed_origins:
diff --git a/tests/api_connexion/test_cors.py b/tests/api_connexion/test_cors.py
new file mode 100644
index 0000000000..30ae19236d
--- /dev/null
+++ b/tests/api_connexion/test_cors.py
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from base64 import b64encode
+
+import pytest
+
+from tests.test_utils.config import conf_vars
+from tests.test_utils.db import clear_db_pools
+
+
+class BaseTestAuth:
+    @pytest.fixture(autouse=True)
+    def set_attrs(self, minimal_app_for_api):
+        self.app = minimal_app_for_api
+
+        sm = self.app.appbuilder.sm
+        tester = sm.find_user(username="test")
+        if not tester:
+            role_admin = sm.find_role("Admin")
+            sm.add_user(
+                username="test",
+                first_name="test",
+                last_name="test",
+                email="test@fab.org",
+                role=role_admin,
+                password="test",
+            )
+
+
+class TestEmptyCors(BaseTestAuth):
+    @pytest.fixture(autouse=True, scope="class")
+    def with_basic_auth_backend(self, minimal_app_for_api):
+        from airflow.www.extensions.init_security import init_api_experimental_auth
+
+        old_auth = getattr(minimal_app_for_api, 'api_auth')
+
+        try:
+            with conf_vars({("api", "auth_backends"): "airflow.api.auth.backend.basic_auth"}):
+                init_api_experimental_auth(minimal_app_for_api)
+                yield
+        finally:
+            setattr(minimal_app_for_api, 'api_auth', old_auth)
+
+    def test_empty_cors_headers(self):
+        token = "Basic " + b64encode(b"test:test").decode()
+        clear_db_pools()
+
+        with self.app.test_client() as test_client:
+            response = test_client.get("/api/v1/pools", headers={"Authorization": token})
+            assert response.status_code == 200
+            assert 'Access-Control-Allow-Headers' not in response.headers
+            assert 'Access-Control-Allow-Methods' not in response.headers
+            assert 'Access-Control-Allow-Origin' not in response.headers
+
+
+class TestCorsOrigin(BaseTestAuth):
+    @pytest.fixture(autouse=True, scope="class")
+    def with_basic_auth_backend(self, minimal_app_for_api):
+        from airflow.www.extensions.init_security import init_api_experimental_auth
+
+        old_auth = getattr(minimal_app_for_api, 'api_auth')
+
+        try:
+            with conf_vars(
+                {
+                    ("api", "auth_backends"): "airflow.api.auth.backend.basic_auth",
+                    ("api", "access_control_allow_origins"): "http://apache.org http://example.com",
+                }
+            ):
+                init_api_experimental_auth(minimal_app_for_api)
+                yield
+        finally:
+            setattr(minimal_app_for_api, 'api_auth', old_auth)
+
+    def test_cors_origin_reflection(self):
+        token = "Basic " + b64encode(b"test:test").decode()
+        clear_db_pools()
+
+        with self.app.test_client() as test_client:
+            response = test_client.get("/api/v1/pools", headers={"Authorization": token})
+            assert response.status_code == 200
+            assert response.headers['Access-Control-Allow-Origin'] == 'http://apache.org'
+
+            response = test_client.get(
+                "/api/v1/pools", headers={"Authorization": token, "Origin": "http://apache.org"}
+            )
+            assert response.status_code == 200
+            assert response.headers['Access-Control-Allow-Origin'] == 'http://apache.org'
+
+            response = test_client.get(
+                "/api/v1/pools", headers={"Authorization": token, "Origin": "http://example.com"}
+            )
+            assert response.status_code == 200
+            assert response.headers['Access-Control-Allow-Origin'] == 'http://example.com'
+
+
+class TestCorsWildcard(BaseTestAuth):
+    @pytest.fixture(autouse=True, scope="class")
+    def with_basic_auth_backend(self, minimal_app_for_api):
+        from airflow.www.extensions.init_security import init_api_experimental_auth
+
+        old_auth = getattr(minimal_app_for_api, 'api_auth')
+
+        try:
+            with conf_vars(
+                {
+                    ("api", "auth_backends"): "airflow.api.auth.backend.basic_auth",
+                    ("api", "access_control_allow_origins"): "*",
+                }
+            ):
+                init_api_experimental_auth(minimal_app_for_api)
+                yield
+        finally:
+            setattr(minimal_app_for_api, 'api_auth', old_auth)
+
+    def test_cors_origin_reflection(self):
+        token = "Basic " + b64encode(b"test:test").decode()
+        clear_db_pools()
+
+        with self.app.test_client() as test_client:
+            response = test_client.get(
+                "/api/v1/pools", headers={"Authorization": token, "Origin": "http://example.com"}
+            )
+            assert response.status_code == 200
+            assert response.headers['Access-Control-Allow-Origin'] == '*'