This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e671074137 Richer Audit Log extra field (#38166)
e671074137 is described below

commit e67107413785fa7ff8e4f0dd89d759120bedfea6
Author: Brent Bovenzi <[email protected]>
AuthorDate: Mon Mar 18 17:58:00 2024 -0700

    Richer Audit Log extra field (#38166)
---
 .../api_connexion/endpoints/dataset_endpoint.py    |  9 +--
 airflow/www/decorators.py                          | 70 ++++++++++++++--------
 .../endpoints/test_connection_endpoint.py          |  4 +-
 tests/api_connexion/endpoints/test_dag_endpoint.py |  9 +--
 .../endpoints/test_dag_run_endpoint.py             | 10 +++-
 .../endpoints/test_dataset_endpoint.py             | 24 ++++++++
 .../endpoints/test_task_instance_endpoint.py       |  6 +-
 .../endpoints/test_variable_endpoint.py            | 39 ++++++++++--
 tests/test_utils/www.py                            | 27 +++++----
 tests/www/views/test_views_decorators.py           |  7 +++
 tests/www/views/test_views_paused.py               |  6 +-
 11 files changed, 150 insertions(+), 61 deletions(-)

diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py 
b/airflow/api_connexion/endpoints/dataset_endpoint.py
index 57c299cade..bfdb8d0a5e 100644
--- a/airflow/api_connexion/endpoints/dataset_endpoint.py
+++ b/airflow/api_connexion/endpoints/dataset_endpoint.py
@@ -46,10 +46,8 @@ from airflow.api_connexion.schemas.dataset_schema import (
 from airflow.datasets import Dataset
 from airflow.datasets.manager import dataset_manager
 from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, 
DatasetModel
-from airflow.security import permissions
 from airflow.utils import timezone
 from airflow.utils.db import get_query_count
-from airflow.utils.log.action_logger import action_event_from_permission
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.www.decorators import action_logging
 from airflow.www.extensions.init_auth_manager import get_auth_manager
@@ -330,12 +328,7 @@ def delete_dataset_queued_events(
 
 @security.requires_access_dataset("POST")
 @provide_session
-@action_logging(
-    event=action_event_from_permission(
-        prefix=RESOURCE_EVENT_PREFIX,
-        permission=permissions.ACTION_CAN_CREATE,
-    ),
-)
+@action_logging
 def create_dataset_event(session: Session = NEW_SESSION) -> APIResponse:
     """Create dataset event."""
     body = get_json_request_dict()
diff --git a/airflow/www/decorators.py b/airflow/www/decorators.py
index 91146d0fee..3eae5f6239 100644
--- a/airflow/www/decorators.py
+++ b/airflow/www/decorators.py
@@ -44,36 +44,36 @@ def _mask_variable_fields(extra_fields):
     Mask the 'val_content' field if 'key_content' is in the mask list.
 
     The variable requests values and args comes in this form:
-    [('key', 'key_content'),('val', 'val_content'), ('description', 
'description_content')]
+    {'key': 'key_content', 'val': 'val_content', 'description': 
'description_content'}
     """
-    result = []
+    result = {}
     keyname = None
-    for k, v in extra_fields:
+    for k, v in extra_fields.items():
         if k == "key":
             keyname = v
-            result.append((k, v))
-        elif keyname and k == "val":
+            result[k] = v
+        elif keyname and (k == "val" or k == "value"):
             x = secrets_masker.redact(v, keyname)
-            result.append((k, x))
+            result[k] = x
             keyname = None
         else:
-            result.append((k, v))
+            result[k] = v
     return result
 
 
 def _mask_connection_fields(extra_fields):
     """Mask connection fields."""
-    result = []
-    for k, v in extra_fields:
-        if k == "extra":
+    result = {}
+    for k, v in extra_fields.items():
+        if k == "extra" and v:
             try:
                 extra = json.loads(v)
-                extra = [(k, secrets_masker.redact(v, k)) for k, v in 
extra.items()]
-                result.append((k, json.dumps(dict(extra))))
+                extra = {k: secrets_masker.redact(v, k) for k, v in 
extra.items()}
+                result[k] = dict(extra)
             except json.JSONDecodeError:
-                result.append((k, "Encountered non-JSON in `extra` field"))
+                result[k] = "Encountered non-JSON in `extra` field"
         else:
-            result.append((k, secrets_masker.redact(v, k)))
+            result[k] = secrets_masker.redact(v, k)
     return result
 
 
@@ -94,35 +94,55 @@ def action_logging(func: T | None = None, event: str | None 
= None) -> T | Calla
                     user = get_auth_manager().get_user_name()
                     user_display = get_auth_manager().get_user_display_name()
 
-                fields_skip_logging = {"csrf_token", "_csrf_token", 
"is_paused"}
-                extra_fields = [
-                    (k, secrets_masker.redact(v, k))
+                isAPIRequest = request.blueprint == "/api/v1"
+                hasJsonBody = request.headers.get("content-type") == 
"application/json" and request.json
+
+                fields_skip_logging = {
+                    "csrf_token",
+                    "_csrf_token",
+                    "is_paused",
+                    "dag_id",
+                    "task_id",
+                    "dag_run_id",
+                    "run_id",
+                    "execution_date",
+                }
+                extra_fields = {
+                    k: secrets_masker.redact(v, k)
                     for k, v in 
itertools.chain(request.values.items(multi=True), request.view_args.items())
                     if k not in fields_skip_logging
-                ]
+                }
                 if event and event.startswith("variable."):
-                    extra_fields = _mask_variable_fields(extra_fields)
-                if event and event.startswith("connection."):
-                    extra_fields = _mask_connection_fields(extra_fields)
+                    extra_fields = _mask_variable_fields(
+                        request.json if isAPIRequest and hasJsonBody else 
extra_fields
+                    )
+                elif event and event.startswith("connection."):
+                    extra_fields = _mask_connection_fields(
+                        request.json if isAPIRequest and hasJsonBody else 
extra_fields
+                    )
+                elif hasJsonBody:
+                    masked_json = {k: secrets_masker.redact(v, k) for k, v in 
request.json.items()}
+                    extra_fields = {**extra_fields, **masked_json}
 
                 params = {**request.values, **request.view_args}
+                if params and "is_paused" in params:
+                    extra_fields["is_paused"] = params["is_paused"] == "false"
 
-                if request.blueprint == "/api/v1":
+                if isAPIRequest:
                     if f"{request.origin}/" == request.root_url:
                         event_name = f"ui.{event_name}"
                     else:
                         event_name = f"api.{event_name}"
 
-                if params and "is_paused" in params:
-                    extra_fields.append(("is_paused", params["is_paused"] == 
"false"))
                 log = Log(
                     event=event_name,
                     task_instance=None,
                     owner=user,
                     owner_display_name=user_display,
-                    extra=str(extra_fields),
+                    extra=json.dumps(extra_fields),
                     task_id=params.get("task_id"),
                     dag_id=params.get("dag_id"),
+                    run_id=params.get("run_id") or params.get("dag_run_id"),
                 )
 
                 if "execution_date" in request.values:
diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py 
b/tests/api_connexion/endpoints/test_connection_endpoint.py
index ca76ce327f..0d9debeb98 100644
--- a/tests/api_connexion/endpoints/test_connection_endpoint.py
+++ b/tests/api_connexion/endpoints/test_connection_endpoint.py
@@ -539,7 +539,9 @@ class TestPostConnection(TestConnectionEndpoint):
         connection = session.query(Connection).all()
         assert len(connection) == 1
         assert connection[0].conn_id == "test-connection-id"
-        _check_last_log(session, dag_id=None, event="api.connection.create", 
execution_date=None)
+        _check_last_log(
+            session, dag_id=None, event="api.connection.create", 
execution_date=None, expected_extra=payload
+        )
 
     def test_post_should_respond_200_extra_null(self, session):
         payload = {"connection_id": "test-connection-id", "conn_type": 
"test_type", "extra": None}
diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py 
b/tests/api_connexion/endpoints/test_dag_endpoint.py
index 0655810e0e..30f87e1f68 100644
--- a/tests/api_connexion/endpoints/test_dag_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_endpoint.py
@@ -1247,11 +1247,10 @@ class TestPatchDag(TestDagEndpoint):
     def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer, 
session):
         file_token = url_safe_serializer.dumps("/tmp/dag_1.py")
         dag_model = self._create_dag_model()
+        payload = {"is_paused": False}
         response = self.client.patch(
             f"/api/v1/dags/{dag_model.dag_id}",
-            json={
-                "is_paused": False,
-            },
+            json=payload,
             environ_overrides={"REMOTE_USER": "test"},
         )
         assert response.status_code == 200
@@ -1288,7 +1287,9 @@ class TestPatchDag(TestDagEndpoint):
             "pickle_id": None,
         }
         assert response.json == expected_response
-        _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag", 
execution_date=None)
+        _check_last_log(
+            session, dag_id="TEST_DAG_1", event="api.patch_dag", 
execution_date=None, expected_extra=payload
+        )
 
     def test_should_respond_200_on_patch_with_granular_dag_access(self, 
session):
         self._create_dag_models(1)
diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py 
b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
index ef91ef78eb..f6ace16099 100644
--- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
@@ -1952,9 +1952,10 @@ class TestSetDagRunNote(TestDagRunEndpoint):
         assert dr.dag_run_note.user_id is not None
         # Update the note again
         new_note_value = "My super cool DagRun notes 2"
+        payload = {"note": new_note_value}
         response = self.client.patch(
             
f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote",
-            json={"note": new_note_value},
+            json=payload,
             environ_overrides={"REMOTE_USER": "test"},
         )
         assert response.status_code == 200
@@ -1975,6 +1976,13 @@ class TestSetDagRunNote(TestDagRunEndpoint):
             "note": new_note_value,
         }
         assert dr.dag_run_note.user_id is not None
+        _check_last_log(
+            session,
+            dag_id=dr.dag_id,
+            event="api.set_dag_run_note",
+            execution_date=None,
+            expected_extra=payload,
+        )
 
     def test_schema_validation_error_raises(self, dag_maker, session):
         dag_runs: list[DagRun] = self._create_test_dag_run(DagRunState.SUCCESS)
diff --git a/tests/api_connexion/endpoints/test_dataset_endpoint.py 
b/tests/api_connexion/endpoints/test_dataset_endpoint.py
index 0ba296604d..29192a3c65 100644
--- a/tests/api_connexion/endpoints/test_dataset_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dataset_endpoint.py
@@ -612,6 +612,30 @@ class TestPostDatasetEvents(TestDatasetEndpoint):
             "source_map_index": -1,
             "timestamp": self.default_time,
         }
+        _check_last_log(
+            session,
+            dag_id=None,
+            event="api.create_dataset_event",
+            execution_date=None,
+            expected_extra=event_payload,
+        )
+
+    def test_should_mask_sensitive_extra_logs(self, session):
+        self._create_dataset(session)
+        event_payload = {"dataset_uri": "s3://bucket/key", "extra": 
{"password": "bar"}}
+        response = self.client.post(
+            "/api/v1/datasets/events", json=event_payload, 
environ_overrides={"REMOTE_USER": "test"}
+        )
+
+        assert response.status_code == 200
+        expected_extra = {**event_payload, "extra": {"password": "***"}}
+        _check_last_log(
+            session,
+            dag_id=None,
+            event="api.create_dataset_event",
+            execution_date=None,
+            expected_extra=expected_extra,
+        )
 
     def test_order_by_raises_400_for_invalid_attr(self, session):
         self._create_dataset(session)
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py 
b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index 59ee686c1f..e6a3e0dfbb 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -1249,7 +1249,11 @@ class 
TestPostClearTaskInstances(TestTaskInstanceEndpoint):
         assert response.status_code == 200
         assert len(response.json["task_instances"]) == expected_ti
         _check_last_log(
-            session, dag_id=request_dag, 
event="api.post_clear_task_instances", execution_date=None
+            session,
+            dag_id=request_dag,
+            event="api.post_clear_task_instances",
+            execution_date=None,
+            expected_extra=payload,
         )
 
     
@mock.patch("airflow.api_connexion.endpoints.task_instance_endpoint.clear_task_instances")
diff --git a/tests/api_connexion/endpoints/test_variable_endpoint.py 
b/tests/api_connexion/endpoints/test_variable_endpoint.py
index f533710aab..f0a1b7502b 100644
--- a/tests/api_connexion/endpoints/test_variable_endpoint.py
+++ b/tests/api_connexion/endpoints/test_variable_endpoint.py
@@ -250,17 +250,20 @@ class TestGetVariables(TestVariableEndpoint):
 class TestPatchVariable(TestVariableEndpoint):
     def test_should_update_variable(self, session):
         Variable.set("var1", "foo")
+        payload = {
+            "key": "var1",
+            "value": "updated",
+        }
         response = self.client.patch(
             "/api/v1/variables/var1",
-            json={
-                "key": "var1",
-                "value": "updated",
-            },
+            json=payload,
             environ_overrides={"REMOTE_USER": "test"},
         )
         assert response.status_code == 200
         assert response.json == {"key": "var1", "value": "updated", 
"description": None}
-        _check_last_log(session, dag_id=None, event="api.variable.edit", 
execution_date=None)
+        _check_last_log(
+            session, dag_id=None, event="api.variable.edit", 
execution_date=None, expected_extra=payload
+        )
 
     def test_should_update_variable_with_mask(self, session):
         Variable.set("var1", "foo", description="before update")
@@ -353,7 +356,9 @@ class TestPostVariables(TestVariableEndpoint):
             environ_overrides={"REMOTE_USER": "test"},
         )
         assert response.status_code == 200
-        _check_last_log(session, dag_id=None, event="api.variable.create", 
execution_date=None)
+        _check_last_log(
+            session, dag_id=None, event="api.variable.create", 
execution_date=None, expected_extra=payload
+        )
         response = self.client.get("/api/v1/variables/var_create", 
environ_overrides={"REMOTE_USER": "test"})
         assert response.json == {
             "key": "var_create",
@@ -361,6 +366,28 @@ class TestPostVariables(TestVariableEndpoint):
             "description": description,
         }
 
+    def test_should_create_masked_variable(self, session):
+        payload = {"key": "api_key", "value": "secret_key", "description": 
"secret"}
+        response = self.client.post(
+            "/api/v1/variables",
+            json=payload,
+            environ_overrides={"REMOTE_USER": "test"},
+        )
+        assert response.status_code == 200
+        expected_extra = {
+            **payload,
+            "value": "***",
+        }
+        _check_last_log(
+            session,
+            dag_id=None,
+            event="api.variable.create",
+            execution_date=None,
+            expected_extra=expected_extra,
+        )
+        response = self.client.get("/api/v1/variables/api_key", 
environ_overrides={"REMOTE_USER": "test"})
+        assert response.json == payload
+
     def test_should_reject_invalid_request(self, session):
         response = self.client.post(
             "/api/v1/variables",
diff --git a/tests/test_utils/www.py b/tests/test_utils/www.py
index 6bee2a3be6..0a19c312fb 100644
--- a/tests/test_utils/www.py
+++ b/tests/test_utils/www.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 import ast
+import json
 from unittest import mock
 
 from airflow.models import Log
@@ -66,7 +67,7 @@ def check_content_not_in_response(text, resp, resp_code=200):
         assert text not in resp_html
 
 
-def _check_last_log(session, dag_id, event, execution_date):
+def _check_last_log(session, dag_id, event, execution_date, 
expected_extra=None):
     logs = (
         session.query(
             Log.dag_id,
@@ -87,6 +88,8 @@ def _check_last_log(session, dag_id, event, execution_date):
     )
     assert len(logs) >= 1
     assert logs[0].extra
+    if expected_extra:
+        assert json.loads(logs[0].extra) == expected_extra
     session.query(Log).delete()
 
 
@@ -111,16 +114,16 @@ def _check_last_log_masked_connection(session, dag_id, 
event, execution_date):
     )
     assert len(logs) >= 1
     extra = ast.literal_eval(logs[0].extra)
-    assert extra == [
-        ("conn_id", "test_conn"),
-        ("conn_type", "http"),
-        ("description", "description"),
-        ("host", "localhost"),
-        ("port", "8080"),
-        ("username", "root"),
-        ("password", "***"),
-        ("extra", '{"x_secret": "***", "y_secret": "***"}'),
-    ]
+    assert extra == {
+        "conn_id": "test_conn",
+        "conn_type": "http",
+        "description": "description",
+        "host": "localhost",
+        "port": "8080",
+        "username": "root",
+        "password": "***",
+        "extra": {"x_secret": "***", "y_secret": "***"},
+    }
 
 
 def _check_last_log_masked_variable(session, dag_id, event, execution_date):
@@ -144,4 +147,4 @@ def _check_last_log_masked_variable(session, dag_id, event, 
execution_date):
     )
     assert len(logs) >= 1
     extra_dict = ast.literal_eval(logs[0].extra)
-    assert extra_dict == [("key", "x_secret"), ("val", "***")]
+    assert extra_dict == {"key": "x_secret", "val": "***"}
diff --git a/tests/www/views/test_views_decorators.py 
b/tests/www/views/test_views_decorators.py
index 4cfab0926e..a8f199f595 100644
--- a/tests/www/views/test_views_decorators.py
+++ b/tests/www/views/test_views_decorators.py
@@ -124,6 +124,13 @@ def test_action_logging_post(session, admin_client):
         dag_id="example_bash_operator",
         event="clear",
         execution_date=EXAMPLE_DAG_DEFAULT_DATE,
+        expected_extra={
+            "upstream": "false",
+            "downstream": "false",
+            "future": "false",
+            "past": "false",
+            "only_failed": "false",
+        },
     )
 
 
diff --git a/tests/www/views/test_views_paused.py 
b/tests/www/views/test_views_paused.py
index 1f9ac1f4d8..46b0a3aa03 100644
--- a/tests/www/views/test_views_paused.py
+++ b/tests/www/views/test_views_paused.py
@@ -39,12 +39,12 @@ def test_logging_pause_dag(admin_client, dags, session):
     # is_paused=false mean pause the dag
     admin_client.post(f"/paused?is_paused=false&dag_id={dag.dag_id}", 
follow_redirects=True)
     dag_query = session.query(Log).filter(Log.dag_id == dag.dag_id)
-    assert "('is_paused', True)" in dag_query.first().extra
+    assert '{"is_paused": true}' in dag_query.first().extra
 
 
-def test_logging_unpuase_dag(admin_client, dags, session):
+def test_logging_unpause_dag(admin_client, dags, session):
     _, paused_dag = dags
     # is_paused=true mean unpause the dag
     admin_client.post(f"/paused?is_paused=true&dag_id={paused_dag.dag_id}", 
follow_redirects=True)
     dag_query = session.query(Log).filter(Log.dag_id == paused_dag.dag_id)
-    assert "('is_paused', False)" in dag_query.first().extra
+    assert '{"is_paused": false}' in dag_query.first().extra

Reply via email to