This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e671074137 Richer Audit Log extra field (#38166)
e671074137 is described below
commit e67107413785fa7ff8e4f0dd89d759120bedfea6
Author: Brent Bovenzi <[email protected]>
AuthorDate: Mon Mar 18 17:58:00 2024 -0700
Richer Audit Log extra field (#38166)
---
.../api_connexion/endpoints/dataset_endpoint.py | 9 +--
airflow/www/decorators.py | 70 ++++++++++++++--------
.../endpoints/test_connection_endpoint.py | 4 +-
tests/api_connexion/endpoints/test_dag_endpoint.py | 9 +--
.../endpoints/test_dag_run_endpoint.py | 10 +++-
.../endpoints/test_dataset_endpoint.py | 24 ++++++++
.../endpoints/test_task_instance_endpoint.py | 6 +-
.../endpoints/test_variable_endpoint.py | 39 ++++++++++--
tests/test_utils/www.py | 27 +++++----
tests/www/views/test_views_decorators.py | 7 +++
tests/www/views/test_views_paused.py | 6 +-
11 files changed, 150 insertions(+), 61 deletions(-)
diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py
b/airflow/api_connexion/endpoints/dataset_endpoint.py
index 57c299cade..bfdb8d0a5e 100644
--- a/airflow/api_connexion/endpoints/dataset_endpoint.py
+++ b/airflow/api_connexion/endpoints/dataset_endpoint.py
@@ -46,10 +46,8 @@ from airflow.api_connexion.schemas.dataset_schema import (
from airflow.datasets import Dataset
from airflow.datasets.manager import dataset_manager
from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent,
DatasetModel
-from airflow.security import permissions
from airflow.utils import timezone
from airflow.utils.db import get_query_count
-from airflow.utils.log.action_logger import action_event_from_permission
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.www.decorators import action_logging
from airflow.www.extensions.init_auth_manager import get_auth_manager
@@ -330,12 +328,7 @@ def delete_dataset_queued_events(
@security.requires_access_dataset("POST")
@provide_session
-@action_logging(
- event=action_event_from_permission(
- prefix=RESOURCE_EVENT_PREFIX,
- permission=permissions.ACTION_CAN_CREATE,
- ),
-)
+@action_logging
def create_dataset_event(session: Session = NEW_SESSION) -> APIResponse:
"""Create dataset event."""
body = get_json_request_dict()
diff --git a/airflow/www/decorators.py b/airflow/www/decorators.py
index 91146d0fee..3eae5f6239 100644
--- a/airflow/www/decorators.py
+++ b/airflow/www/decorators.py
@@ -44,36 +44,36 @@ def _mask_variable_fields(extra_fields):
Mask the 'val_content' field if 'key_content' is in the mask list.
The variable requests values and args comes in this form:
- [('key', 'key_content'),('val', 'val_content'), ('description',
'description_content')]
+ {'key': 'key_content', 'val': 'val_content', 'description':
'description_content'}
"""
- result = []
+ result = {}
keyname = None
- for k, v in extra_fields:
+ for k, v in extra_fields.items():
if k == "key":
keyname = v
- result.append((k, v))
- elif keyname and k == "val":
+ result[k] = v
+ elif keyname and (k == "val" or k == "value"):
x = secrets_masker.redact(v, keyname)
- result.append((k, x))
+ result[k] = x
keyname = None
else:
- result.append((k, v))
+ result[k] = v
return result
def _mask_connection_fields(extra_fields):
"""Mask connection fields."""
- result = []
- for k, v in extra_fields:
- if k == "extra":
+ result = {}
+ for k, v in extra_fields.items():
+ if k == "extra" and v:
try:
extra = json.loads(v)
- extra = [(k, secrets_masker.redact(v, k)) for k, v in
extra.items()]
- result.append((k, json.dumps(dict(extra))))
+ extra = {k: secrets_masker.redact(v, k) for k, v in
extra.items()}
+ result[k] = dict(extra)
except json.JSONDecodeError:
- result.append((k, "Encountered non-JSON in `extra` field"))
+ result[k] = "Encountered non-JSON in `extra` field"
else:
- result.append((k, secrets_masker.redact(v, k)))
+ result[k] = secrets_masker.redact(v, k)
return result
@@ -94,35 +94,55 @@ def action_logging(func: T | None = None, event: str | None
= None) -> T | Calla
user = get_auth_manager().get_user_name()
user_display = get_auth_manager().get_user_display_name()
- fields_skip_logging = {"csrf_token", "_csrf_token",
"is_paused"}
- extra_fields = [
- (k, secrets_masker.redact(v, k))
+ isAPIRequest = request.blueprint == "/api/v1"
+ hasJsonBody = request.headers.get("content-type") ==
"application/json" and request.json
+
+ fields_skip_logging = {
+ "csrf_token",
+ "_csrf_token",
+ "is_paused",
+ "dag_id",
+ "task_id",
+ "dag_run_id",
+ "run_id",
+ "execution_date",
+ }
+ extra_fields = {
+ k: secrets_masker.redact(v, k)
for k, v in
itertools.chain(request.values.items(multi=True), request.view_args.items())
if k not in fields_skip_logging
- ]
+ }
if event and event.startswith("variable."):
- extra_fields = _mask_variable_fields(extra_fields)
- if event and event.startswith("connection."):
- extra_fields = _mask_connection_fields(extra_fields)
+ extra_fields = _mask_variable_fields(
+ request.json if isAPIRequest and hasJsonBody else
extra_fields
+ )
+ elif event and event.startswith("connection."):
+ extra_fields = _mask_connection_fields(
+ request.json if isAPIRequest and hasJsonBody else
extra_fields
+ )
+ elif hasJsonBody:
+ masked_json = {k: secrets_masker.redact(v, k) for k, v in
request.json.items()}
+ extra_fields = {**extra_fields, **masked_json}
params = {**request.values, **request.view_args}
+ if params and "is_paused" in params:
+ extra_fields["is_paused"] = params["is_paused"] == "false"
- if request.blueprint == "/api/v1":
+ if isAPIRequest:
if f"{request.origin}/" == request.root_url:
event_name = f"ui.{event_name}"
else:
event_name = f"api.{event_name}"
- if params and "is_paused" in params:
- extra_fields.append(("is_paused", params["is_paused"] ==
"false"))
log = Log(
event=event_name,
task_instance=None,
owner=user,
owner_display_name=user_display,
- extra=str(extra_fields),
+ extra=json.dumps(extra_fields),
task_id=params.get("task_id"),
dag_id=params.get("dag_id"),
+ run_id=params.get("run_id") or params.get("dag_run_id"),
)
if "execution_date" in request.values:
diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py
b/tests/api_connexion/endpoints/test_connection_endpoint.py
index ca76ce327f..0d9debeb98 100644
--- a/tests/api_connexion/endpoints/test_connection_endpoint.py
+++ b/tests/api_connexion/endpoints/test_connection_endpoint.py
@@ -539,7 +539,9 @@ class TestPostConnection(TestConnectionEndpoint):
connection = session.query(Connection).all()
assert len(connection) == 1
assert connection[0].conn_id == "test-connection-id"
- _check_last_log(session, dag_id=None, event="api.connection.create",
execution_date=None)
+ _check_last_log(
+ session, dag_id=None, event="api.connection.create",
execution_date=None, expected_extra=payload
+ )
def test_post_should_respond_200_extra_null(self, session):
payload = {"connection_id": "test-connection-id", "conn_type":
"test_type", "extra": None}
diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py
b/tests/api_connexion/endpoints/test_dag_endpoint.py
index 0655810e0e..30f87e1f68 100644
--- a/tests/api_connexion/endpoints/test_dag_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_endpoint.py
@@ -1247,11 +1247,10 @@ class TestPatchDag(TestDagEndpoint):
def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer,
session):
file_token = url_safe_serializer.dumps("/tmp/dag_1.py")
dag_model = self._create_dag_model()
+ payload = {"is_paused": False}
response = self.client.patch(
f"/api/v1/dags/{dag_model.dag_id}",
- json={
- "is_paused": False,
- },
+ json=payload,
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 200
@@ -1288,7 +1287,9 @@ class TestPatchDag(TestDagEndpoint):
"pickle_id": None,
}
assert response.json == expected_response
- _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag",
execution_date=None)
+ _check_last_log(
+ session, dag_id="TEST_DAG_1", event="api.patch_dag",
execution_date=None, expected_extra=payload
+ )
def test_should_respond_200_on_patch_with_granular_dag_access(self,
session):
self._create_dag_models(1)
diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py
b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
index ef91ef78eb..f6ace16099 100644
--- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
@@ -1952,9 +1952,10 @@ class TestSetDagRunNote(TestDagRunEndpoint):
assert dr.dag_run_note.user_id is not None
# Update the note again
new_note_value = "My super cool DagRun notes 2"
+ payload = {"note": new_note_value}
response = self.client.patch(
f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote",
- json={"note": new_note_value},
+ json=payload,
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 200
@@ -1975,6 +1976,13 @@ class TestSetDagRunNote(TestDagRunEndpoint):
"note": new_note_value,
}
assert dr.dag_run_note.user_id is not None
+ _check_last_log(
+ session,
+ dag_id=dr.dag_id,
+ event="api.set_dag_run_note",
+ execution_date=None,
+ expected_extra=payload,
+ )
def test_schema_validation_error_raises(self, dag_maker, session):
dag_runs: list[DagRun] = self._create_test_dag_run(DagRunState.SUCCESS)
diff --git a/tests/api_connexion/endpoints/test_dataset_endpoint.py
b/tests/api_connexion/endpoints/test_dataset_endpoint.py
index 0ba296604d..29192a3c65 100644
--- a/tests/api_connexion/endpoints/test_dataset_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dataset_endpoint.py
@@ -612,6 +612,30 @@ class TestPostDatasetEvents(TestDatasetEndpoint):
"source_map_index": -1,
"timestamp": self.default_time,
}
+ _check_last_log(
+ session,
+ dag_id=None,
+ event="api.create_dataset_event",
+ execution_date=None,
+ expected_extra=event_payload,
+ )
+
+ def test_should_mask_sensitive_extra_logs(self, session):
+ self._create_dataset(session)
+ event_payload = {"dataset_uri": "s3://bucket/key", "extra":
{"password": "bar"}}
+ response = self.client.post(
+ "/api/v1/datasets/events", json=event_payload,
environ_overrides={"REMOTE_USER": "test"}
+ )
+
+ assert response.status_code == 200
+ expected_extra = {**event_payload, "extra": {"password": "***"}}
+ _check_last_log(
+ session,
+ dag_id=None,
+ event="api.create_dataset_event",
+ execution_date=None,
+ expected_extra=expected_extra,
+ )
def test_order_by_raises_400_for_invalid_attr(self, session):
self._create_dataset(session)
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index 59ee686c1f..e6a3e0dfbb 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -1249,7 +1249,11 @@ class
TestPostClearTaskInstances(TestTaskInstanceEndpoint):
assert response.status_code == 200
assert len(response.json["task_instances"]) == expected_ti
_check_last_log(
- session, dag_id=request_dag,
event="api.post_clear_task_instances", execution_date=None
+ session,
+ dag_id=request_dag,
+ event="api.post_clear_task_instances",
+ execution_date=None,
+ expected_extra=payload,
)
@mock.patch("airflow.api_connexion.endpoints.task_instance_endpoint.clear_task_instances")
diff --git a/tests/api_connexion/endpoints/test_variable_endpoint.py
b/tests/api_connexion/endpoints/test_variable_endpoint.py
index f533710aab..f0a1b7502b 100644
--- a/tests/api_connexion/endpoints/test_variable_endpoint.py
+++ b/tests/api_connexion/endpoints/test_variable_endpoint.py
@@ -250,17 +250,20 @@ class TestGetVariables(TestVariableEndpoint):
class TestPatchVariable(TestVariableEndpoint):
def test_should_update_variable(self, session):
Variable.set("var1", "foo")
+ payload = {
+ "key": "var1",
+ "value": "updated",
+ }
response = self.client.patch(
"/api/v1/variables/var1",
- json={
- "key": "var1",
- "value": "updated",
- },
+ json=payload,
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 200
assert response.json == {"key": "var1", "value": "updated",
"description": None}
- _check_last_log(session, dag_id=None, event="api.variable.edit",
execution_date=None)
+ _check_last_log(
+ session, dag_id=None, event="api.variable.edit",
execution_date=None, expected_extra=payload
+ )
def test_should_update_variable_with_mask(self, session):
Variable.set("var1", "foo", description="before update")
@@ -353,7 +356,9 @@ class TestPostVariables(TestVariableEndpoint):
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 200
- _check_last_log(session, dag_id=None, event="api.variable.create",
execution_date=None)
+ _check_last_log(
+ session, dag_id=None, event="api.variable.create",
execution_date=None, expected_extra=payload
+ )
response = self.client.get("/api/v1/variables/var_create",
environ_overrides={"REMOTE_USER": "test"})
assert response.json == {
"key": "var_create",
@@ -361,6 +366,28 @@ class TestPostVariables(TestVariableEndpoint):
"description": description,
}
+ def test_should_create_masked_variable(self, session):
+ payload = {"key": "api_key", "value": "secret_key", "description":
"secret"}
+ response = self.client.post(
+ "/api/v1/variables",
+ json=payload,
+ environ_overrides={"REMOTE_USER": "test"},
+ )
+ assert response.status_code == 200
+ expected_extra = {
+ **payload,
+ "value": "***",
+ }
+ _check_last_log(
+ session,
+ dag_id=None,
+ event="api.variable.create",
+ execution_date=None,
+ expected_extra=expected_extra,
+ )
+ response = self.client.get("/api/v1/variables/api_key",
environ_overrides={"REMOTE_USER": "test"})
+ assert response.json == payload
+
def test_should_reject_invalid_request(self, session):
response = self.client.post(
"/api/v1/variables",
diff --git a/tests/test_utils/www.py b/tests/test_utils/www.py
index 6bee2a3be6..0a19c312fb 100644
--- a/tests/test_utils/www.py
+++ b/tests/test_utils/www.py
@@ -17,6 +17,7 @@
from __future__ import annotations
import ast
+import json
from unittest import mock
from airflow.models import Log
@@ -66,7 +67,7 @@ def check_content_not_in_response(text, resp, resp_code=200):
assert text not in resp_html
-def _check_last_log(session, dag_id, event, execution_date):
+def _check_last_log(session, dag_id, event, execution_date,
expected_extra=None):
logs = (
session.query(
Log.dag_id,
@@ -87,6 +88,8 @@ def _check_last_log(session, dag_id, event, execution_date):
)
assert len(logs) >= 1
assert logs[0].extra
+ if expected_extra:
+ assert json.loads(logs[0].extra) == expected_extra
session.query(Log).delete()
@@ -111,16 +114,16 @@ def _check_last_log_masked_connection(session, dag_id,
event, execution_date):
)
assert len(logs) >= 1
extra = ast.literal_eval(logs[0].extra)
- assert extra == [
- ("conn_id", "test_conn"),
- ("conn_type", "http"),
- ("description", "description"),
- ("host", "localhost"),
- ("port", "8080"),
- ("username", "root"),
- ("password", "***"),
- ("extra", '{"x_secret": "***", "y_secret": "***"}'),
- ]
+ assert extra == {
+ "conn_id": "test_conn",
+ "conn_type": "http",
+ "description": "description",
+ "host": "localhost",
+ "port": "8080",
+ "username": "root",
+ "password": "***",
+ "extra": {"x_secret": "***", "y_secret": "***"},
+ }
def _check_last_log_masked_variable(session, dag_id, event, execution_date):
@@ -144,4 +147,4 @@ def _check_last_log_masked_variable(session, dag_id, event,
execution_date):
)
assert len(logs) >= 1
extra_dict = ast.literal_eval(logs[0].extra)
- assert extra_dict == [("key", "x_secret"), ("val", "***")]
+ assert extra_dict == {"key": "x_secret", "val": "***"}
diff --git a/tests/www/views/test_views_decorators.py
b/tests/www/views/test_views_decorators.py
index 4cfab0926e..a8f199f595 100644
--- a/tests/www/views/test_views_decorators.py
+++ b/tests/www/views/test_views_decorators.py
@@ -124,6 +124,13 @@ def test_action_logging_post(session, admin_client):
dag_id="example_bash_operator",
event="clear",
execution_date=EXAMPLE_DAG_DEFAULT_DATE,
+ expected_extra={
+ "upstream": "false",
+ "downstream": "false",
+ "future": "false",
+ "past": "false",
+ "only_failed": "false",
+ },
)
diff --git a/tests/www/views/test_views_paused.py
b/tests/www/views/test_views_paused.py
index 1f9ac1f4d8..46b0a3aa03 100644
--- a/tests/www/views/test_views_paused.py
+++ b/tests/www/views/test_views_paused.py
@@ -39,12 +39,12 @@ def test_logging_pause_dag(admin_client, dags, session):
# is_paused=false mean pause the dag
admin_client.post(f"/paused?is_paused=false&dag_id={dag.dag_id}",
follow_redirects=True)
dag_query = session.query(Log).filter(Log.dag_id == dag.dag_id)
- assert "('is_paused', True)" in dag_query.first().extra
+ assert '{"is_paused": true}' in dag_query.first().extra
-def test_logging_unpuase_dag(admin_client, dags, session):
+def test_logging_unpause_dag(admin_client, dags, session):
_, paused_dag = dags
# is_paused=true mean unpause the dag
admin_client.post(f"/paused?is_paused=true&dag_id={paused_dag.dag_id}",
follow_redirects=True)
dag_query = session.query(Log).filter(Log.dag_id == paused_dag.dag_id)
- assert "('is_paused', False)" in dag_query.first().extra
+ assert '{"is_paused": false}' in dag_query.first().extra