This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-3-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 6821fe12f880696d5219057278b6d2a6c425bd86 Author: Mark Norman Francis <[email protected]> AuthorDate: Fri Aug 5 18:41:05 2022 +0100 Allow wildcarded CORS origins (#25553) '*' is a valid 'Access-Control-Allow-Origin' response, but was being dropped as it failed to match the Origin header sent in requests. (cherry picked from commit e81b27e713e9ef6f7104c7038f0c37cc55d96593) --- airflow/www/extensions/init_views.py | 8 +- tests/api_connexion/test_cors.py | 140 +++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 3 deletions(-) diff --git a/airflow/www/extensions/init_views.py b/airflow/www/extensions/init_views.py index 83dbc50eaa..4a2d4a5119 100644 --- a/airflow/www/extensions/init_views.py +++ b/airflow/www/extensions/init_views.py @@ -159,11 +159,13 @@ def set_cors_headers_on_response(response): allow_headers = conf.get('api', 'access_control_allow_headers') allow_methods = conf.get('api', 'access_control_allow_methods') allow_origins = conf.get('api', 'access_control_allow_origins') - if allow_headers is not None: + if allow_headers: response.headers['Access-Control-Allow-Headers'] = allow_headers - if allow_methods is not None: + if allow_methods: response.headers['Access-Control-Allow-Methods'] = allow_methods - if allow_origins is not None: + if allow_origins == '*': + response.headers['Access-Control-Allow-Origin'] = '*' + elif allow_origins: allowed_origins = allow_origins.split(' ') origin = request.environ.get('HTTP_ORIGIN', allowed_origins[0]) if origin in allowed_origins: diff --git a/tests/api_connexion/test_cors.py b/tests/api_connexion/test_cors.py new file mode 100644 index 0000000000..30ae19236d --- /dev/null +++ b/tests/api_connexion/test_cors.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from base64 import b64encode + +import pytest + +from tests.test_utils.config import conf_vars +from tests.test_utils.db import clear_db_pools + + +class BaseTestAuth: + @pytest.fixture(autouse=True) + def set_attrs(self, minimal_app_for_api): + self.app = minimal_app_for_api + + sm = self.app.appbuilder.sm + tester = sm.find_user(username="test") + if not tester: + role_admin = sm.find_role("Admin") + sm.add_user( + username="test", + first_name="test", + last_name="test", + email="[email protected]", + role=role_admin, + password="test", + ) + + +class TestEmptyCors(BaseTestAuth): + @pytest.fixture(autouse=True, scope="class") + def with_basic_auth_backend(self, minimal_app_for_api): + from airflow.www.extensions.init_security import init_api_experimental_auth + + old_auth = getattr(minimal_app_for_api, 'api_auth') + + try: + with conf_vars({("api", "auth_backends"): "airflow.api.auth.backend.basic_auth"}): + init_api_experimental_auth(minimal_app_for_api) + yield + finally: + setattr(minimal_app_for_api, 'api_auth', old_auth) + + def test_empty_cors_headers(self): + token = "Basic " + b64encode(b"test:test").decode() + clear_db_pools() + + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 200 + assert 'Access-Control-Allow-Headers' not in response.headers + assert 'Access-Control-Allow-Methods' not in response.headers + assert 'Access-Control-Allow-Origin' not in response.headers + + +class TestCorsOrigin(BaseTestAuth): + @pytest.fixture(autouse=True, scope="class") + def with_basic_auth_backend(self, minimal_app_for_api): + from airflow.www.extensions.init_security import init_api_experimental_auth + + old_auth = getattr(minimal_app_for_api, 'api_auth') + + try: + with conf_vars( + { + ("api", "auth_backends"): "airflow.api.auth.backend.basic_auth", + ("api", "access_control_allow_origins"): "http://apache.org http://example.com", + } + ): + init_api_experimental_auth(minimal_app_for_api) + yield + finally: + setattr(minimal_app_for_api, 'api_auth', old_auth) + + def test_cors_origin_reflection(self): + token = "Basic " + b64encode(b"test:test").decode() + clear_db_pools() + + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 200 + assert response.headers['Access-Control-Allow-Origin'] == 'http://apache.org' + + response = test_client.get( + "/api/v1/pools", headers={"Authorization": token, "Origin": "http://apache.org"} + ) + assert response.status_code == 200 + assert response.headers['Access-Control-Allow-Origin'] == 'http://apache.org' + + response = test_client.get( + "/api/v1/pools", headers={"Authorization": token, "Origin": "http://example.com"} + ) + assert response.status_code == 200 + assert response.headers['Access-Control-Allow-Origin'] == 'http://example.com' + + +class TestCorsWildcard(BaseTestAuth): + @pytest.fixture(autouse=True, scope="class") + def with_basic_auth_backend(self, minimal_app_for_api): + from airflow.www.extensions.init_security import init_api_experimental_auth + + old_auth = getattr(minimal_app_for_api, 'api_auth') + + try: + with conf_vars( + { + ("api", "auth_backends"): "airflow.api.auth.backend.basic_auth", + ("api", "access_control_allow_origins"): "*", + } + ): + init_api_experimental_auth(minimal_app_for_api) + yield + finally: + setattr(minimal_app_for_api, 'api_auth', old_auth) + + def test_cors_origin_reflection(self): + token = "Basic " + b64encode(b"test:test").decode() + clear_db_pools() + + with self.app.test_client() as test_client: + response = test_client.get( + "/api/v1/pools", headers={"Authorization": token, "Origin": "http://example.com"} + ) + assert response.status_code == 200 + assert response.headers['Access-Control-Allow-Origin'] == '*'
