This is an automated email from the ASF dual-hosted git repository.

shahar1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d6dee49db22 Rework StackdriverTaskHandler for the structlog era #65191 
(#65198)
d6dee49db22 is described below

commit d6dee49db228ca14f829e12df066b1cac994ceee
Author: Haseeb Malik <[email protected]>
AuthorDate: Tue May 26 17:16:46 2026 -0400

    Rework StackdriverTaskHandler for the structlog era #65191 (#65198)
---
 .../google/cloud/log/stackdriver_task_handler.py   | 396 +++++++++++++--------
 .../cloud/log/test_stackdriver_task_handler.py     | 250 ++++++++++++-
 2 files changed, 484 insertions(+), 162 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py
 
b/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py
index b11a1cd6a47..145c2564539 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py
@@ -18,13 +18,21 @@
 
 from __future__ import annotations
 
+import contextlib
+import copy
 import logging
+import os
+import shutil
 import warnings
 from collections.abc import Collection
+from datetime import datetime
 from functools import cached_property
+from logging import getLogRecordFactory
+from pathlib import Path
 from typing import TYPE_CHECKING
 from urllib.parse import urlencode
 
+import attrs
 from google.cloud import logging as gcp_logging
 from google.cloud.logging import Resource
 from google.cloud.logging.handlers.transports import 
BackgroundThreadTransport, Transport
@@ -35,6 +43,7 @@ from airflow.exceptions import 
AirflowProviderDeprecationWarning
 from airflow.providers.google.cloud.utils.credentials_provider import 
get_credentials_and_project_id
 from airflow.providers.google.common.consts import CLIENT_INFO
 from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.utils.log.logging_mixin import LoggingMixin
 
 try:
     from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
@@ -45,9 +54,12 @@ if not AIRFLOW_V_3_0_PLUS:
     from airflow.utils.log.trigger_handler import ctx_indiv_trigger
 
 if TYPE_CHECKING:
+    import structlog.typing
     from google.auth.credentials import Credentials
 
     from airflow.models import TaskInstance
+    from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
+    from airflow.utils.log.file_task_handler import LogResponse
 
 DEFAULT_LOGGER_NAME = "airflow"
 _GLOBAL_RESOURCE = Resource(type="global", labels={})
@@ -56,6 +68,209 @@ _DEFAULT_SCOPESS = frozenset(
     ["https://www.googleapis.com/auth/logging.read";, 
"https://www.googleapis.com/auth/logging.write";]
 )
 
+LABEL_TASK_ID = "task_id"
+LABEL_DAG_ID = "dag_id"
+LABEL_LOGICAL_DATE = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date"
+LABEL_TRY_NUMBER = "try_number"
+
+
[email protected](kw_only=True)
+class StackdriverRemoteLogIO(LoggingMixin):
+    """Remote log IO that streams logs to and reads from Google Cloud 
Stackdriver Logging."""
+
+    base_log_folder: Path = attrs.field(converter=Path)
+    delete_local_copy: bool = True
+
+    gcp_key_path: str | None = None
+    scopes: Collection[str] | None = _DEFAULT_SCOPESS
+    gcp_log_name: str = DEFAULT_LOGGER_NAME
+    transport_type: type[Transport] = BackgroundThreadTransport
+    resource: Resource = _GLOBAL_RESOURCE
+    labels: dict[str, str] | None = None
+
+    @cached_property
+    def credentials_and_project(self) -> tuple[Credentials, str]:
+        credentials, project = get_credentials_and_project_id(
+            key_path=self.gcp_key_path, scopes=self.scopes, 
disable_logging=True
+        )
+        return credentials, project
+
+    @cached_property
+    def _client(self) -> gcp_logging.Client:
+        """The Cloud Library API client."""
+        credentials, project = self.credentials_and_project
+        return gcp_logging.Client(
+            credentials=credentials,
+            project=project,
+            client_info=CLIENT_INFO,
+        )
+
+    @cached_property
+    def _logging_service_client(self) -> LoggingServiceV2Client:
+        """The Cloud logging service v2 client."""
+        credentials, _ = self.credentials_and_project
+        return LoggingServiceV2Client(
+            credentials=credentials,
+            client_info=CLIENT_INFO,
+        )
+
+    @cached_property
+    def transport(self) -> Transport:
+        """Object responsible for sending data to Stackdriver."""
+        return self.transport_type(self._client, self.gcp_log_name)
+
+    @cached_property
+    def processors(self) -> tuple[structlog.typing.Processor, ...]:
+        import structlog.stdlib
+
+        from airflow.sdk.log import relative_path_from_logger
+
+        log_record_factory = getLogRecordFactory()
+        _transport = self.transport
+
+        def proc(
+            logger: structlog.typing.WrappedLogger,
+            method_name: str,
+            event: structlog.typing.EventDict,
+        ):
+            if not logger or not relative_path_from_logger(logger):
+                return event
+
+            name = event.get("logger_name") or event.get("logger", "")
+            level = structlog.stdlib.NAME_TO_LEVEL.get(method_name.lower(), 
logging.INFO)
+            msg = copy.copy(event)
+            created = None
+            if ts := msg.pop("timestamp", None):
+                with contextlib.suppress(Exception):
+                    created = datetime.fromisoformat(ts)
+            record = log_record_factory(
+                name,
+                level,
+                pathname="",
+                lineno=0,
+                msg=msg,
+                args=(),
+                exc_info=None,
+                func=None,
+                sinfo=None,
+            )
+            if created is not None:
+                ct = created.timestamp()
+                record.created = ct
+                record.msecs = int((ct - int(ct)) * 1000) + 0.0
+
+            ti = getattr(record, "task_instance", None)
+            labels: dict[str, str] = {}
+            if self.labels:
+                labels.update(self.labels)
+            if ti:
+                labels.update(_task_instance_to_labels(ti))
+            _transport.send(record, str(msg.get("event", "")), 
resource=self.resource, labels=labels)
+            return event
+
+        return (proc,)
+
+    def upload(self, path: os.PathLike | str, ti: RuntimeTI) -> None:
+        """Flush the transport and optionally delete local log files."""
+        self.transport.flush()
+        if self.delete_local_copy:
+            base = self.base_log_folder.resolve()
+            raw = Path(path)
+            local_path = (raw if raw.is_absolute() else base / raw).resolve()
+            try:
+                local_path.relative_to(base)
+            except ValueError:
+                self.log.warning(
+                    "Skipping deletion: path %s is outside base_log_folder %s",
+                    local_path,
+                    base,
+                )
+                return
+            parent = local_path.parent
+            if parent.exists():
+                shutil.rmtree(parent, ignore_errors=True)
+
+    def read(self, relative_path: str, ti: RuntimeTI) -> LogResponse:
+        """Read logs from Stackdriver Logging using task instance labels."""
+        ti_labels = _task_instance_to_labels(ti)
+        log_filter = self.prepare_log_filter(ti_labels)
+        messages, end_of_log, _ = self.read_logs(log_filter, 
next_page_token=None, all_pages=True)
+        return [f"Reading remote log from Stackdriver for {relative_path}"], 
[messages] if messages else []
+
+    def prepare_log_filter(self, ti_labels: dict[str, str]) -> str:
+        def escape_label_key(key: str) -> str:
+            return f'"{key}"' if "." in key else key
+
+        def escape_label_value(value: str) -> str:
+            escaped_value = value.replace("\\", "\\\\").replace('"', '\\"')
+            return f'"{escaped_value}"'
+
+        _, project = self.credentials_and_project
+        log_filters = [
+            f"resource.type={escape_label_value(self.resource.type)}",
+            f'logName="projects/{project}/logs/{self.gcp_log_name}"',
+        ]
+
+        for key, value in self.resource.labels.items():
+            
log_filters.append(f"resource.labels.{escape_label_key(key)}={escape_label_value(value)}")
+
+        for key, value in ti_labels.items():
+            
log_filters.append(f"labels.{escape_label_key(key)}={escape_label_value(value)}")
+        return "\n".join(log_filters)
+
+    def read_logs(
+        self, log_filter: str, next_page_token: str | None, all_pages: bool
+    ) -> tuple[str, bool, str | None]:
+        messages = []
+        new_messages, next_page_token = self._read_single_logs_page(
+            log_filter=log_filter,
+            page_token=next_page_token,
+        )
+        messages.append(new_messages)
+        if all_pages:
+            while next_page_token:
+                new_messages, next_page_token = self._read_single_logs_page(
+                    log_filter=log_filter, page_token=next_page_token
+                )
+                messages.append(new_messages)
+
+            end_of_log = True
+            next_page_token = None
+        else:
+            end_of_log = not bool(next_page_token)
+        return "\n".join(messages), end_of_log, next_page_token
+
+    def _read_single_logs_page(self, log_filter: str, page_token: str | None = 
None) -> tuple[str, str]:
+        _, project = self.credentials_and_project
+        request = ListLogEntriesRequest(
+            resource_names=[f"projects/{project}"],
+            filter=log_filter,
+            page_token=page_token,
+            order_by="timestamp asc",
+            page_size=1000,
+        )
+        response = 
self._logging_service_client.list_log_entries(request=request)
+        page: ListLogEntriesResponse = next(response.pages)
+        messages: list[str] = []
+        for entry in page.entries:
+            if "message" in (entry.json_payload or {}):
+                messages.append(entry.json_payload["message"])  # type: ignore
+            elif entry.text_payload:
+                messages.append(entry.text_payload)
+        return "\n".join(messages), page.next_page_token
+
+
+def _task_instance_to_labels(ti) -> dict[str, str]:
+    """Convert a task instance to Stackdriver labels."""
+    return {
+        LABEL_TASK_ID: ti.task_id,
+        LABEL_DAG_ID: ti.dag_id,
+        LABEL_LOGICAL_DATE: str(ti.logical_date.isoformat())
+        if AIRFLOW_V_3_0_PLUS
+        else str(ti.execution_date.isoformat()),
+        LABEL_TRY_NUMBER: str(ti.try_number),
+    }
+
 
 class StackdriverTaskHandler(logging.Handler):
     """
@@ -88,10 +303,11 @@ class StackdriverTaskHandler(logging.Handler):
     :param labels: (Optional) Mapping of labels for the entry.
     """
 
-    LABEL_TASK_ID = "task_id"
-    LABEL_DAG_ID = "dag_id"
-    LABEL_LOGICAL_DATE = "logical_date" if AIRFLOW_V_3_0_PLUS else 
"execution_date"
-    LABEL_TRY_NUMBER = "try_number"
+    # Re-export module-level constants for back-compat with external code 
reading them off the class
+    LABEL_TASK_ID = LABEL_TASK_ID
+    LABEL_DAG_ID = LABEL_DAG_ID
+    LABEL_LOGICAL_DATE = LABEL_LOGICAL_DATE
+    LABEL_TRY_NUMBER = LABEL_TRY_NUMBER
     LOG_VIEWER_BASE_URL = "https://console.cloud.google.com/logs/viewer";
     LOG_NAME = "Google Stackdriver"
 
@@ -120,53 +336,23 @@ class StackdriverTaskHandler(logging.Handler):
             gcp_log_name = str(name)
 
         super().__init__()
-        self.gcp_key_path: str | None = gcp_key_path
-        self.scopes: Collection[str] | None = scopes
-        self.gcp_log_name: str = gcp_log_name
-        self.transport_type: type[Transport] = transport
-        self.resource: Resource = resource
+        self.io = StackdriverRemoteLogIO(
+            base_log_folder=Path("."),
+            gcp_key_path=gcp_key_path,
+            scopes=scopes,
+            gcp_log_name=gcp_log_name,
+            transport_type=transport,
+            resource=resource,
+            labels=labels,
+        )
         self.labels: dict[str, str] | None = labels
+        self.resource: Resource = resource
         self.task_instance_labels: dict[str, str] | None = {}
         self.task_instance_hostname = "default-hostname"
 
-    @cached_property
-    def _credentials_and_project(self) -> tuple[Credentials, str]:
-        credentials, project = get_credentials_and_project_id(
-            key_path=self.gcp_key_path, scopes=self.scopes, 
disable_logging=True
-        )
-        return credentials, project
-
-    @property
-    def _client(self) -> gcp_logging.Client:
-        """The Cloud Library API client."""
-        credentials, project = self._credentials_and_project
-        client = gcp_logging.Client(
-            credentials=credentials,
-            project=project,
-            client_info=CLIENT_INFO,
-        )
-        return client
-
-    @property
-    def _logging_service_client(self) -> LoggingServiceV2Client:
-        """The Cloud logging service v2 client."""
-        credentials, _ = self._credentials_and_project
-        client = LoggingServiceV2Client(
-            credentials=credentials,
-            client_info=CLIENT_INFO,
-        )
-        return client
-
-    @cached_property
-    def _transport(self) -> Transport:
-        """Object responsible for sending data to Stackdriver."""
-        # The Transport object is badly defined (no init) but in the docs 
client/name as constructor
-        # arguments are a requirement for any class that derives from 
Transport class, hence ignore:
-        return self.transport_type(self._client, self.gcp_log_name)
-
     def _get_labels(self, task_instance=None):
         if task_instance:
-            ti_labels = self._task_instance_to_labels(task_instance)
+            ti_labels = _task_instance_to_labels(task_instance)
         else:
             ti_labels = self.task_instance_labels
         labels: dict[str, str] | None
@@ -193,7 +379,7 @@ class StackdriverTaskHandler(logging.Handler):
         if not AIRFLOW_V_3_0_PLUS and getattr(record, ctx_indiv_trigger.name, 
None):
             ti = getattr(record, "task_instance", None)  # trigger context
         labels = self._get_labels(ti)
-        self._transport.send(record, message, resource=self.resource, 
labels=labels)
+        self.io.transport.send(record, message, resource=self.resource, 
labels=labels)
 
     def set_context(self, task_instance: TaskInstance) -> None:
         """
@@ -201,7 +387,7 @@ class StackdriverTaskHandler(logging.Handler):
 
         :param task_instance: Currently executed task
         """
-        self.task_instance_labels = 
self._task_instance_to_labels(task_instance)
+        self.task_instance_labels = _task_instance_to_labels(task_instance)
         self.task_instance_hostname = task_instance.hostname or 
"default-hostname"
 
     def read(
@@ -225,18 +411,18 @@ class StackdriverTaskHandler(logging.Handler):
         if not metadata:
             metadata = {}
 
-        ti_labels = self._task_instance_to_labels(task_instance)
+        ti_labels = _task_instance_to_labels(task_instance)
 
         if try_number is not None:
-            ti_labels[self.LABEL_TRY_NUMBER] = str(try_number)
+            ti_labels[LABEL_TRY_NUMBER] = str(try_number)
         else:
-            del ti_labels[self.LABEL_TRY_NUMBER]
+            del ti_labels[LABEL_TRY_NUMBER]
 
-        log_filter = self._prepare_log_filter(ti_labels)
+        log_filter = self.io.prepare_log_filter(ti_labels)
         next_page_token = metadata.get("next_page_token", None)
         all_pages = "download_logs" in metadata and metadata["download_logs"]
 
-        messages, end_of_log, next_page_token = self._read_logs(log_filter, 
next_page_token, all_pages)
+        messages, end_of_log, next_page_token = self.io.read_logs(log_filter, 
next_page_token, all_pages)
 
         new_metadata: dict[str, str | bool] = {"end_of_log": end_of_log}
 
@@ -245,102 +431,6 @@ class StackdriverTaskHandler(logging.Handler):
 
         return [((self.task_instance_hostname, messages),)], [new_metadata]
 
-    def _prepare_log_filter(self, ti_labels: dict[str, str]) -> str:
-        """
-        Prepare the filter that chooses which log entries to fetch.
-
-        More information:
-        
https://cloud.google.com/logging/docs/reference/v2/rest/v2/entries/list#body.request_body.FIELDS.filter
-        https://cloud.google.com/logging/docs/view/advanced-queries
-
-        :param ti_labels: Task Instance's labels that will be used to search 
for logs
-        :return: logs filter
-        """
-
-        def escape_label_key(key: str) -> str:
-            return f'"{key}"' if "." in key else key
-
-        def escale_label_value(value: str) -> str:
-            escaped_value = value.replace("\\", "\\\\").replace('"', '\\"')
-            return f'"{escaped_value}"'
-
-        _, project = self._credentials_and_project
-        log_filters = [
-            f"resource.type={escale_label_value(self.resource.type)}",
-            f'logName="projects/{project}/logs/{self.gcp_log_name}"',
-        ]
-
-        for key, value in self.resource.labels.items():
-            
log_filters.append(f"resource.labels.{escape_label_key(key)}={escale_label_value(value)}")
-
-        for key, value in ti_labels.items():
-            
log_filters.append(f"labels.{escape_label_key(key)}={escale_label_value(value)}")
-        return "\n".join(log_filters)
-
-    def _read_logs(
-        self, log_filter: str, next_page_token: str | None, all_pages: bool
-    ) -> tuple[str, bool, str | None]:
-        """
-        Send requests to the Stackdriver service and downloads logs.
-
-        :param log_filter: Filter specifying the logs to be downloaded.
-        :param next_page_token: The token of the page from which the log 
download will start.
-            If None is passed, it will start from the first page.
-        :param all_pages: If True is passed, all subpages will be downloaded. 
Otherwise, only the first
-            page will be downloaded
-        :return: A token that contains the following items:
-            * string with logs
-            * Boolean value describing whether there are more logs,
-            * token of the next page
-        """
-        messages = []
-        new_messages, next_page_token = self._read_single_logs_page(
-            log_filter=log_filter,
-            page_token=next_page_token,
-        )
-        messages.append(new_messages)
-        if all_pages:
-            while next_page_token:
-                new_messages, next_page_token = self._read_single_logs_page(
-                    log_filter=log_filter, page_token=next_page_token
-                )
-                messages.append(new_messages)
-                if not messages:
-                    break
-
-            end_of_log = True
-            next_page_token = None
-        else:
-            end_of_log = not bool(next_page_token)
-        return "\n".join(messages), end_of_log, next_page_token
-
-    def _read_single_logs_page(self, log_filter: str, page_token: str | None = 
None) -> tuple[str, str]:
-        """
-        Send requests to the Stackdriver service and downloads single pages 
with logs.
-
-        :param log_filter: Filter specifying the logs to be downloaded.
-        :param page_token: The token of the page to be downloaded. If None is 
passed, the first page will be
-            downloaded.
-        :return: Downloaded logs and next page token
-        """
-        _, project = self._credentials_and_project
-        request = ListLogEntriesRequest(
-            resource_names=[f"projects/{project}"],
-            filter=log_filter,
-            page_token=page_token,
-            order_by="timestamp asc",
-            page_size=1000,
-        )
-        response = 
self._logging_service_client.list_log_entries(request=request)
-        page: ListLogEntriesResponse = next(response.pages)
-        messages: list[str] = []
-        for entry in page.entries:
-            if "message" in (entry.json_payload or {}):
-                messages.append(entry.json_payload["message"])  # type: ignore
-            elif entry.text_payload:
-                messages.append(entry.text_payload)
-        return "\n".join(messages), page.next_page_token
-
     @classmethod
     def _task_instance_to_labels(cls, ti: TaskInstance) -> dict[str, str]:
         return {
@@ -375,12 +465,12 @@ class StackdriverTaskHandler(logging.Handler):
         :param try_number: task instance try_number to read logs from
         :return: URL to the external log collection service
         """
-        _, project_id = self._credentials_and_project
+        _, project_id = self.io.credentials_and_project
 
-        ti_labels = self._task_instance_to_labels(task_instance)
-        ti_labels[self.LABEL_TRY_NUMBER] = str(try_number)
+        ti_labels = _task_instance_to_labels(task_instance)
+        ti_labels[LABEL_TRY_NUMBER] = str(try_number)
 
-        log_filter = self._prepare_log_filter(ti_labels)
+        log_filter = self.io.prepare_log_filter(ti_labels)
 
         url_query_string = {
             "project": project_id,
@@ -393,4 +483,4 @@ class StackdriverTaskHandler(logging.Handler):
         return url
 
     def close(self) -> None:
-        self._transport.flush()
+        self.io.transport.flush()
diff --git 
a/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py 
b/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py
index 66c2ebe29fc..a55b62fa4f4 100644
--- 
a/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py
+++ 
b/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py
@@ -18,14 +18,19 @@ from __future__ import annotations
 
 import logging
 from contextlib import nullcontext
+from pathlib import Path
 from unittest import mock
+from unittest.mock import PropertyMock
 from urllib.parse import parse_qs, urlsplit
 
 import pytest
 from google.cloud.logging import Resource
 from google.cloud.logging_v2.types import ListLogEntriesRequest, 
ListLogEntriesResponse, LogEntry
 
-from airflow.providers.google.cloud.log.stackdriver_task_handler import 
StackdriverTaskHandler
+from airflow.providers.google.cloud.log.stackdriver_task_handler import (
+    StackdriverRemoteLogIO,
+    StackdriverTaskHandler,
+)
 from airflow.utils import timezone
 from airflow.utils.state import TaskInstanceState
 
@@ -50,6 +55,232 @@ def clean_stackdriver_handlers():
             del handler
 
 
+class TestStackdriverRemoteLogIO:
+    @pytest.fixture(autouse=True)
+    def _setup(self, tmp_path):
+        self.local_log_location = str(tmp_path / "local/stackdriver/logs")
+        self.io = StackdriverRemoteLogIO(
+            base_log_folder=self.local_log_location,
+            gcp_key_path="KEY_PATH",
+            gcp_log_name="airflow",
+            delete_local_copy=True,
+        )
+
+    
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+    
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.LoggingServiceV2Client")
+    def test_read_logs(self, mock_client, mock_get_creds_and_project_id):
+        mock_client.return_value.list_log_entries.return_value.pages = iter(
+            [_create_list_log_entries_response_mock(["MSG1", "MSG2"], None)]
+        )
+        mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+        ti = mock.MagicMock()
+        ti.task_id = "test_task"
+        ti.dag_id = "test_dag"
+        ti.try_number = 1
+        if AIRFLOW_V_3_0_PLUS:
+            ti.logical_date = timezone.datetime(2016, 1, 1)
+        else:
+            ti.execution_date = timezone.datetime(2016, 1, 1)
+
+        messages, logs = 
self.io.read("dag_id=test_dag/run_id=run1/task_id=test_task/attempt=1.log", ti)
+
+        assert len(messages) == 1
+        assert "Stackdriver" in messages[0]
+        assert logs == ["MSG1\nMSG2"]
+
+    
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+    
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.LoggingServiceV2Client")
+    def test_read_logs_empty(self, mock_client, mock_get_creds_and_project_id):
+        mock_client.return_value.list_log_entries.return_value.pages = iter(
+            [_create_list_log_entries_response_mock([], None)]
+        )
+        mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+        ti = mock.MagicMock()
+        ti.task_id = "test_task"
+        ti.dag_id = "test_dag"
+        ti.try_number = 1
+        if AIRFLOW_V_3_0_PLUS:
+            ti.logical_date = timezone.datetime(2016, 1, 1)
+        else:
+            ti.execution_date = timezone.datetime(2016, 1, 1)
+
+        messages, logs = self.io.read("test/path", ti)
+
+        assert len(messages) == 1
+        assert logs == []
+
+    
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+    
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client")
+    def test_credentials(self, mock_client, mock_get_creds_and_project_id):
+        mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+        _ = self.io._client
+
+        mock_get_creds_and_project_id.assert_called_once_with(
+            disable_logging=True,
+            key_path="KEY_PATH",
+            scopes=frozenset(
+                {
+                    "https://www.googleapis.com/auth/logging.write";,
+                    "https://www.googleapis.com/auth/logging.read";,
+                }
+            ),
+        )
+        mock_client.assert_called_once_with(credentials="creds", 
client_info=mock.ANY, project="project_id")
+
+    
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+    
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client")
+    def test_transport_init(self, mock_client, mock_get_creds_and_project_id):
+        mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+        transport_type = mock.MagicMock()
+        io = StackdriverRemoteLogIO(
+            base_log_folder=self.local_log_location,
+            gcp_log_name="test-log",
+            transport_type=transport_type,
+        )
+        _ = io.transport
+        transport_type.assert_called_once_with(mock_client.return_value, 
"test-log")
+
+    @mock.patch("shutil.rmtree")
+    @mock.patch(
+        
"airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverRemoteLogIO.transport",
+        new_callable=PropertyMock,
+    )
+    def test_upload_flushes_transport_and_deletes_local(self, 
mock_transport_prop, mock_rmtree):
+        io = StackdriverRemoteLogIO(
+            base_log_folder=self.local_log_location,
+            gcp_log_name="airflow",
+            delete_local_copy=True,
+        )
+        mock_transport = mock.MagicMock()
+        mock_transport_prop.return_value = mock_transport
+
+        base = Path(self.local_log_location)
+        base.mkdir(parents=True, exist_ok=True)
+        log_dir = base / "subdir"
+        log_dir.mkdir(parents=True, exist_ok=True)
+        log_file = log_dir / "test.log"
+        log_file.write_text("log content")
+
+        ti = mock.MagicMock()
+        io.upload(str(log_file), ti)
+
+        mock_transport.flush.assert_called_once()
+        mock_rmtree.assert_called_once_with(log_dir.resolve(), 
ignore_errors=True)
+
+    @mock.patch(
+        
"airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverRemoteLogIO.transport",
+        new_callable=PropertyMock,
+    )
+    def test_upload_no_delete(self, mock_transport_prop):
+        io = StackdriverRemoteLogIO(
+            base_log_folder=self.local_log_location,
+            gcp_log_name="airflow",
+            delete_local_copy=False,
+        )
+        mock_transport = mock.MagicMock()
+        mock_transport_prop.return_value = mock_transport
+
+        ti = mock.MagicMock()
+        io.upload("some/path.log", ti)
+
+        mock_transport.flush.assert_called_once()
+
+    
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+    def test_prepare_log_filter(self, mock_get_creds_and_project_id):
+        mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+        ti_labels = {
+            "task_id": "test_task",
+            "dag_id": "test_dag",
+            "try_number": "1",
+        }
+        log_filter = self.io.prepare_log_filter(ti_labels)
+
+        assert 'resource.type="global"' in log_filter
+        assert 'logName="projects/project_id/logs/airflow"' in log_filter
+        assert 'labels.task_id="test_task"' in log_filter
+        assert 'labels.dag_id="test_dag"' in log_filter
+
+    
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+    def test_prepare_log_filter_with_custom_resource(self, 
mock_get_creds_and_project_id):
+        mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+        io = StackdriverRemoteLogIO(
+            base_log_folder=self.local_log_location,
+            gcp_log_name="airflow",
+            resource=Resource(
+                type="cloud_composer_environment",
+                labels={
+                    "environment.name": "test-instance",
+                    "location": "europe-west-3",
+                },
+            ),
+        )
+        log_filter = io.prepare_log_filter({"task_id": "test"})
+
+        assert 'resource.type="cloud_composer_environment"' in log_filter
+        assert 'resource.labels."environment.name"="test-instance"' in 
log_filter
+        assert 'resource.labels.location="europe-west-3"' in log_filter
+
+    @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="airflow.sdk.log only 
exists in Airflow 3+")
+    
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+    
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client")
+    def test_processors_sends_to_transport(self, mock_client, 
mock_get_creds_and_project_id):
+        mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+        mock_transport_type = mock.MagicMock()
+        with mock.patch("airflow.sdk.log.relative_path_from_logger", 
return_value="dag/task/1.log"):
+            io = StackdriverRemoteLogIO(
+                base_log_folder=self.local_log_location,
+                gcp_log_name="airflow",
+                labels={"env": "test"},
+                transport_type=mock_transport_type,
+            )
+            processors = io.processors
+            assert len(processors) == 1
+
+            proc = processors[0]
+            mock_logger = mock.MagicMock()
+
+            event = {
+                "event": "hello world",
+                "logger_name": "airflow.task",
+                "timestamp": "2026-01-15T10:30:00+00:00",
+            }
+            result = proc(mock_logger, "info", event)
+
+        assert result is event
+        mock_transport = mock_transport_type.return_value
+        mock_transport.send.assert_called_once()
+        record = mock_transport.send.call_args[0][0]
+        assert record.levelno == logging.INFO
+
+    @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="airflow.sdk.log only 
exists in Airflow 3+")
+    
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
+    
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client")
+    def test_processors_skips_non_task_logger(self, mock_client, 
mock_get_creds_and_project_id):
+        mock_get_creds_and_project_id.return_value = ("creds", "project_id")
+
+        mock_transport_type = mock.MagicMock()
+        with mock.patch("airflow.sdk.log.relative_path_from_logger", 
return_value=None):
+            io = StackdriverRemoteLogIO(
+                base_log_folder=self.local_log_location,
+                gcp_log_name="airflow",
+                transport_type=mock_transport_type,
+            )
+            proc = io.processors[0]
+
+            event = {"event": "should not be sent"}
+            result = proc(mock.MagicMock(), "info", event)
+
+        assert result is event
+        mock_transport_type.return_value.send.assert_not_called()
+
+
 @pytest.mark.usefixtures("clean_stackdriver_handlers")
 
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
 
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client")
@@ -77,7 +308,6 @@ def test_should_pass_message_to_client(mock_client, 
mock_get_creds_and_project_i
 
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client")
 def test_should_use_configured_log_name(mock_client, 
mock_get_creds_and_project_id):
     import importlib
-    import logging
 
     from airflow import settings
     from airflow.config_templates import airflow_local_settings
@@ -85,9 +315,6 @@ def test_should_use_configured_log_name(mock_client, 
mock_get_creds_and_project_
     mock_get_creds_and_project_id.return_value = ("creds", "project_id")
 
     try:
-        # this is needed for Airflow 2.8 and below where default settings are 
triggering warning on
-        # extra "name" in the configuration of stackdriver handler. As of 
Airflow 2.9 this warning is not
-        # emitted.
         context_manager = nullcontext()
         with context_manager:
             with conf_vars(
@@ -99,12 +326,17 @@ def test_should_use_configured_log_name(mock_client, 
mock_get_creds_and_project_
                 importlib.reload(airflow_local_settings)
                 settings.configure_logging()
 
+                task_log = getattr(airflow_local_settings, "REMOTE_TASK_LOG", 
None)
+                if task_log is not None:
+                    # Airflow 3+ uses REMOTE_TASK_LOG instead of handler-based 
config
+                    assert isinstance(task_log, StackdriverRemoteLogIO)
+                    assert task_log.gcp_log_name == "path"
+                    return
+
+                # Older Airflow: stackdriver is wired as a logging handler
                 logger = logging.getLogger("airflow.task")
                 handler = logger.handlers[0]
                 assert isinstance(handler, StackdriverTaskHandler)
-                with mock.patch.object(handler, "transport_type") as 
transport_type_mock:
-                    logger.error("foo")
-                    
transport_type_mock.assert_called_once_with(mock_client.return_value, "path")
     finally:
         importlib.reload(airflow_local_settings)
         settings.configure_logging()
@@ -398,7 +630,7 @@ class TestStackdriverLoggingHandlerTask:
         mock_get_creds_and_project_id.return_value = ("creds", "project_id")
 
         stackdriver_task_handler = 
StackdriverTaskHandler(gcp_key_path="KEY_PATH")
-        client = stackdriver_task_handler._client
+        client = stackdriver_task_handler.io._client
 
         mock_get_creds_and_project_id.assert_called_once_with(
             disable_logging=True,


Reply via email to