This is an automated email from the ASF dual-hosted git repository.
taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7b60825046 check whether AUTH_ROLE_PUBLIC is set in
check_authentication (#38924)
7b60825046 is described below
commit 7b608250468740954c6b0af7a5f7f23dfa52b473
Author: Wei Lee <[email protected]>
AuthorDate: Sun Apr 14 18:06:43 2024 +0800
check whether AUTH_ROLE_PUBLIC is set in check_authentication (#38924)
* fix(security): check whether AUTH_ROLE_PUBLIC is set in
check_authentication
* test(api_connexion): ensure the auth_role_public is not set in
minimal_app_for_api
* test(endpoints): add test case to each of the endpoints for
auth_role_public cases
---
airflow/api_connexion/security.py | 6 +
tests/api_connexion/conftest.py | 15 +-
.../endpoints/test_config_endpoint.py | 22 +++
.../endpoints/test_connection_endpoint.py | 89 ++++++++++
tests/api_connexion/endpoints/test_dag_endpoint.py | 99 +++++++++++
.../endpoints/test_dag_run_endpoint.py | 185 ++++++++++++++++++++
.../endpoints/test_dag_source_endpoint.py | 16 ++
.../endpoints/test_dag_warning_endpoint.py | 12 ++
.../endpoints/test_dataset_endpoint.py | 186 ++++++++++++++++++++-
.../endpoints/test_event_log_endpoint.py | 44 +++++
10 files changed, 672 insertions(+), 2 deletions(-)
diff --git a/airflow/api_connexion/security.py
b/airflow/api_connexion/security.py
index 1cc044d9dd..660bc6cce2 100644
--- a/airflow/api_connexion/security.py
+++ b/airflow/api_connexion/security.py
@@ -49,6 +49,12 @@ def check_authentication() -> None:
response = auth.requires_authentication(Response)()
if response.status_code == 200:
return
+
+ # Even if the current_user is anonymous, the AUTH_ROLE_PUBLIC might still
have permission.
+ appbuilder = get_airflow_app().appbuilder
+ if appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None):
+ return
+
# since this handler only checks authentication, not authorization,
# we should always return 401
raise Unauthenticated(headers=response.headers)
diff --git a/tests/api_connexion/conftest.py b/tests/api_connexion/conftest.py
index c860a78f27..481f07fe73 100644
--- a/tests/api_connexion/conftest.py
+++ b/tests/api_connexion/conftest.py
@@ -40,7 +40,9 @@ def minimal_app_for_api():
)
def factory():
with conf_vars({("api", "auth_backends"):
"tests.test_utils.remote_user_api_auth_backend"}):
- return app.create_app(testing=True, config={"WTF_CSRF_ENABLED":
False}) # type:ignore
+ _app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED":
False}) # type:ignore
+ _app.config["AUTH_ROLE_PUBLIC"] = None
+ return _app
return factory()
@@ -67,3 +69,14 @@ def dagbag():
)
DagBag(include_examples=True, read_dags_from_db=False).sync_to_db()
return DagBag(include_examples=True, read_dags_from_db=True)
+
+
[email protected]
+def set_auto_role_public(request):
+ app = request.getfixturevalue("minimal_app_for_api")
+ auto_role_public = app.config["AUTH_ROLE_PUBLIC"]
+ app.config["AUTH_ROLE_PUBLIC"] = request.param
+
+ yield
+
+ app.config["AUTH_ROLE_PUBLIC"] = auto_role_public
diff --git a/tests/api_connexion/endpoints/test_config_endpoint.py
b/tests/api_connexion/endpoints/test_config_endpoint.py
index c091c4ef1c..3dd5814e5d 100644
--- a/tests/api_connexion/endpoints/test_config_endpoint.py
+++ b/tests/api_connexion/endpoints/test_config_endpoint.py
@@ -222,6 +222,16 @@ class TestGetConfig:
assert response.status_code == 403
assert "chose not to expose" in response.json["detail"]
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code):
+ response = self.client.get("/api/v1/config", headers={"Accept":
"application/json"})
+
+ assert response.status_code == expected_status_code
+
class TestGetValue:
@pytest.fixture(autouse=True)
@@ -339,3 +349,15 @@ class TestGetValue:
)
assert response.status_code == 403
assert "chose not to expose" in response.json["detail"]
+
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code):
+ response = self.client.get(
+ "/api/v1/config/section/smtp/option/smtp_mail_from",
headers={"Accept": "application/json"}
+ )
+
+ assert response.status_code == expected_status_code
diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py
b/tests/api_connexion/endpoints/test_connection_endpoint.py
index dc0f2893e0..c88b8a56de 100644
--- a/tests/api_connexion/endpoints/test_connection_endpoint.py
+++ b/tests/api_connexion/endpoints/test_connection_endpoint.py
@@ -112,6 +112,22 @@ class TestDeleteConnection(TestConnectionEndpoint):
)
assert response.status_code == 403
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 204)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, session):
+ connection_model = Connection(conn_id="test-connection",
conn_type="test_type")
+ session.add(connection_model)
+ session.commit()
+ conn = session.query(Connection).all()
+ assert len(conn) == 1
+
+ response = self.client.delete("/api/v1/connections/test-connection")
+
+ assert response.status_code == expected_status_code
+
class TestGetConnection(TestConnectionEndpoint):
def test_should_respond_200(self, session):
@@ -178,6 +194,31 @@ class TestGetConnection(TestConnectionEndpoint):
assert_401(response)
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, session):
+ connection_model = Connection(
+ conn_id="test-connection-id",
+ conn_type="mysql",
+ description="test description",
+ host="mysql",
+ login="login",
+ schema="testschema",
+ port=80,
+ extra='{"param": "value"}',
+ )
+ session.add(connection_model)
+ session.commit()
+ result = session.query(Connection).all()
+ assert len(result) == 1
+
+ response = self.client.get("/api/v1/connections/test-connection-id")
+
+ assert response.status_code == expected_status_code
+
class TestGetConnections(TestConnectionEndpoint):
def test_should_respond_200(self, session):
@@ -256,6 +297,16 @@ class TestGetConnections(TestConnectionEndpoint):
assert_401(response)
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code):
+ response = self.client.get("/api/v1/connections")
+
+ assert response.status_code == expected_status_code
+
class TestGetConnectionsPagination(TestConnectionEndpoint):
@pytest.mark.parametrize(
@@ -529,6 +580,21 @@ class TestPatchConnection(TestConnectionEndpoint):
assert_401(response)
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, session):
+ self._create_connection(session)
+
+ response = self.client.patch(
+ "/api/v1/connections/test-connection-id",
+ json={"connection_id": "test-connection-id", "conn_type":
"test_type", "extra": '{"key": "var"}'},
+ )
+
+ assert response.status_code == expected_status_code
+
class TestPostConnection(TestConnectionEndpoint):
def test_post_should_respond_200(self, session):
@@ -610,6 +676,18 @@ class TestPostConnection(TestConnectionEndpoint):
assert_401(response)
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code):
+ response = self.client.post(
+ "/api/v1/connections", json={"connection_id":
"test-connection-id", "conn_type": "test_type"}
+ )
+
+ assert response.status_code == expected_status_code
+
class TestConnection(TestConnectionEndpoint):
@mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
@@ -663,3 +741,14 @@ class TestConnection(TestConnectionEndpoint):
"Testing connections is disabled in Airflow configuration. "
"Contact your deployment admin to enable it."
)
+
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code):
+ payload = {"connection_id": "test-connection-id", "conn_type":
"sqlite"}
+ response = self.client.post("/api/v1/connections/test", json=payload)
+ assert response.status_code == expected_status_code
diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py
b/tests/api_connexion/endpoints/test_dag_endpoint.py
index 8578f633cf..b514faba27 100644
--- a/tests/api_connexion/endpoints/test_dag_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_endpoint.py
@@ -317,6 +317,24 @@ class TestGetDag(TestDagEndpoint):
)
assert response.status_code == 400, f"Current code:
{response.status_code}"
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, session):
+ dag_model = DagModel(
+ dag_id="TEST_DAG_1",
+ fileloc="/tmp/dag_1.py",
+ schedule_interval=None,
+ is_paused=False,
+ )
+ session.add(dag_model)
+ session.commit()
+
+ response = self.client.get("/api/v1/dags/TEST_DAG_1")
+ assert response.status_code == expected_status_code
+
class TestGetDagDetails(TestDagEndpoint):
def test_should_respond_200(self, url_safe_serializer):
@@ -728,6 +746,18 @@ class TestGetDagDetails(TestDagEndpoint):
)
assert response.status_code == 400, f"Current code:
{response.status_code}"
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, url_safe_serializer):
+ self._create_dag_model_for_details_endpoint(self.dag_id)
+ url_safe_serializer.dumps("/tmp/dag.py")
+ response = self.client.get(f"/api/v1/dags/{self.dag_id}/details")
+
+ assert response.status_code == expected_status_code
+
class TestGetDags(TestDagEndpoint):
@provide_session
@@ -1259,6 +1289,22 @@ class TestGetDags(TestDagEndpoint):
assert response.status_code == 400, f"Current code:
{response.status_code}"
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, session):
+ self._create_dag_models(2)
+ self._create_deactivated_dag()
+
+ dags_query = session.query(DagModel).filter(~DagModel.is_subdag)
+ assert len(dags_query.all()) == 3
+
+ response = self.client.get("api/v1/dags")
+
+ assert response.status_code == expected_status_code
+
class TestPatchDag(TestDagEndpoint):
def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer,
session):
@@ -1485,6 +1531,24 @@ class TestPatchDag(TestDagEndpoint):
assert response.status_code == 403
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(
+ self, set_auto_role_public, expected_status_code, url_safe_serializer,
session
+ ):
+ url_safe_serializer.dumps("/tmp/dag_1.py")
+ dag_model = self._create_dag_model()
+ payload = {"is_paused": False}
+ response = self.client.patch(
+ f"/api/v1/dags/{dag_model.dag_id}",
+ json=payload,
+ )
+
+ assert response.status_code == expected_status_code
+
class TestPatchDags(TestDagEndpoint):
@provide_session
@@ -2291,6 +2355,29 @@ class TestPatchDags(TestDagEndpoint):
)
assert response.status_code == 400
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(
+ self, set_auto_role_public, expected_status_code, session,
url_safe_serializer
+ ):
+ url_safe_serializer.dumps("/tmp/dag_1.py")
+ url_safe_serializer.dumps("/tmp/dag_2.py")
+ self._create_dag_models(2)
+ self._create_deactivated_dag()
+
+ dags_query = session.query(DagModel).filter(~DagModel.is_subdag)
+ assert len(dags_query.all()) == 3
+
+ response = self.client.patch(
+ "/api/v1/dags?dag_id_pattern=~",
+ json={"is_paused": False},
+ )
+
+ assert response.status_code == expected_status_code
+
class TestDeleteDagEndpoint(TestDagEndpoint):
def test_that_dag_can_be_deleted(self, session):
@@ -2342,3 +2429,15 @@ class TestDeleteDagEndpoint(TestDagEndpoint):
environ_overrides={"REMOTE_USER": "test_no_permissions"},
)
assert response.status_code == 403
+
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 204)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code):
+ self._create_dag_models(1)
+
+ response = self.client.delete("/api/v1/dags/TEST_DAG_1")
+
+ assert response.status_code == expected_status_code
diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py
b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
index f6ace16099..5182ef427e 100644
--- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
@@ -214,6 +214,18 @@ class TestDeleteDagRun(TestDagRunEndpoint):
)
assert response.status_code == 403
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 204)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, session):
+ session.add_all(self._create_test_dag_run())
+ session.commit()
+ response =
self.client.delete("api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1")
+
+ assert response.status_code == expected_status_code
+
class TestGetDagRun(TestDagRunEndpoint):
def test_should_respond_200(self, session):
@@ -333,6 +345,29 @@ class TestGetDagRun(TestDagRunEndpoint):
)
assert response.status_code == 400, f"Current code:
{response.status_code}"
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, session):
+ dagrun_model = DagRun(
+ dag_id="TEST_DAG_ID",
+ run_id="TEST_DAG_RUN_ID",
+ run_type=DagRunType.MANUAL,
+ execution_date=timezone.parse(self.default_time),
+ start_date=timezone.parse(self.default_time),
+ external_trigger=True,
+ state="running",
+ )
+ session.add(dagrun_model)
+ session.commit()
+ result = session.query(DagRun).all()
+ assert len(result) == 1
+
+ response =
self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID")
+ assert response.status_code == expected_status_code
+
class TestGetDagRuns(TestDagRunEndpoint):
def test_should_respond_200(self, session):
@@ -508,6 +543,18 @@ class TestGetDagRuns(TestDagRunEndpoint):
)
assert response.status_code == 400, f"Current code:
{response.status_code}"
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, session):
+ self._create_test_dag_run()
+ result = session.query(DagRun).all()
+ assert len(result) == 2
+ response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns")
+ assert response.status_code == expected_status_code
+
class TestGetDagRunsPagination(TestDagRunEndpoint):
@pytest.mark.parametrize(
@@ -931,6 +978,18 @@ class TestGetDagRunBatch(TestDagRunEndpoint):
assert_401(response)
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code):
+ self._create_test_dag_run()
+
+ response = self.client.post("api/v1/dags/~/dagRuns/list",
json={"dag_ids": ["TEST_DAG_ID"]})
+
+ assert response.status_code == expected_status_code
+
class TestGetDagRunBatchPagination(TestDagRunEndpoint):
@pytest.mark.parametrize(
@@ -1564,6 +1623,26 @@ class TestPostDagRun(TestDagRunEndpoint):
)
assert response.status_code == 403
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code):
+ execution_date = "2020-11-10T08:25:56.939143+00:00"
+ logical_date = "2020-11-10T08:25:56.939143+00:00"
+ self._create_dag("TEST_DAG_ID")
+
+ response = self.client.post(
+ "api/v1/dags/TEST_DAG_ID/dagRuns",
+ json={
+ "execution_date": execution_date,
+ "logical_date": logical_date,
+ },
+ )
+
+ assert response.status_code == expected_status_code
+
class TestPatchDagRunState(TestDagRunEndpoint):
@pytest.mark.parametrize("state", ["failed", "success", "queued"])
@@ -1687,6 +1766,31 @@ class TestPatchDagRunState(TestDagRunEndpoint):
)
assert response.status_code == 404
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, dag_maker, session):
+ dag_id = "TEST_DAG_ID"
+ dag_run_id = "TEST_DAG_RUN_ID"
+ with dag_maker(dag_id) as dag:
+ task = EmptyOperator(task_id="task_id", dag=dag)
+ self.app.dag_bag.bag_dag(dag, root_dag=dag)
+ dr = dag_maker.create_dagrun(run_id=dag_run_id,
run_type=DagRunType.SCHEDULED)
+ ti = dr.get_task_instance(task_id="task_id")
+ ti.task = task
+ ti.state = State.RUNNING
+ session.merge(ti)
+ session.commit()
+
+ response = self.client.patch(
+ f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}",
+ json={"state": "failed"},
+ )
+
+ assert response.status_code == expected_status_code
+
class TestClearDagRun(TestDagRunEndpoint):
def test_should_respond_200(self, dag_maker, session):
@@ -1822,6 +1926,31 @@ class TestClearDagRun(TestDagRunEndpoint):
)
assert response.status_code == 404
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, dag_maker, session):
+ dag_id = "TEST_DAG_ID"
+ dag_run_id = "TEST_DAG_RUN_ID"
+ with dag_maker(dag_id) as dag:
+ task = EmptyOperator(task_id="task_id", dag=dag)
+ self.app.dag_bag.bag_dag(dag, root_dag=dag)
+ dr = dag_maker.create_dagrun(run_id=dag_run_id,
run_type=DagRunType.SCHEDULED)
+ ti = dr.get_task_instance(task_id="task_id")
+ ti.task = task
+ ti.state = State.RUNNING
+ session.merge(ti)
+ session.commit()
+
+ response = self.client.patch(
+ f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}",
+ json={"state": "failed"},
+ )
+
+ assert response.status_code == expected_status_code
+
@pytest.mark.need_serialized_dag
class TestGetDagRunDatasetTriggerEvents(TestDagRunEndpoint):
@@ -1916,6 +2045,42 @@ class
TestGetDagRunDatasetTriggerEvents(TestDagRunEndpoint):
assert_401(response)
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, dag_maker, session):
+ dataset1 = Dataset(uri="ds1")
+
+ with dag_maker(dag_id="source_dag", start_date=timezone.utcnow(),
session=session):
+ EmptyOperator(task_id="task", outlets=[dataset1])
+ dr = dag_maker.create_dagrun()
+ ti = dr.task_instances[0]
+
+ ds1_id =
session.query(DatasetModel.id).filter_by(uri=dataset1.uri).scalar()
+ event = DatasetEvent(
+ dataset_id=ds1_id,
+ source_task_id=ti.task_id,
+ source_dag_id=ti.dag_id,
+ source_run_id=ti.run_id,
+ source_map_index=ti.map_index,
+ )
+ session.add(event)
+
+ with dag_maker(dag_id="TEST_DAG_ID", start_date=timezone.utcnow(),
session=session):
+ pass
+ dr = dag_maker.create_dagrun(run_id="TEST_DAG_RUN_ID",
run_type=DagRunType.DATASET_TRIGGERED)
+ dr.consumed_dataset_events.append(event)
+
+ session.commit()
+ assert event.timestamp
+
+ response = self.client.get(
+
"api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/upstreamDatasetEvents",
+ )
+ assert response.status_code == expected_status_code
+
class TestSetDagRunNote(TestDagRunEndpoint):
def test_should_respond_200(self, dag_maker, session):
@@ -2046,3 +2211,23 @@ class TestSetDagRunNote(TestDagRunEndpoint):
json={"note": "I am setting a note with anonymous user"},
)
assert response.status_code == 200
+
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, session):
+ dag_runs: list[DagRun] = self._create_test_dag_run(DagRunState.SUCCESS)
+ session.add_all(dag_runs)
+ session.commit()
+ created_dr: DagRun = dag_runs[0]
+ new_note_value = "My super cool DagRun notes"
+ response = self.client.patch(
+
f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote",
+ json={"note": new_note_value},
+ )
+
+ session.query(DagRun).filter(DagRun.run_id ==
created_dr.run_id).first()
+
+ assert response.status_code == expected_status_code
diff --git a/tests/api_connexion/endpoints/test_dag_source_endpoint.py
b/tests/api_connexion/endpoints/test_dag_source_endpoint.py
index d48d7e1c02..14c7d1534d 100644
--- a/tests/api_connexion/endpoints/test_dag_source_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_source_endpoint.py
@@ -202,3 +202,19 @@ class TestGetSource:
)
assert response.status_code == 403
assert read_dag.status_code == 200
+
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, url_safe_serializer):
+ dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE)
+ dagbag.sync_to_db()
+ test_dag: DAG = dagbag.dags[TEST_DAG_ID]
+ self._get_dag_file_docstring(test_dag.fileloc)
+
+ url =
f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}"
+ response = self.client.get(url, headers={"Accept": "text/plain"})
+
+ assert response.status_code == expected_status_code
diff --git a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py
b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py
index 9310956d24..cc398329b9 100644
--- a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py
@@ -170,3 +170,15 @@ class TestGetDagWarningEndpoint(TestBaseDagWarning):
query_string={"dag_id": "dag1"},
)
assert response.status_code == 403
+
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code):
+ response = self.client.get(
+ "/api/v1/dagWarnings",
+ query_string={"dag_id": "dag1", "warning_type": "non-existent
pool"},
+ )
+ assert response.status_code == expected_status_code
diff --git a/tests/api_connexion/endpoints/test_dataset_endpoint.py
b/tests/api_connexion/endpoints/test_dataset_endpoint.py
index a2451fb30a..5b6e2f2414 100644
--- a/tests/api_connexion/endpoints/test_dataset_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dataset_endpoint.py
@@ -143,6 +143,22 @@ class TestGetDatasetEndpoint(TestDatasetEndpoint):
response =
self.client.get(f"/api/v1/datasets/{urllib.parse.quote('s3://bucket/key',
safe='')}")
assert_401(response)
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, session):
+ self._create_dataset(session)
+ assert session.query(DatasetModel).count() == 1
+
+ with assert_queries_count(5):
+ response = self.client.get(
+ f"/api/v1/datasets/{urllib.parse.quote('s3://bucket/key',
safe='')}",
+ )
+
+ assert response.status_code == expected_status_code
+
class TestGetDatasets(TestDatasetEndpoint):
def test_should_respond_200(self, session):
@@ -313,6 +329,31 @@ class TestGetDatasets(TestDatasetEndpoint):
response_data = response.json
assert len(response_data["datasets"]) == expected_num
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, session):
+ datasets = [
+ DatasetModel(
+ id=i,
+ uri=f"s3://bucket/key/{i}",
+ extra={"foo": "bar"},
+ created_at=timezone.parse(self.default_time),
+ updated_at=timezone.parse(self.default_time),
+ )
+ for i in [1, 2]
+ ]
+ session.add_all(datasets)
+ session.commit()
+ assert session.query(DatasetModel).count() == 2
+
+ with assert_queries_count(8):
+ response = self.client.get("/api/v1/datasets")
+
+ assert response.status_code == expected_status_code
+
class TestGetDatasetsEndpointPagination(TestDatasetEndpoint):
@pytest.mark.parametrize(
@@ -579,6 +620,32 @@ class TestGetDatasetEvents(TestDatasetEndpoint):
"total_entries": 1,
}
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, session):
+ self._create_dataset(session)
+ common = {
+ "dataset_id": 1,
+ "extra": {"foo": "bar"},
+ "source_dag_id": "foo",
+ "source_task_id": "bar",
+ "source_run_id": "custom",
+ "source_map_index": -1,
+ "created_dagruns": [],
+ }
+
+ events = [DatasetEvent(id=i,
timestamp=timezone.parse(self.default_time), **common) for i in [1, 2]]
+ session.add_all(events)
+ session.commit()
+ assert session.query(DatasetEvent).count() == 2
+
+ response = self.client.get("/api/v1/datasets/events")
+
+ assert response.status_code == expected_status_code
+
class TestPostDatasetEvents(TestDatasetEndpoint):
@pytest.fixture
@@ -651,6 +718,19 @@ class TestPostDatasetEvents(TestDatasetEndpoint):
response = self.client.post("/api/v1/datasets/events",
json={"dataset_uri": "TEST_DATASET_URI"})
assert_401(response)
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ @pytest.mark.usefixtures("time_freezer")
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, session):
+ self._create_dataset(session)
+ event_payload = {"dataset_uri": "s3://bucket/key", "extra": {"foo":
"bar"}}
+ response = self.client.post("/api/v1/datasets/events",
json=event_payload)
+
+ assert response.status_code == expected_status_code
+
class TestGetDatasetEventsEndpointPagination(TestDatasetEndpoint):
@pytest.mark.parametrize(
@@ -821,6 +901,27 @@ class
TestGetDagDatasetQueuedEvent(TestQueuedEventEndpoint):
assert response.status_code == 403
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ @pytest.mark.usefixtures("time_freezer")
+ def test_with_auth_role_public_set(
+ self, set_auto_role_public, expected_status_code, create_dummy_dag,
session
+ ):
+ dag, _ = create_dummy_dag()
+ dag_id = dag.dag_id
+ dataset_id = self._create_dataset(session).id
+ self._create_dataset_dag_run_queues(dag_id, dataset_id, session)
+ dataset_uri = "s3://bucket/key"
+
+ response = self.client.get(
+ f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}",
+ )
+
+ assert response.status_code == expected_status_code
+
class TestDeleteDagDatasetQueuedEvent(TestDatasetEndpoint):
def test_delete_should_respond_204(self, session, create_dummy_dag):
@@ -882,7 +983,7 @@ class TestDeleteDagDatasetQueuedEvent(TestDatasetEndpoint):
class TestGetDagDatasetQueuedEvents(TestQueuedEventEndpoint):
@pytest.mark.usefixtures("time_freezer")
- def test_should_respond_200(self, session, create_dummy_dag, time_freezer):
+ def test_should_respond_200(self, session, create_dummy_dag):
dag, _ = create_dummy_dag()
dag_id = dag.dag_id
dataset_id = self._create_dataset(session).id
@@ -938,6 +1039,24 @@ class
TestGetDagDatasetQueuedEvents(TestQueuedEventEndpoint):
assert response.status_code == 403
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(
+ self, set_auto_role_public, expected_status_code, session,
create_dummy_dag
+ ):
+ dag, _ = create_dummy_dag()
+ dag_id = dag.dag_id
+ dataset_id = self._create_dataset(session).id
+ self._create_dataset_dag_run_queues(dag_id, dataset_id, session)
+
+ response = self.client.get(
+ f"/api/v1/dags/{dag_id}/datasets/queuedEvent",
+ )
+ assert response.status_code == expected_status_code
+
class TestDeleteDagDatasetQueuedEvents(TestDatasetEndpoint):
def test_should_respond_404(self):
@@ -973,6 +1092,31 @@ class
TestDeleteDagDatasetQueuedEvents(TestDatasetEndpoint):
assert response.status_code == 403
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 204)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(
+ self, set_auto_role_public, expected_status_code, session,
create_dummy_dag
+ ):
+ dag, _ = create_dummy_dag()
+ dag_id = dag.dag_id
+ dataset_uri = "s3://bucket/key"
+ dataset_id = self._create_dataset(session).id
+
+ ddrq = DatasetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id)
+ session.add(ddrq)
+ session.commit()
+ conn = session.query(DatasetDagRunQueue).all()
+ assert len(conn) == 1
+
+ response = self.client.delete(
+ f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}",
+ )
+
+ assert response.status_code == expected_status_code
+
class TestGetDatasetQueuedEvents(TestQueuedEventEndpoint):
@pytest.mark.usefixtures("time_freezer")
@@ -1033,6 +1177,26 @@ class
TestGetDatasetQueuedEvents(TestQueuedEventEndpoint):
assert response.status_code == 403
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ @pytest.mark.usefixtures("time_freezer")
+ def test_with_auth_role_public_set(
+ self, set_auto_role_public, expected_status_code, session,
create_dummy_dag
+ ):
+ dag, _ = create_dummy_dag()
+ dag_id = dag.dag_id
+ dataset_id = self._create_dataset(session).id
+ self._create_dataset_dag_run_queues(dag_id, dataset_id, session)
+
+ response = self.client.get(
+ f"/api/v1/dags/{dag_id}/datasets/queuedEvent",
+ )
+
+ assert response.status_code == expected_status_code
+
class TestDeleteDatasetQueuedEvents(TestQueuedEventEndpoint):
def test_delete_should_respond_204(self, session, create_dummy_dag):
@@ -1084,3 +1248,23 @@ class
TestDeleteDatasetQueuedEvents(TestQueuedEventEndpoint):
)
assert response.status_code == 403
+
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 204)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(
+ self, set_auto_role_public, expected_status_code, session,
create_dummy_dag
+ ):
+ dag, _ = create_dummy_dag()
+ dag_id = dag.dag_id
+ dataset_id = self._create_dataset(session).id
+ self._create_dataset_dag_run_queues(dag_id, dataset_id, session)
+ dataset_uri = "s3://bucket/key"
+
+ response = self.client.delete(
+ f"/api/v1/datasets/queuedEvent/{dataset_uri}",
+ )
+
+ assert response.status_code == expected_status_code
diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py
b/tests/api_connexion/endpoints/test_event_log_endpoint.py
index 6e71a86b94..6738858ddd 100644
--- a/tests/api_connexion/endpoints/test_event_log_endpoint.py
+++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py
@@ -109,6 +109,21 @@ class TestEventLogEndpoint:
def teardown_method(self) -> None:
clear_db_logs()
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, log_model):
+ event_log_id = log_model.id
+ response = self.client.get(
+ f"/api/v1/eventLogs/{event_log_id}",
environ_overrides={"REMOTE_USER": "test"}
+ )
+
+ response = self.client.get("/api/v1/eventLogs")
+
+ assert response.status_code == expected_status_code
+
class TestGetEventLog(TestEventLogEndpoint):
def test_should_respond_200(self, log_model):
@@ -152,6 +167,18 @@ class TestGetEventLog(TestEventLogEndpoint):
)
assert response.status_code == 403
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(self, set_auto_role_public,
expected_status_code, log_model):
+ event_log_id = log_model.id
+
+ response = self.client.get(f"/api/v1/eventLogs/{event_log_id}")
+
+ assert response.status_code == expected_status_code
+
class TestGetEventLogs(TestEventLogEndpoint):
def test_should_respond_200(self, session, create_log_model):
@@ -349,6 +376,23 @@ class TestGetEventLogs(TestEventLogEndpoint):
assert response_data["total_entries"] == 1
assert {"cli_scheduler"} == {x["event"] for x in
response_data["event_logs"]}
+ @pytest.mark.parametrize(
+ "set_auto_role_public, expected_status_code",
+ (("Public", 403), ("Admin", 200)),
+ indirect=["set_auto_role_public"],
+ )
+ def test_with_auth_role_public_set(
+ self, set_auto_role_public, expected_status_code, create_log_model,
session
+ ):
+ log_model_3 = Log(event="cli_scheduler", owner="root",
extra='{"host_name": "e24b454f002a"}')
+ log_model_3.dttm = self.default_time_2
+
+ session.add(log_model_3)
+ session.flush()
+ response = self.client.get("/api/v1/eventLogs")
+
+ assert response.status_code == expected_status_code
+
class TestGetEventLogPagination(TestEventLogEndpoint):
@pytest.mark.parametrize(