You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/08/15 18:45:06 UTC
[airflow] 30/45: Allow wildcarded CORS origins (#25553)
This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch v2-3-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 6821fe12f880696d5219057278b6d2a6c425bd86
Author: Mark Norman Francis <no...@201created.com>
AuthorDate: Fri Aug 5 18:41:05 2022 +0100
Allow wildcarded CORS origins (#25553)
'*' is a valid 'Access-Control-Allow-Origin' response, but was being
dropped as it failed to match the Origin header sent in requests.
(cherry picked from commit e81b27e713e9ef6f7104c7038f0c37cc55d96593)
---
airflow/www/extensions/init_views.py | 8 +-
tests/api_connexion/test_cors.py | 140 +++++++++++++++++++++++++++++++++++
2 files changed, 145 insertions(+), 3 deletions(-)
diff --git a/airflow/www/extensions/init_views.py b/airflow/www/extensions/init_views.py
index 83dbc50eaa..4a2d4a5119 100644
--- a/airflow/www/extensions/init_views.py
+++ b/airflow/www/extensions/init_views.py
@@ -159,11 +159,13 @@ def set_cors_headers_on_response(response):
allow_headers = conf.get('api', 'access_control_allow_headers')
allow_methods = conf.get('api', 'access_control_allow_methods')
allow_origins = conf.get('api', 'access_control_allow_origins')
- if allow_headers is not None:
+ if allow_headers:
response.headers['Access-Control-Allow-Headers'] = allow_headers
- if allow_methods is not None:
+ if allow_methods:
response.headers['Access-Control-Allow-Methods'] = allow_methods
- if allow_origins is not None:
+ if allow_origins == '*':
+ response.headers['Access-Control-Allow-Origin'] = '*'
+ elif allow_origins:
allowed_origins = allow_origins.split(' ')
origin = request.environ.get('HTTP_ORIGIN', allowed_origins[0])
if origin in allowed_origins:
diff --git a/tests/api_connexion/test_cors.py b/tests/api_connexion/test_cors.py
new file mode 100644
index 0000000000..30ae19236d
--- /dev/null
+++ b/tests/api_connexion/test_cors.py
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from base64 import b64encode
+
+import pytest
+
+from tests.test_utils.config import conf_vars
+from tests.test_utils.db import clear_db_pools
+
+
+class BaseTestAuth:
+ @pytest.fixture(autouse=True)
+ def set_attrs(self, minimal_app_for_api):
+ self.app = minimal_app_for_api
+
+ sm = self.app.appbuilder.sm
+ tester = sm.find_user(username="test")
+ if not tester:
+ role_admin = sm.find_role("Admin")
+ sm.add_user(
+ username="test",
+ first_name="test",
+ last_name="test",
+ email="test@fab.org",
+ role=role_admin,
+ password="test",
+ )
+
+
+class TestEmptyCors(BaseTestAuth):
+ @pytest.fixture(autouse=True, scope="class")
+ def with_basic_auth_backend(self, minimal_app_for_api):
+ from airflow.www.extensions.init_security import init_api_experimental_auth
+
+ old_auth = getattr(minimal_app_for_api, 'api_auth')
+
+ try:
+ with conf_vars({("api", "auth_backends"): "airflow.api.auth.backend.basic_auth"}):
+ init_api_experimental_auth(minimal_app_for_api)
+ yield
+ finally:
+ setattr(minimal_app_for_api, 'api_auth', old_auth)
+
+ def test_empty_cors_headers(self):
+ token = "Basic " + b64encode(b"test:test").decode()
+ clear_db_pools()
+
+ with self.app.test_client() as test_client:
+ response = test_client.get("/api/v1/pools", headers={"Authorization": token})
+ assert response.status_code == 200
+ assert 'Access-Control-Allow-Headers' not in response.headers
+ assert 'Access-Control-Allow-Methods' not in response.headers
+ assert 'Access-Control-Allow-Origin' not in response.headers
+
+
+class TestCorsOrigin(BaseTestAuth):
+ @pytest.fixture(autouse=True, scope="class")
+ def with_basic_auth_backend(self, minimal_app_for_api):
+ from airflow.www.extensions.init_security import init_api_experimental_auth
+
+ old_auth = getattr(minimal_app_for_api, 'api_auth')
+
+ try:
+ with conf_vars(
+ {
+ ("api", "auth_backends"): "airflow.api.auth.backend.basic_auth",
+ ("api", "access_control_allow_origins"): "http://apache.org http://example.com",
+ }
+ ):
+ init_api_experimental_auth(minimal_app_for_api)
+ yield
+ finally:
+ setattr(minimal_app_for_api, 'api_auth', old_auth)
+
+ def test_cors_origin_reflection(self):
+ token = "Basic " + b64encode(b"test:test").decode()
+ clear_db_pools()
+
+ with self.app.test_client() as test_client:
+ response = test_client.get("/api/v1/pools", headers={"Authorization": token})
+ assert response.status_code == 200
+ assert response.headers['Access-Control-Allow-Origin'] == 'http://apache.org'
+
+ response = test_client.get(
+ "/api/v1/pools", headers={"Authorization": token, "Origin": "http://apache.org"}
+ )
+ assert response.status_code == 200
+ assert response.headers['Access-Control-Allow-Origin'] == 'http://apache.org'
+
+ response = test_client.get(
+ "/api/v1/pools", headers={"Authorization": token, "Origin": "http://example.com"}
+ )
+ assert response.status_code == 200
+ assert response.headers['Access-Control-Allow-Origin'] == 'http://example.com'
+
+
+class TestCorsWildcard(BaseTestAuth):
+ @pytest.fixture(autouse=True, scope="class")
+ def with_basic_auth_backend(self, minimal_app_for_api):
+ from airflow.www.extensions.init_security import init_api_experimental_auth
+
+ old_auth = getattr(minimal_app_for_api, 'api_auth')
+
+ try:
+ with conf_vars(
+ {
+ ("api", "auth_backends"): "airflow.api.auth.backend.basic_auth",
+ ("api", "access_control_allow_origins"): "*",
+ }
+ ):
+ init_api_experimental_auth(minimal_app_for_api)
+ yield
+ finally:
+ setattr(minimal_app_for_api, 'api_auth', old_auth)
+
+ def test_cors_origin_reflection(self):
+ token = "Basic " + b64encode(b"test:test").decode()
+ clear_db_pools()
+
+ with self.app.test_client() as test_client:
+ response = test_client.get(
+ "/api/v1/pools", headers={"Authorization": token, "Origin": "http://example.com"}
+ )
+ assert response.status_code == 200
+ assert response.headers['Access-Control-Allow-Origin'] == '*'