villebro commented on code in PR #36368: URL: https://github.com/apache/superset/pull/36368#discussion_r2775052889
########## superset/tasks/context.py: ########## @@ -0,0 +1,673 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Concrete TaskContext implementation for GTF""" + +import logging +import threading +import time +import traceback +from typing import Any, Callable, cast, TYPE_CHECKING, TypeVar + +from flask import current_app +from superset_core.api.tasks import ( + TaskContext as CoreTaskContext, + TaskProperties, + TaskStatus, +) + +from superset.stats_logger import BaseStatsLogger +from superset.tasks.constants import ABORT_STATES +from superset.tasks.utils import progress_update + +if TYPE_CHECKING: + from superset.models.tasks import Task + from superset.tasks.manager import AbortListener + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class TaskContext(CoreTaskContext): + """ + Concrete implementation of TaskContext for the Global Async Task Framework. + + Provides write-only access to task state. Tasks use this context to update + their progress and payload, and check for cancellation. Tasks should not + need to read their own state - they are the source of state, not consumers. + """ + + # Type alias for handler failures: (handler_type, exception, stack_trace) + HandlerFailure = tuple[str, Exception, str] + + def __init__(self, task: "Task") -> None: + """ + Initialize TaskContext with a pre-fetched task entity. + + The task entity must be pre-fetched by the caller (executor) to ensure + caching works correctly and to enforce the pattern of single initial fetch. + + :param task: Pre-fetched Task entity (required) + """ + self._task_uuid = task.uuid + self._cleanup_handlers: list[Callable[[], None]] = [] + self._abort_handlers: list[Callable[[], None]] = [] + self._abort_listener: "AbortListener | None" = None + self._abort_detected = False + self._abort_handlers_completed = False # Track if all abort handlers finished + self._execution_completed = False # Set by executor after task work completes + + # Collected handler failures for unified reporting + self._handler_failures: list[TaskContext.HandlerFailure] = [] + + # Timeout timer state + self._timeout_timer: threading.Timer | None = None + self._timeout_triggered = False + + # Throttling state for update_task() + # These manage the minimum interval between DB writes + self._last_db_write_time: float | None = None + self._has_pending_updates: bool = False + self._deferred_flush_timer: threading.Timer | None = None + self._throttle_lock = threading.Lock() + + # Cached task entity - avoids repeated DB fetches. + # Updated only by _refresh_task() when checking external state changes. + self._task: "Task" = task + + # In-memory state caches - authoritative during execution + # These are initialized from the task entity and updated locally + # before being written to DB via targeted SQL updates. + # We copy the dicts to avoid mutating the Task's cached instances. + self._properties_cache: TaskProperties = cast( + TaskProperties, {**task.properties_dict} + ) + self._payload_cache: dict[str, Any] = {**task.payload_dict} + + # Store Flask app reference for background thread database access + # Use _get_current_object() to get actual app, not proxy + try: + self._app = current_app._get_current_object() + # Cache stats logger to avoid repeated config lookups + self._stats_logger: BaseStatsLogger = current_app.config.get( + "STATS_LOGGER", BaseStatsLogger() + ) + except RuntimeError: + # Handle case where app context isn't available (e.g., tests) + self._app = None + self._stats_logger = BaseStatsLogger() + + def _refresh_task(self) -> "Task": + """ + Force refresh the task entity from the database. + + Use this method when you need to check for external state changes, + such as whether the task has been aborted by a concurrent operation. + + This method: + - Fetches fresh task entity from database + - Updates the cached _task reference + - Updates properties/payload caches from fresh data + + :returns: Fresh task entity from database + :raises ValueError: If task is not found + """ + from superset.daos.tasks import TaskDAO + + fresh_task = TaskDAO.find_one_or_none(uuid=self._task_uuid) + if not fresh_task: + raise ValueError(f"Task {self._task_uuid} not found") + + self._task = fresh_task + + # Update caches from fresh data (copy to avoid mutating Task's cache) + self._properties_cache = cast(TaskProperties, {**fresh_task.properties_dict}) + self._payload_cache = {**fresh_task.payload_dict} + + return self._task + + def update_task( + self, + progress: float | int | tuple[int, int] | None = None, + payload: dict[str, object] | None = None, + ) -> None: + """ + Update task progress and/or payload atomically. + + All parameters are optional. Payload is merged with existing cached data. + In-memory caches are always updated immediately, but DB writes are + throttled according to TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL to prevent + excessive database load from eager tasks. + + Progress can be specified in three ways: + - float (0.0-1.0): Percentage only, e.g., 0.5 means 50% + - int: Count only (total unknown), e.g., 42 means "42 items processed" + - tuple[int, int]: Count and total, e.g., (3, 100) means "3 of 100" + The percentage is automatically computed from count/total. + + :param progress: Progress value, or None to leave unchanged + :param payload: Payload data to merge (dict), or None to leave unchanged + """ + has_updates = False + + # Handle progress updates - always update in-memory cache + if progress is not None: + progress_props = progress_update(progress) + if progress_props: + # Merge progress into cached properties + self._properties_cache.update(progress_props) + has_updates = True + else: + # Invalid progress format - progress_update returns empty dict + logger.warning( + "Invalid progress value for task %s: %s " + "(expected float, int, or tuple[int, int])", + self._task_uuid, + progress, + ) + + # Handle payload updates - always update in-memory cache + if payload is not None: + # Merge payload into cached payload + self._payload_cache.update(payload) + has_updates = True + + if not has_updates: + return + + # Get throttle interval from config + throttle_interval = current_app.config["TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL"] Review Comment: We hav a convention to assume all `config` keys are defined. So `[]` is the convention, not `.get(<key>, <default>)`. ########## superset/tasks/context.py: ########## @@ -0,0 +1,673 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Concrete TaskContext implementation for GTF""" + +import logging +import threading +import time +import traceback +from typing import Any, Callable, cast, TYPE_CHECKING, TypeVar + +from flask import current_app +from superset_core.api.tasks import ( + TaskContext as CoreTaskContext, + TaskProperties, + TaskStatus, +) + +from superset.stats_logger import BaseStatsLogger +from superset.tasks.constants import ABORT_STATES +from superset.tasks.utils import progress_update + +if TYPE_CHECKING: + from superset.models.tasks import Task + from superset.tasks.manager import AbortListener + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class TaskContext(CoreTaskContext): + """ + Concrete implementation of TaskContext for the Global Async Task Framework. + + Provides write-only access to task state. Tasks use this context to update + their progress and payload, and check for cancellation. Tasks should not + need to read their own state - they are the source of state, not consumers. + """ + + # Type alias for handler failures: (handler_type, exception, stack_trace) + HandlerFailure = tuple[str, Exception, str] + + def __init__(self, task: "Task") -> None: + """ + Initialize TaskContext with a pre-fetched task entity. + + The task entity must be pre-fetched by the caller (executor) to ensure + caching works correctly and to enforce the pattern of single initial fetch. + + :param task: Pre-fetched Task entity (required) + """ + self._task_uuid = task.uuid + self._cleanup_handlers: list[Callable[[], None]] = [] + self._abort_handlers: list[Callable[[], None]] = [] + self._abort_listener: "AbortListener | None" = None + self._abort_detected = False + self._abort_handlers_completed = False # Track if all abort handlers finished + self._execution_completed = False # Set by executor after task work completes + + # Collected handler failures for unified reporting + self._handler_failures: list[TaskContext.HandlerFailure] = [] + + # Timeout timer state + self._timeout_timer: threading.Timer | None = None + self._timeout_triggered = False + + # Throttling state for update_task() + # These manage the minimum interval between DB writes + self._last_db_write_time: float | None = None + self._has_pending_updates: bool = False + self._deferred_flush_timer: threading.Timer | None = None + self._throttle_lock = threading.Lock() + + # Cached task entity - avoids repeated DB fetches. + # Updated only by _refresh_task() when checking external state changes. + self._task: "Task" = task + + # In-memory state caches - authoritative during execution + # These are initialized from the task entity and updated locally + # before being written to DB via targeted SQL updates. + # We copy the dicts to avoid mutating the Task's cached instances. + self._properties_cache: TaskProperties = cast( + TaskProperties, {**task.properties_dict} + ) + self._payload_cache: dict[str, Any] = {**task.payload_dict} + + # Store Flask app reference for background thread database access + # Use _get_current_object() to get actual app, not proxy + try: + self._app = current_app._get_current_object() + # Cache stats logger to avoid repeated config lookups + self._stats_logger: BaseStatsLogger = current_app.config.get( + "STATS_LOGGER", BaseStatsLogger() + ) + except RuntimeError: + # Handle case where app context isn't available (e.g., tests) + self._app = None + self._stats_logger = BaseStatsLogger() + + def _refresh_task(self) -> "Task": + """ + Force refresh the task entity from the database. + + Use this method when you need to check for external state changes, + such as whether the task has been aborted by a concurrent operation. + + This method: + - Fetches fresh task entity from database + - Updates the cached _task reference + - Updates properties/payload caches from fresh data + + :returns: Fresh task entity from database + :raises ValueError: If task is not found + """ + from superset.daos.tasks import TaskDAO + + fresh_task = TaskDAO.find_one_or_none(uuid=self._task_uuid) + if not fresh_task: + raise ValueError(f"Task {self._task_uuid} not found") + + self._task = fresh_task + + # Update caches from fresh data (copy to avoid mutating Task's cache) + self._properties_cache = cast(TaskProperties, {**fresh_task.properties_dict}) + self._payload_cache = {**fresh_task.payload_dict} + + return self._task + + def update_task( + self, + progress: float | int | tuple[int, int] | None = None, + payload: dict[str, object] | None = None, + ) -> None: + """ + Update task progress and/or payload atomically. + + All parameters are optional. Payload is merged with existing cached data. + In-memory caches are always updated immediately, but DB writes are + throttled according to TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL to prevent + excessive database load from eager tasks. + + Progress can be specified in three ways: + - float (0.0-1.0): Percentage only, e.g., 0.5 means 50% + - int: Count only (total unknown), e.g., 42 means "42 items processed" + - tuple[int, int]: Count and total, e.g., (3, 100) means "3 of 100" + The percentage is automatically computed from count/total. + + :param progress: Progress value, or None to leave unchanged + :param payload: Payload data to merge (dict), or None to leave unchanged + """ + has_updates = False + + # Handle progress updates - always update in-memory cache + if progress is not None: + progress_props = progress_update(progress) + if progress_props: + # Merge progress into cached properties + self._properties_cache.update(progress_props) + has_updates = True + else: + # Invalid progress format - progress_update returns empty dict + logger.warning( + "Invalid progress value for task %s: %s " + "(expected float, int, or tuple[int, int])", + self._task_uuid, + progress, + ) + + # Handle payload updates - always update in-memory cache + if payload is not None: + # Merge payload into cached payload + self._payload_cache.update(payload) + has_updates = True + + if not has_updates: + return + + # Get throttle interval from config + throttle_interval = current_app.config["TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL"] Review Comment: We have a convention to assume all `config` keys are defined. So `[]` is the convention, not `.get(<key>, <default>)`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
