michael-s-molina commented on code in PR #36368: URL: https://github.com/apache/superset/pull/36368#discussion_r2774194513
########## superset/tasks/context.py: ########## @@ -0,0 +1,673 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Concrete TaskContext implementation for GTF""" + +import logging +import threading +import time +import traceback +from typing import Any, Callable, cast, TYPE_CHECKING, TypeVar + +from flask import current_app +from superset_core.api.tasks import ( + TaskContext as CoreTaskContext, + TaskProperties, + TaskStatus, +) + +from superset.stats_logger import BaseStatsLogger +from superset.tasks.constants import ABORT_STATES +from superset.tasks.utils import progress_update + +if TYPE_CHECKING: + from superset.models.tasks import Task + from superset.tasks.manager import AbortListener + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class TaskContext(CoreTaskContext): + """ + Concrete implementation of TaskContext for the Global Async Task Framework. + + Provides write-only access to task state. Tasks use this context to update + their progress and payload, and check for cancellation. Tasks should not + need to read their own state - they are the source of state, not consumers. + """ + + # Type alias for handler failures: (handler_type, exception, stack_trace) + HandlerFailure = tuple[str, Exception, str] + + def __init__(self, task: "Task") -> None: + """ + Initialize TaskContext with a pre-fetched task entity. + + The task entity must be pre-fetched by the caller (executor) to ensure + caching works correctly and to enforce the pattern of single initial fetch. + + :param task: Pre-fetched Task entity (required) + """ + self._task_uuid = task.uuid + self._cleanup_handlers: list[Callable[[], None]] = [] + self._abort_handlers: list[Callable[[], None]] = [] + self._abort_listener: "AbortListener | None" = None + self._abort_detected = False + self._abort_handlers_completed = False # Track if all abort handlers finished + self._execution_completed = False # Set by executor after task work completes + + # Collected handler failures for unified reporting + self._handler_failures: list[TaskContext.HandlerFailure] = [] + + # Timeout timer state + self._timeout_timer: threading.Timer | None = None + self._timeout_triggered = False + + # Throttling state for update_task() + # These manage the minimum interval between DB writes + self._last_db_write_time: float | None = None + self._has_pending_updates: bool = False + self._deferred_flush_timer: threading.Timer | None = None + self._throttle_lock = threading.Lock() + + # Cached task entity - avoids repeated DB fetches. + # Updated only by _refresh_task() when checking external state changes. + self._task: "Task" = task + + # In-memory state caches - authoritative during execution + # These are initialized from the task entity and updated locally + # before being written to DB via targeted SQL updates. + # We copy the dicts to avoid mutating the Task's cached instances. + self._properties_cache: TaskProperties = cast( + TaskProperties, {**task.properties_dict} + ) + self._payload_cache: dict[str, Any] = {**task.payload_dict} + + # Store Flask app reference for background thread database access + # Use _get_current_object() to get actual app, not proxy + try: + self._app = current_app._get_current_object() + # Cache stats logger to avoid repeated config lookups + self._stats_logger: BaseStatsLogger = current_app.config.get( + "STATS_LOGGER", BaseStatsLogger() + ) + except RuntimeError: + # Handle case where app context isn't available (e.g., tests) + self._app = None + self._stats_logger = BaseStatsLogger() + + def _refresh_task(self) -> "Task": + """ + Force refresh the task entity from the database. + + Use this method when you need to check for external state changes, + such as whether the task has been aborted by a concurrent operation. + + This method: + - Fetches fresh task entity from database + - Updates the cached _task reference + - Updates properties/payload caches from fresh data + + :returns: Fresh task entity from database + :raises ValueError: If task is not found + """ + from superset.daos.tasks import TaskDAO + + fresh_task = TaskDAO.find_one_or_none(uuid=self._task_uuid) + if not fresh_task: + raise ValueError(f"Task {self._task_uuid} not found") + + self._task = fresh_task + + # Update caches from fresh data (copy to avoid mutating Task's cache) + self._properties_cache = cast(TaskProperties, {**fresh_task.properties_dict}) + self._payload_cache = {**fresh_task.payload_dict} + + return self._task + + def update_task( + self, + progress: float | int | tuple[int, int] | None = None, + payload: dict[str, object] | None = None, + ) -> None: + """ + Update task progress and/or payload atomically. + + All parameters are optional. Payload is merged with existing cached data. + In-memory caches are always updated immediately, but DB writes are + throttled according to TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL to prevent + excessive database load from eager tasks. + + Progress can be specified in three ways: + - float (0.0-1.0): Percentage only, e.g., 0.5 means 50% + - int: Count only (total unknown), e.g., 42 means "42 items processed" + - tuple[int, int]: Count and total, e.g., (3, 100) means "3 of 100" + The percentage is automatically computed from count/total. + + :param progress: Progress value, or None to leave unchanged + :param payload: Payload data to merge (dict), or None to leave unchanged + """ + has_updates = False + + # Handle progress updates - always update in-memory cache + if progress is not None: + progress_props = progress_update(progress) + if progress_props: + # Merge progress into cached properties + self._properties_cache.update(progress_props) + has_updates = True + else: + # Invalid progress format - progress_update returns empty dict + logger.warning( + "Invalid progress value for task %s: %s " + "(expected float, int, or tuple[int, int])", + self._task_uuid, + progress, + ) + + # Handle payload updates - always update in-memory cache + if payload is not None: + # Merge payload into cached payload + self._payload_cache.update(payload) + has_updates = True + + if not has_updates: + return + + # Get throttle interval from config + throttle_interval = current_app.config["TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL"] Review Comment: `current_app.config["TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL"]` uses direct key access which raises `KeyError` if missing. **Suggested fix:** Use `.get()` with a default: `current_app.config.get("TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL", 2)` ########## superset/migrations/versions/2025_12_18_0220_create_tasks_table.py: ########## @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Create tasks and task_subscriber tables for Global Task Framework (GTF) + +Revision ID: 4b2a8c9d3e1f +Revises: 9787190b3d89 +Create Date: 2025-12-18 02:20:00.000000 + +""" + +from sqlalchemy import ( + Column, + DateTime, + Integer, + String, + Text, + UniqueConstraint, +) + +from superset.migrations.shared.utils import ( + create_fks_for_table, + create_index, + create_table, + drop_fks_for_table, + drop_index, + drop_table, +) + +# revision identifiers, used by Alembic. +revision = "4b2a8c9d3e1f" +down_revision = "9787190b3d89" + +TASKS_TABLE = "tasks" +TASK_SUBSCRIBERS_TABLE = "task_subscribers" + + +def upgrade(): + """ + Create tasks and task_subscribers tables for the Global Task Framework (GTF). + + This migration creates: + 1. tasks table - unified tracking for all long running tasks + 2. task_subscribers table - multi-user task subscriptions for shared tasks + + The scope feature allows tasks to be: + - private: user-specific (default) + - shared: multi-user collaborative tasks + - system: admin-only background tasks + """ + # Create tasks table + create_table( + TASKS_TABLE, + Column("id", Integer, primary_key=True), + Column("uuid", String(36), nullable=False, unique=True), + Column("task_key", String(256), nullable=False), Review Comment: The model defines `task_key` with `index=True`, but the migration doesn't create an explicit index for `task_key`. Only `dedup_key` gets a unique index. This could cause performance issues on task lookups by key. **Suggested fix:** Add `create_index(TASKS_TABLE, "idx_tasks_task_key", ["task_key"])` to the migration. ########## superset/tasks/api.py: ########## @@ -0,0 +1,494 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Task REST API""" + +import logging +import uuid +from typing import TYPE_CHECKING + +from flask import Response +from flask_appbuilder.api import expose, protect, safe +from flask_appbuilder.models.sqla.interface import SQLAInterface + +from superset.commands.tasks.cancel import CancelTaskCommand +from superset.commands.tasks.exceptions import ( + TaskAbortFailedError, + TaskForbiddenError, + TaskNotAbortableError, + TaskNotFoundError, + TaskPermissionDeniedError, +) +from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod +from superset.extensions import event_logger +from superset.models.tasks import Task +from superset.tasks.filters import TaskFilter +from superset.tasks.schemas import ( + openapi_spec_methods_override, + TaskCancelRequestSchema, + TaskCancelResponseSchema, + TaskResponseSchema, + TaskStatusResponseSchema, +) +from superset.views.base_api import ( + BaseSupersetModelRestApi, + RelatedFieldFilter, + statsd_metrics, +) +from superset.views.filters import BaseFilterRelatedUsers, FilterRelatedOwners + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +def _is_valid_uuid(value: str) -> bool: + """Check if a string is a valid UUID format.""" + try: + uuid.UUID(value) + return True + except (ValueError, AttributeError): + return False + + +class TaskRestApi(BaseSupersetModelRestApi): + """REST API for task management""" + + datamodel = SQLAInterface(Task) + resource_name = "task" + allow_browser_login = True + + class_permission_name = "Task" + + # Map cancel and status to write/read permissions + method_permission_name = { + **MODEL_API_RW_METHOD_PERMISSION_MAP, + "cancel": "write", + "status": "read", + } + + include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | { + "cancel", + "status", + "related_subscribers", + "related", + } + + list_columns = [ + "id", + "uuid", + "task_type", + "task_key", + "task_name", + "scope", + "status", + "created_on", + "created_on_delta_humanized", + "changed_on", + "changed_by.first_name", + "changed_by.last_name", + "started_at", + "ended_at", + "created_by.id", + "created_by.first_name", + "created_by.last_name", + "user_id", + "payload", + "properties", + "duration_seconds", + "subscriber_count", + "subscribers", + ] + + list_select_columns = list_columns + ["created_by_fk", "changed_by_fk"] + + show_columns = list_columns + + order_columns = [ + "task_type", + "scope", + "status", + "created_on", + "changed_on", + "started_at", + "ended_at", + ] + + search_columns = [ + "task_type", + "task_key", + "task_name", + "scope", + "status", + "created_by", + "created_on", + ] + + base_order = ("created_on", "desc") + base_filters = [["id", TaskFilter, lambda: []]] + + # Related field configuration for filter dropdowns + allowed_rel_fields = {"created_by"} + related_field_filters = { + "created_by": RelatedFieldFilter("first_name", FilterRelatedOwners), + } + base_related_field_filters = { + "created_by": [["id", BaseFilterRelatedUsers, lambda: []]], + } + + show_model_schema = TaskResponseSchema() + list_model_schema = TaskResponseSchema() + cancel_request_schema = TaskCancelRequestSchema() + + openapi_spec_tag = "Tasks" + openapi_spec_component_schemas = ( + TaskResponseSchema, + TaskCancelRequestSchema, + TaskCancelResponseSchema, + TaskStatusResponseSchema, + ) + openapi_spec_methods = openapi_spec_methods_override + + @expose("/<uuid_or_id>", methods=("GET",)) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get", + log_to_statsd=False, + ) + def get(self, uuid_or_id: str) -> Response: + """Get a task. + --- + get: + summary: Get a task + parameters: + - in: path + schema: + type: string + name: uuid_or_id + description: The UUID or ID of the task + responses: + 200: + description: Task detail + content: + application/json: + schema: + type: object + properties: + result: + $ref: '#/components/schemas/TaskResponseSchema' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + """ + from superset.daos.tasks import TaskDAO + + try: + # Try to find by UUID first, then by ID + if _is_valid_uuid(uuid_or_id): + task = TaskDAO.find_one_or_none(uuid=uuid_or_id) + else: + task = TaskDAO.find_by_id(int(uuid_or_id)) + + if not task: + return self.response_404() + + result = self.show_model_schema.dump(task) + return self.response(200, result=result) + except (ValueError, TypeError): + return self.response_404() + + @expose("/<uuid_or_id>/status", methods=("GET",)) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.status", + log_to_statsd=False, + ) + def status(self, uuid_or_id: str) -> Response: + """Get only the status of a task (lightweight for polling). + --- + get: + summary: Get task status + parameters: + - in: path + schema: + type: string + name: uuid_or_id + description: The UUID or ID of the task + responses: + 200: + description: Task status + content: + application/json: + schema: + type: object + properties: + status: + type: string + description: Current status of the task + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + """ + from superset.daos.tasks import TaskDAO + + try: + # Try to find by UUID first, then by ID + if _is_valid_uuid(uuid_or_id): + task = TaskDAO.find_one_or_none(uuid=uuid_or_id) + else: + task = TaskDAO.find_by_id(int(uuid_or_id)) + + if not task: + return self.response_404() + + return self.response(200, status=task.status) + except (ValueError, TypeError): + return self.response_404() + + @expose("/<uuid_or_id>/cancel", methods=("POST",)) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.cancel", + log_to_statsd=False, + ) + def cancel(self, uuid_or_id: str) -> Response: + """Cancel a task. + --- + post: + summary: Cancel a task + description: > + Cancel a task. The behavior depends on task scope and subscriber + count: + + - **Private tasks**: Aborts the task + - **Shared tasks (single subscriber)**: Aborts the task + - **Shared tasks (multiple subscribers)**: Removes current user's + subscription; the task continues for other subscribers + - **Shared tasks with force=true (admin only)**: Aborts task for + all subscribers + + The `action` field in the response indicates what happened: + - `aborted`: Task was terminated + - `unsubscribed`: User was removed from task (task continues) + parameters: + - in: path + schema: + type: string + name: uuid_or_id + description: The UUID or ID of the task to cancel + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/TaskCancelRequestSchema' + responses: + 200: + description: Task cancelled successfully + content: + application/json: + schema: + $ref: '#/components/schemas/TaskCancelResponseSchema' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + """ + return self._execute_cancel(uuid_or_id) + + def _execute_cancel(self, uuid_or_id: str) -> Response: + """Execute the cancel operation with error handling.""" + try: + task_uuid = self._resolve_task_uuid(uuid_or_id) + if task_uuid is None: + return self.response_404() + + command, updated_task = self._run_cancel_command(task_uuid) + return self._build_cancel_response(command, updated_task) + + except TaskNotFoundError: + return self.response_404() + except (TaskForbiddenError, TaskPermissionDeniedError) as ex: + if isinstance(ex, TaskPermissionDeniedError): + logger.warning( + "Permission denied cancelling task %s: %s", + uuid_or_id, + str(ex), + ) + return self.response_403() + except TaskNotAbortableError as ex: + logger.warning("Task %s is not cancellable: %s", uuid_or_id, str(ex)) + return self.response_422(message=str(ex)) + except TaskAbortFailedError as ex: + logger.error( + "Error cancelling task %s: %s", uuid_or_id, str(ex), exc_info=True + ) + return self.response_422(message=str(ex)) + except (ValueError, TypeError): + return self.response_404() + + def _resolve_task_uuid(self, uuid_or_id: str) -> str | None: + """Resolve a UUID or ID to a task UUID.""" + from superset.daos.tasks import TaskDAO + + if _is_valid_uuid(uuid_or_id): + return uuid_or_id + + task = TaskDAO.find_by_id(uuid_or_id) + return task.uuid if task else None + + def _run_cancel_command(self, task_uuid: str) -> tuple[CancelTaskCommand, "Task"]: + """Parse request and run the cancel command.""" + from flask import request + + force = False + # Use get_json with silent=True to handle missing Content-Type gracefully + json_data = request.get_json(silent=True) + if json_data: + parsed = self.cancel_request_schema.load(json_data) + force = parsed.get("force", False) + + command = CancelTaskCommand(task_uuid, force=force) + updated_task = command.run() + return command, updated_task + + def _build_cancel_response( + self, command: CancelTaskCommand, updated_task: "Task" + ) -> Response: + """Build the response for a successful cancel operation.""" + action = command.action_taken + message = ( + "Task cancelled" + if action == "aborted" + else "You have been removed from this task" + ) + result = { + "message": message, + "action": action, + "task": self.show_model_schema.dump(updated_task), + } + return self.response(200, **result) + + @expose("/related/subscribers", methods=("GET",)) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + ".related_subscribers", + log_to_statsd=False, + ) + def related_subscribers(self) -> Response: + """Get users who are subscribers to tasks. + --- + get: + summary: Get related subscribers + description: > + Returns a list of users who are subscribed to tasks, for use in filter + dropdowns. Results can be filtered by a search query parameter. + parameters: + - in: query + schema: + type: string + name: q + description: Search query to filter subscribers by name + responses: + 200: + description: List of subscribers + content: + application/json: + schema: + type: object + properties: + count: + type: integer + description: Total number of matching subscribers + result: + type: array + items: + type: object + properties: + value: + type: integer + description: User ID + text: + type: string + description: User display name + 401: + $ref: '#/components/responses/401' + """ + from flask import request + + from superset import db, security_manager + from superset.models.task_subscribers import TaskSubscriber + + # Get search query + + # Get user model + user_model = security_manager.user_model + + # Query distinct users who are task subscribers + query = ( + db.session.query(user_model.id, user_model.first_name, user_model.last_name) + .join(TaskSubscriber, user_model.id == TaskSubscriber.user_id) + .distinct() + ) + + # Apply search filter if provided + if search_query := request.args.get("q", ""): Review Comment: The `search_query` from `request.args.get("q", "")` is used in a LIKE pattern without escaping `%` and `_` characters. While SQLAlchemy's `ilike` prevents SQL injection, users can craft patterns like `%admin%` to match unintended results. **Suggested fix:** Escape LIKE special characters before building the pattern: ```python search_query = search_query.replace("%", r"\%").replace("_", r"\_") ``` ########## superset/models/tasks.py: ########## @@ -0,0 +1,364 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Task model for Global Task Framework (GTF)""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from typing import Any, cast + +from flask_appbuilder import Model +from sqlalchemy import ( + Column, + DateTime, + Integer, + String, + Text, +) +from sqlalchemy.orm import relationship +from superset_core.api.models import Task as CoreTask +from superset_core.api.tasks import TaskProperties, TaskStatus + +from superset.models.helpers import AuditMixinNullable +from superset.models.task_subscribers import TaskSubscriber +from superset.tasks.utils import ( + error_update, + get_finished_dedup_key, + parse_properties, + serialize_properties, +) +from superset.utils import json + + +class Task(CoreTask, AuditMixinNullable, Model): + """ + Concrete Task model for the Global Task Framework (GTF). + + This model represents async tasks in Superset, providing unified tracking + for all background operations including SQL queries, thumbnail generation, + reports, and other async operations. + + Non-filterable fields (progress, error info, execution config) are stored + in a `properties` JSON blob for schema flexibility. + """ + + __tablename__ = "tasks" + + # Primary key and identifiers + id = Column(Integer, primary_key=True) + uuid = Column( + String(36), nullable=False, unique=True, default=lambda: str(uuid.uuid4()) + ) + + # Task metadata (filterable) + task_key = Column(String(256), nullable=False, index=True) # For deduplication + task_type = Column(String(100), nullable=False, index=True) # e.g., 'sql_execution' + task_name = Column(String(256), nullable=True) # Human readable name + scope = Column( + String(20), nullable=False, index=True, default="private" + ) # private/shared/system + status = Column( + String(50), nullable=False, index=True, default=TaskStatus.PENDING.value + ) + dedup_key = Column( + String(64), nullable=False, unique=True, index=True + ) # Hashed deduplication key (SHA-256 = 64 chars, UUID = 36 chars) + + # Timestamps + started_at = Column(DateTime, nullable=True) + ended_at = Column(DateTime, nullable=True) + + # User context for execution + user_id = Column(Integer, nullable=True) + + # Task-specific output data (set by task code via ctx.update_task(payload=...)) + payload = Column(Text, nullable=True, default="{}") + + # Properties JSON blob - contains runtime state and execution config: + # - is_abortable: bool - has abort handler registered + # - progress_percent: float - progress 0.0-1.0 + # - progress_current: int - current iteration count + # - progress_total: int - total iterations + # - error_message: str - human-readable error message + # - exception_type: str - exception class name + # - stack_trace: str - full formatted traceback + # - timeout: int - timeout in seconds + properties = Column(Text, nullable=True, default="{}") + + # Relationships + subscribers = relationship( Review Comment: When listing tasks, the `subscribers` relationship is lazy-loaded: ```python subscribers = relationship( TaskSubscriber, back_populates="task", cascade="all, delete-orphan", ) # Default lazy loading ``` **Impact:** For a list of 25 tasks, this causes 26 queries (1 for tasks + 25 for subscribers). **Suggested fix:** Use `selectinload(Task.subscribers)` for list operations. ########## superset/tasks/api.py: ########## @@ -0,0 +1,494 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Task REST API""" + +import logging +import uuid +from typing import TYPE_CHECKING + +from flask import Response +from flask_appbuilder.api import expose, protect, safe +from flask_appbuilder.models.sqla.interface import SQLAInterface + +from superset.commands.tasks.cancel import CancelTaskCommand +from superset.commands.tasks.exceptions import ( + TaskAbortFailedError, + TaskForbiddenError, + TaskNotAbortableError, + TaskNotFoundError, + TaskPermissionDeniedError, +) +from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod +from superset.extensions import event_logger +from superset.models.tasks import Task +from superset.tasks.filters import TaskFilter +from superset.tasks.schemas import ( + openapi_spec_methods_override, + TaskCancelRequestSchema, + TaskCancelResponseSchema, + TaskResponseSchema, + TaskStatusResponseSchema, +) +from superset.views.base_api import ( + BaseSupersetModelRestApi, + RelatedFieldFilter, + statsd_metrics, +) +from superset.views.filters import BaseFilterRelatedUsers, FilterRelatedOwners + +if TYPE_CHECKING: + pass + Review Comment: ```suggestion ``` ########## superset-frontend/src/pages/TaskList/index.tsx: ########## @@ -0,0 +1,658 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { + FeatureFlag, + isFeatureEnabled, + SupersetClient, +} from '@superset-ui/core'; +import { t, useTheme } from '@apache-superset/core'; +import { useMemo, useCallback, useState } from 'react'; +import { Tooltip, Label, Modal, Checkbox } from '@superset-ui/core/components'; +import { + CreatedInfo, + ListView, + ListViewFilterOperator as FilterOperator, + type ListViewFilters, + FacePile, +} from 'src/components'; +import { Icons } from '@superset-ui/core/components/Icons'; +import withToasts from 'src/components/MessageToasts/withToasts'; +import SubMenu from 'src/features/home/SubMenu'; +import { useListViewResource } from 'src/views/CRUD/hooks'; +import { createErrorHandler, createFetchRelated } from 'src/views/CRUD/utils'; +import TaskStatusIcon from 'src/features/tasks/TaskStatusIcon'; +import TaskPayloadPopover from 'src/features/tasks/TaskPayloadPopover'; +import TaskStackTracePopover from 'src/features/tasks/TaskStackTracePopover'; +import { formatDuration } from 'src/features/tasks/timeUtils'; +import { + Task, + TaskStatus, + TaskScope, + canAbortTask, + isTaskAborting, + TaskSubscriber, +} from 'src/features/tasks/types'; +import { isUserAdmin } from 'src/dashboard/util/permissionUtils'; +import getBootstrapData from 'src/utils/getBootstrapData'; + +const PAGE_SIZE = 25; + +/** + * Typed cell props for react-table columns. + * Replaces `: any` for better type safety in Cell render functions. + */ +interface TaskCellProps { + row: { + original: Task; + }; +} + +interface TaskListProps { + addDangerToast: (msg: string) => void; + addSuccessToast: (msg: string) => void; + user: { + userId: string | number; + firstName: string; + lastName: string; + }; +} + +function TaskList({ addDangerToast, addSuccessToast, user }: TaskListProps) { + const theme = useTheme(); + + // Check if GTF feature flag is enabled + if (!isFeatureEnabled(FeatureFlag.GlobalTaskFramework)) { + return ( + <> + <SubMenu name={t('Tasks')} /> + <div + style={{ + display: 'flex', + flexDirection: 'column', + alignItems: 'center', + justifyContent: 'center', + height: '50vh', + color: theme.colorTextSecondary, + }} + > + <h3>{t('Feature Not Enabled')}</h3> + <p> + {t( + 'The Global Task Framework is not enabled. Please contact your administrator to enable the GLOBAL_TASK_FRAMEWORK feature flag.', + )} + </p> + </div> + </> + ); + } + + const { + state: { loading, resourceCount: tasksCount, resourceCollection: tasks }, + fetchData, + refreshData, + } = useListViewResource<Task>('task', t('task'), addDangerToast); + + // Get full user with roles to check admin status + const bootstrapData = getBootstrapData(); + const fullUser = bootstrapData?.user; + const isAdmin = useMemo(() => isUserAdmin(fullUser), [fullUser]); + + // State for cancel confirmation modal + const [cancelModalTask, setCancelModalTask] = useState<Task | null>(null); + const [forceCancel, setForceCancel] = useState(false); + + // Determine dialog message based on task context + const getCancelDialogMessage = useCallback((task: Task) => { + const isSharedTask = task.scope === TaskScope.Shared; + const subscriberCount = task.subscriber_count || 0; + const otherSubscribers = subscriberCount - 1; + + // If it's going to abort (private, system, or last subscriber) + if (!isSharedTask || subscriberCount <= 1) { + return t('This will cancel the task.'); + } + + // Shared task with multiple subscribers + return t( + "You'll be removed from this task. It will continue running for %s other subscriber(s).", + otherSubscribers, + ); + }, []); + + // Get force abort message for admin checkbox + const getForceAbortMessage = useCallback((task: Task) => { + const subscriberCount = task.subscriber_count || 0; + return t( + 'This will abort (stop) the task for all %s subscriber(s).', + subscriberCount, + ); + }, []); + + // Check if current user is subscribed to a task + const isUserSubscribed = useCallback( + (task: Task) => + task.subscribers?.some( + (sub: TaskSubscriber) => sub.user_id === user.userId, + ) ?? false, + [user.userId], + ); + + // Check if force cancel option should be shown (for admins on shared tasks) + const showForceCancelOption = useCallback( + (task: Task) => { + const isSharedTask = task.scope === TaskScope.Shared; + const subscriberCount = task.subscriber_count || 0; + const userSubscribed = isUserSubscribed(task); + // Show for admins on shared tasks when: + // - Not subscribed (can only abort, so show checkbox pre-checked disabled), OR + // - Multiple subscribers (can choose between unsubscribe and force abort) + // Don't show when admin is the sole subscriber - cancel will abort anyway + return ( + isAdmin && isSharedTask && (subscriberCount > 1 || !userSubscribed) + ); + }, + [isAdmin, isUserSubscribed], + ); + + // Check if force cancel checkbox should be disabled (admin not subscribed) + const isForceCancelDisabled = useCallback( + (task: Task) => isAdmin && !isUserSubscribed(task), + [isAdmin, isUserSubscribed], + ); + + const handleTaskCancel = useCallback( + (task: Task, force: boolean = false) => { + SupersetClient.post({ + endpoint: `/api/v1/task/${task.uuid}/cancel`, + jsonPayload: force ? { force: true } : {}, + }).then( + ({ json }) => { + refreshData(); + const { action } = json as { action: string }; + if (action === 'aborted') { + addSuccessToast( + t('Task cancelled: %s', task.task_name || task.task_key), + ); + } else { + addSuccessToast( + t( + 'You have been removed from task: %s', + task.task_name || task.task_key, + ), + ); + } + }, + createErrorHandler(errMsg => + addDangerToast( + t('There was an issue cancelling the task: %s', errMsg), + ), + ), + ); + }, + [addDangerToast, addSuccessToast, refreshData], + ); + + // Handle opening the cancel modal - set initial forceCancel state + const openCancelModal = useCallback( + (task: Task) => { + // Pre-check force cancel if admin is not subscribed + const shouldPreCheck = isAdmin && !isUserSubscribed(task); + setForceCancel(shouldPreCheck); + setCancelModalTask(task); + }, + [isAdmin, isUserSubscribed], + ); + + // Handle modal confirmation + const handleCancelConfirm = useCallback(() => { + if (cancelModalTask) { + handleTaskCancel(cancelModalTask, forceCancel); + setCancelModalTask(null); + setForceCancel(false); + } + }, [cancelModalTask, forceCancel, handleTaskCancel]); + + // Handle modal close + const handleCancelModalClose = useCallback(() => { + setCancelModalTask(null); + setForceCancel(false); + }, []); + + const columns = useMemo( + () => [ + { + Cell: ({ + row: { + original: { task_name, task_key, uuid }, + }, + }: TaskCellProps) => { + // Display preference: task_name > task_key + const displayText = task_name || task_key; + const truncated = + displayText.length > 30 + ? `${displayText.slice(0, 30)}...` + : displayText; + + // Build tooltip with all identifiers + const tooltipLines = []; + if (task_name) tooltipLines.push(`Name: ${task_name}`); + tooltipLines.push(`Key: ${task_key}`); + tooltipLines.push(`UUID: ${uuid}`); + const tooltipText = tooltipLines.join('\n'); + + return ( + <Tooltip + title={ + <span style={{ whiteSpace: 'pre-line' }}>{tooltipText}</span> + } + placement="top" + > + <span>{truncated}</span> + </Tooltip> + ); + }, + accessor: 'task_name', + Header: t('Task'), + size: 'xl', + id: 'task', + }, + { + Cell: ({ + row: { + original: { status, properties, duration_seconds }, + }, + }: TaskCellProps) => ( + <TaskStatusIcon + status={status as TaskStatus} + progressPercent={properties?.progress_percent} + progressCurrent={properties?.progress_current} + progressTotal={properties?.progress_total} + durationSeconds={duration_seconds} + errorMessage={properties?.error_message} + exceptionType={properties?.exception_type} + /> + ), + accessor: 'status', + Header: t('Status'), + size: 'xs', + id: 'status', + }, + { + accessor: 'task_type', + Header: t('Type'), + size: 'md', + id: 'task_type', + }, + { + Cell: ({ + row: { + original: { scope }, + }, + }: TaskCellProps) => { + const scopeConfig: Record< + TaskScope, + { label: string; type: 'default' | 'info' | 'warning' } + > = { + [TaskScope.Private]: { label: t('Private'), type: 'default' }, + [TaskScope.Shared]: { label: t('Shared'), type: 'info' }, + [TaskScope.System]: { label: t('System'), type: 'warning' }, + }; + + const config = scopeConfig[scope as TaskScope] || { + label: scope, + type: 'default' as const, + }; + + return <Label type={config.type}>{config.label}</Label>; + }, + accessor: 'scope', + Header: t('Scope'), + size: 'sm', + id: 'scope', + }, + { + Cell: ({ + row: { + original: { subscriber_count, subscribers }, + }, + }: TaskCellProps) => { + if (!subscribers || subscriber_count === 0) { + return '-'; + } + + // Convert subscribers to FacePile format + const users = subscribers.map((sub: any) => ({ Review Comment: ```suggestion const users = subscribers.map((sub: TaskSubscriber) => ({ ``` ########## superset/models/task_subscribers.py: ########## @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TaskSubscriber model for tracking multi-user task subscriptions""" + +from datetime import datetime + +from flask_appbuilder import Model +from sqlalchemy import Column, DateTime, ForeignKey, Integer, UniqueConstraint +from sqlalchemy.orm import relationship +from superset_core.api.models import TaskSubscriber as CoreTaskSubscriber + +from superset.models.helpers import AuditMixinNullable + + +class TaskSubscriber(CoreTaskSubscriber, AuditMixinNullable, Model): + """ + Model for tracking task subscriptions in shared tasks. + + This model enables multi-user collaboration on async tasks. When a user + schedules a shared task with the same parameters as an existing task, + they are automatically subscribed to that task instead of creating a + duplicate. + + Subscribers can unsubscribe from shared tasks. When the last subscriber + unsubscribes, the task is automatically aborted. + """ + + __tablename__ = "task_subscribers" + + id = Column(Integer, primary_key=True) + task_id = Column( + Integer, ForeignKey("tasks.id", ondelete="CASCADE"), nullable=False + ) + user_id = Column( + Integer, ForeignKey("ab_user.id", ondelete="CASCADE"), nullable=False + ) + subscribed_at = Column(DateTime, nullable=False, default=datetime.utcnow) Review Comment: ```suggestion subscribed_at = Column(DateTime, nullable=False, default=datetime.now(timezone.utc)) ``` ########## superset/tasks/api.py: ########## @@ -0,0 +1,494 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Task REST API""" + +import logging +import uuid +from typing import TYPE_CHECKING + +from flask import Response +from flask_appbuilder.api import expose, protect, safe +from flask_appbuilder.models.sqla.interface import SQLAInterface + +from superset.commands.tasks.cancel import CancelTaskCommand +from superset.commands.tasks.exceptions import ( + TaskAbortFailedError, + TaskForbiddenError, + TaskNotAbortableError, + TaskNotFoundError, + TaskPermissionDeniedError, +) +from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod +from superset.extensions import event_logger +from superset.models.tasks import Task +from superset.tasks.filters import TaskFilter +from superset.tasks.schemas import ( + openapi_spec_methods_override, + TaskCancelRequestSchema, + TaskCancelResponseSchema, + TaskResponseSchema, + TaskStatusResponseSchema, +) +from superset.views.base_api import ( + BaseSupersetModelRestApi, + RelatedFieldFilter, + statsd_metrics, +) +from superset.views.filters import BaseFilterRelatedUsers, FilterRelatedOwners + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +def _is_valid_uuid(value: str) -> bool: + """Check if a string is a valid UUID format.""" + try: + uuid.UUID(value) + return True + except (ValueError, AttributeError): + return False + + +class TaskRestApi(BaseSupersetModelRestApi): + """REST API for task management""" + + datamodel = SQLAInterface(Task) + resource_name = "task" + allow_browser_login = True + + class_permission_name = "Task" + + # Map cancel and status to write/read permissions + method_permission_name = { + **MODEL_API_RW_METHOD_PERMISSION_MAP, + "cancel": "write", + "status": "read", + } + + include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | { + "cancel", + "status", + "related_subscribers", + "related", + } + + list_columns = [ + "id", + "uuid", + "task_type", + "task_key", + "task_name", + "scope", + "status", + "created_on", + "created_on_delta_humanized", + "changed_on", + "changed_by.first_name", + "changed_by.last_name", + "started_at", + "ended_at", + "created_by.id", + "created_by.first_name", + "created_by.last_name", + "user_id", + "payload", + "properties", + "duration_seconds", + "subscriber_count", + "subscribers", + ] + + list_select_columns = list_columns + ["created_by_fk", "changed_by_fk"] + + show_columns = list_columns + + order_columns = [ + "task_type", + "scope", + "status", + "created_on", + "changed_on", + "started_at", + "ended_at", + ] + + search_columns = [ + "task_type", + "task_key", + "task_name", + "scope", + "status", + "created_by", + "created_on", + ] + + base_order = ("created_on", "desc") + base_filters = [["id", TaskFilter, lambda: []]] + + # Related field configuration for filter dropdowns + allowed_rel_fields = {"created_by"} + related_field_filters = { + "created_by": RelatedFieldFilter("first_name", FilterRelatedOwners), + } + base_related_field_filters = { + "created_by": [["id", BaseFilterRelatedUsers, lambda: []]], + } + + show_model_schema = TaskResponseSchema() + list_model_schema = TaskResponseSchema() + cancel_request_schema = TaskCancelRequestSchema() + + openapi_spec_tag = "Tasks" + openapi_spec_component_schemas = ( + TaskResponseSchema, + TaskCancelRequestSchema, + TaskCancelResponseSchema, + TaskStatusResponseSchema, + ) + openapi_spec_methods = openapi_spec_methods_override + + @expose("/<uuid_or_id>", methods=("GET",)) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get", + log_to_statsd=False, + ) + def get(self, uuid_or_id: str) -> Response: + """Get a task. + --- + get: + summary: Get a task + parameters: + - in: path + schema: + type: string + name: uuid_or_id + description: The UUID or ID of the task + responses: + 200: + description: Task detail + content: + application/json: + schema: + type: object + properties: + result: + $ref: '#/components/schemas/TaskResponseSchema' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + """ + from superset.daos.tasks import TaskDAO + + try: + # Try to find by UUID first, then by ID + if _is_valid_uuid(uuid_or_id): + task = TaskDAO.find_one_or_none(uuid=uuid_or_id) + else: + task = TaskDAO.find_by_id(int(uuid_or_id)) + + if not task: + return self.response_404() + + result = self.show_model_schema.dump(task) + return self.response(200, result=result) + except (ValueError, TypeError): + return self.response_404() + + @expose("/<uuid_or_id>/status", methods=("GET",)) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.status", + log_to_statsd=False, + ) + def status(self, uuid_or_id: str) -> Response: + """Get only the status of a task (lightweight for polling). + --- + get: + summary: Get task status + parameters: + - in: path + schema: + type: string + name: uuid_or_id + description: The UUID or ID of the task + responses: + 200: + description: Task status + content: + application/json: + schema: + type: object + properties: + status: + type: string + description: Current status of the task + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + """ + from superset.daos.tasks import TaskDAO + + try: + # Try to find by UUID first, then by ID + if _is_valid_uuid(uuid_or_id): Review Comment: The `/status` endpoint fetches the entire task entity just to return `task.status`: ```python task = TaskDAO.find_one_or_none(uuid=uuid_or_id) return self.response(200, status=task.status) ``` **Suggested fix:** Add a `TaskDAO.get_status(uuid)` method that uses `SELECT status FROM tasks WHERE uuid = ?`. ########## superset/commands/distributed_lock/acquire.py: ########## @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +from datetime import datetime, timedelta +from functools import partial +from typing import Any + +import redis +from sqlalchemy.exc import SQLAlchemyError + +from superset.commands.distributed_lock.base import ( + BaseDistributedLockCommand, + get_default_lock_ttl, + get_redis_client, +) +from superset.daos.key_value import KeyValueDAO +from superset.exceptions import AcquireDistributedLockFailedException +from superset.key_value.exceptions import ( + KeyValueCodecEncodeException, + KeyValueUpsertFailedError, +) +from superset.key_value.types import KeyValueResource +from superset.utils.decorators import on_error, transaction + +logger = logging.getLogger(__name__) + + +class AcquireDistributedLock(BaseDistributedLockCommand): + """ + Acquire a distributed lock with automatic backend selection. + + Uses Redis SET NX EX when SIGNAL_CACHE_CONFIG is configured, + otherwise falls back to KeyValue table. + + Raises AcquireDistributedLockFailedException if: + - Lock is already held by another process + - Redis connection fails + """ + + ttl_seconds: int + + def __init__( + self, + namespace: str, + params: dict[str, Any] | None = None, + ttl_seconds: int | None = None, + ) -> None: + super().__init__(namespace, params) + self.ttl_seconds = ttl_seconds or get_default_lock_ttl() + + def run(self) -> None: + if (redis_client := get_redis_client()) is not None: + self._acquire_redis(redis_client) + else: + self._acquire_kv() + + def _acquire_redis(self, redis_client: Any) -> None: + """Acquire lock using Redis SET NX EX (atomic).""" + try: + # SET NX EX: Set if not exists, with expiration + # Returns True if lock acquired, None if already exists + acquired = redis_client.set( + self.redis_lock_key, + "1", + nx=True, + ex=self.ttl_seconds, + ) + + if not acquired: + logger.debug("Redis lock on %s already taken", self.redis_lock_key) + raise AcquireDistributedLockFailedException("Lock already taken") + + logger.debug( + "Acquired Redis lock: %s (TTL=%ds)", + self.redis_lock_key, + self.ttl_seconds, + ) + + except redis.RedisError as ex: + logger.error("Redis lock error for %s: %s", self.redis_lock_key, ex) + raise AcquireDistributedLockFailedException( + f"Redis lock failed: {ex}" + ) from ex + + @transaction( + on_error=partial( + on_error, + catches=( + KeyValueCodecEncodeException, + KeyValueUpsertFailedError, + SQLAlchemyError, + ), + reraise=AcquireDistributedLockFailedException, + ), + ) + def _acquire_kv(self) -> None: + """Acquire lock using KeyValue table (database).""" + # Delete expired entries first to prevent stale locks from blocking + KeyValueDAO.delete_expired_entries(self.resource) + + # Create entry - unique constraint will raise if lock already exists + KeyValueDAO.create_entry( + resource=KeyValueResource.LOCK, + value={"value": True}, + codec=self.codec, + key=self.key, + expires_on=datetime.now() + timedelta(seconds=self.ttl_seconds), Review Comment: ```suggestion expires_on=datetime.now(timezone.utc) + timedelta(seconds=self.ttl_seconds), ``` ########## superset/tasks/schemas.py: ########## @@ -0,0 +1,200 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Task API schemas""" + +from marshmallow import fields, Schema +from marshmallow.fields import Method + +# RISON/JSON schemas for query parameters +get_delete_ids_schema = {"type": "array", "items": {"type": "string"}} + +# Field descriptions +uuid_description = "The unique identifier (UUID) of the task" +task_key_description = "The task identifier used for deduplication" +task_type_description = ( + "The type of task (e.g., 'sql_execution', 'thumbnail_generation')" +) +task_name_description = "Human-readable name for the task" +status_description = "Current status of the task" +created_on_description = "Timestamp when the task was created" +changed_on_description = "Timestamp when the task was last updated" +started_at_description = "Timestamp when the task started execution" +ended_at_description = "Timestamp when the task completed or failed" +created_by_description = "User who created the task" +user_id_description = "ID of the user context for task execution" +payload_description = "Task-specific data in JSON format" +properties_description = ( + "Runtime state and execution config. Contains: is_abortable, progress_percent, " + "progress_current, progress_total, error_message, exception_type, stack_trace, " + "timeout" +) +duration_seconds_description = ( + "Duration in seconds - for finished tasks: execution time, " + "for running tasks: time since start, for pending: queue time" +) +scope_description = ( + "Task scope: 'private' (user-specific), 'shared' (multi-user), " + "or 'system' (admin-only)" +) +subscriber_count_description = ( + "Number of users subscribed to this task (for shared tasks)" +) +subscribers_description = "List of users subscribed to this task (for shared tasks)" + + +class UserSchema(Schema): + """Schema for user information""" + + id = fields.Int() + first_name = fields.String() + last_name = fields.String() + + +class TaskResponseSchema(Schema): + """ + Schema for task response. + + Used for both list and detail endpoints. + """ + + id = fields.Int(metadata={"description": "Internal task ID"}) + uuid = fields.String(metadata={"description": uuid_description}) + task_key = fields.String(metadata={"description": task_key_description}) + task_type = fields.String(metadata={"description": task_type_description}) + task_name = fields.String( + metadata={"description": task_name_description}, allow_none=True + ) + status = fields.String(metadata={"description": status_description}) + created_on = fields.DateTime(metadata={"description": created_on_description}) + created_on_delta_humanized = Method( + "get_created_on_delta_humanized", + metadata={"description": "Humanized time since creation"}, + ) + changed_on = fields.DateTime(metadata={"description": changed_on_description}) + changed_by = fields.Nested(UserSchema, allow_none=True) + started_at = fields.DateTime( + metadata={"description": started_at_description}, allow_none=True + ) + ended_at = fields.DateTime( + metadata={"description": ended_at_description}, allow_none=True + ) + created_by = fields.Nested(UserSchema, allow_none=True) + user_id = fields.Int(metadata={"description": user_id_description}, allow_none=True) + payload = Method("get_payload_dict", metadata={"description": payload_description}) + properties = Method( + "get_properties", metadata={"description": properties_description} + ) + duration_seconds = Method( + "get_duration", + metadata={"description": duration_seconds_description}, + ) + scope = fields.String(metadata={"description": scope_description}) + subscriber_count = Method( + "get_subscriber_count", metadata={"description": subscriber_count_description} + ) + subscribers = Method( + "get_subscribers", metadata={"description": subscribers_description} + ) + + def get_payload_dict(self, obj: object) -> dict[str, object] | None: + """Get payload as dictionary""" + return obj.payload_dict # type: ignore[attr-defined] + + def get_properties(self, obj: object) -> dict[str, object]: + """Get properties dict, filtering stack_trace if SHOW_STACKTRACE is disabled.""" + from flask import current_app + + properties = dict(obj.properties_dict) # type: ignore[attr-defined] + + # Remove stack_trace unless SHOW_STACKTRACE is enabled + if not current_app.config.get("SHOW_STACKTRACE", False): + properties.pop("stack_trace", None) + + return properties + + def get_duration(self, obj: object) -> float | None: + """Get duration in seconds""" + return obj.duration_seconds # type: ignore[attr-defined] + + def get_created_on_delta_humanized(self, obj: object) -> str: + """Get humanized time since creation""" + return obj.created_on_delta_humanized() # type: ignore[attr-defined] + + def get_subscriber_count(self, obj: object) -> int: + """Get number of subscribers""" + return obj.subscriber_count # type: ignore[attr-defined] + + def get_subscribers(self, obj: object) -> list[dict[str, object]]: Review Comment: The `get_subscribers` method causes N+1 queries: ```python for sub in obj.subscribers: # 1 query to load subscribers (lazy) subscribers.append({ "first_name": sub.user.first_name if sub.user else None, # N queries! "last_name": sub.user.last_name if sub.user else None, }) ``` The `user` relationship in `TaskSubscriber` (line 56 of `task_subscribers.py`) has no eager loading: ```python user = relationship("User", foreign_keys=[user_id]) # Default lazy loading ``` **Impact:** For a task with 10 subscribers, this causes 11 queries (1 for subscribers + 10 for users). **Suggested fix:** Add `lazy="joined"` to the user relationship, or use `selectinload(TaskSubscriber.user)` in queries. ########## superset/tasks/ambient_context.py: ########## @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Ambient context management for the Global Task Framework (GTF)""" + +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Iterator + +from superset.tasks.context import TaskContext + +# Global context variable for ambient context pattern +# This is thread-safe and async-safe via Python's contextvars +_current_context: ContextVar[TaskContext | None] = ContextVar( + "task_context", default=None +) + + +def get_context() -> TaskContext: + """ + Get the current task context from contextvars. + + This function provides ambient access to the task context without + requiring it to be passed as a parameter. It can only be called + from within a task execution. + + :returns: The current TaskContext + :raises RuntimeError: If called outside a task execution context + + Example: + >>> @task() + >>> def my_task(chart_id: int) -> None: + >>> ctx = get_context() # Access ambient context + >>> + >>> # Update progress and payload atomically + >>> ctx.update_task( + >>> progress=0.5, + >>> payload={"chart_id": chart_id} + >>> ) + """ + ctx = _current_context.get() + if ctx is None: + raise RuntimeError( + "get_context() called outside task execution context. " + "This function can only be called from within an @async_task " Review Comment: ```suggestion "This function can only be called from within an @task " ``` ########## superset/tasks/manager.py: ########## @@ -0,0 +1,762 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Task manager for the Global Task Framework (GTF)""" + +from __future__ import annotations + +import logging +import threading +import time +from typing import Any, Callable, TYPE_CHECKING + +import redis +from superset_core.api.tasks import TaskProperties, TaskScope + +from superset.async_events.cache_backend import ( + RedisCacheBackend, + RedisSentinelCacheBackend, +) +from superset.extensions import cache_manager +from superset.tasks.constants import ABORT_STATES, TERMINAL_STATES +from superset.tasks.utils import generate_random_task_key + +if TYPE_CHECKING: + from flask import Flask + + from superset.models.tasks import Task + +logger = logging.getLogger(__name__) + + +class AbortListener: + """ + Handle for a background abort listener. + + Returned by TaskManager.listen_for_abort() to allow stopping the listener. + """ + + def __init__( + self, + task_uuid: str, + thread: threading.Thread, + stop_event: threading.Event, + pubsub: redis.client.PubSub | None = None, + ) -> None: + self._task_uuid = task_uuid + self._thread = thread + self._stop_event = stop_event + self._pubsub = pubsub + + def stop(self) -> None: + """Stop the abort listener.""" + self._stop_event.set() + + # Close pub/sub subscription if active + if self._pubsub is not None: + try: + self._pubsub.unsubscribe() + self._pubsub.close() + except Exception as ex: + logger.debug("Error closing pub/sub during stop: %s", ex) + + # Wait for thread to finish (with timeout to avoid blocking indefinitely) + if self._thread.is_alive(): + self._thread.join(timeout=2.0) + + # Check if thread is still running after timeout + if self._thread.is_alive(): + # Thread is a daemon, so it will be killed when process exits. + # Log warning but continue - cleanup will still proceed. + logger.warning( + "Abort listener thread for task %s did not terminate within " + "2 seconds. Thread will be terminated when process exits.", + self._task_uuid, + ) + else: + logger.debug("Stopped abort listener for task %s", self._task_uuid) + else: + logger.debug("Stopped abort listener for task %s", self._task_uuid) + + +class TaskManager: + """ + Handles task creation, scheduling, and abort notifications. + + The TaskManager is responsible for: + 1. Creating task entries in the metastore (Task model) + 2. Scheduling task execution via Celery + 3. Handling deduplication (returning existing active task if duplicate) + 4. Managing real-time abort notifications (optional) + + Redis pub/sub is opt-in via SIGNAL_CACHE_CONFIG configuration. When not + configured, tasks use database polling for abort detection. + """ + + # Class-level state (initialized once via init_app) + _channel_prefix: str = "gtf:abort:" + _completion_channel_prefix: str = "gtf:complete:" + _initialized: bool = False + + # Backward compatibility alias - prefer importing from superset.tasks.constants + TERMINAL_STATES = TERMINAL_STATES + + @classmethod + def init_app(cls, app: Flask) -> None: + """ + Initialize the TaskManager with Flask app config. + + Redis connection is managed by CacheManager - this just reads channel prefixes. + + :param app: Flask application instance + """ + if cls._initialized: + return + + cls._channel_prefix = app.config.get("TASKS_ABORT_CHANNEL_PREFIX", "gtf:abort:") + cls._completion_channel_prefix = app.config.get( + "TASKS_COMPLETION_CHANNEL_PREFIX", "gtf:complete:" + ) + + cls._initialized = True + + @classmethod + def _get_cache(cls) -> RedisCacheBackend | RedisSentinelCacheBackend | None: + """ + Get the signal cache backend. + + :returns: The signal cache backend, or None if not configured + """ + return cache_manager.signal_cache + + @classmethod + def is_pubsub_available(cls) -> bool: + """ + Check if Redis pub/sub backend is configured and available. + + :returns: True if Redis is available for pub/sub, False otherwise + """ + return cls._get_cache() is not None + + @classmethod + def get_abort_channel(cls, task_uuid: str) -> str: + """ + Get the abort channel name for a task. + + :param task_uuid: UUID of the task + :returns: Channel name for the task's abort notifications + """ + return f"{cls._channel_prefix}{task_uuid}" + + @classmethod + def publish_abort(cls, task_uuid: str) -> bool: + """ + Publish an abort message to the task's channel. + + :param task_uuid: UUID of the task to abort + :returns: True if message was published, False if Redis unavailable + """ + cache = cls._get_cache() + if not cache: + return False + + try: + channel = cls.get_abort_channel(task_uuid) + subscriber_count = cache.publish(channel, "abort") + logger.debug( + "Published abort to channel %s (%d subscribers)", + channel, + subscriber_count, + ) + return True + except redis.RedisError as ex: + logger.error("Failed to publish abort for task %s: %s", task_uuid, ex) + return False + + @classmethod + def get_completion_channel(cls, task_uuid: str) -> str: + """ + Get the completion channel name for a task. + + :param task_uuid: UUID of the task + :returns: Channel name for the task's completion notifications + """ + return f"{cls._completion_channel_prefix}{task_uuid}" + + @classmethod + def publish_completion(cls, task_uuid: str, status: str) -> bool: + """ + Publish a completion message to the task's channel. + + Called when task reaches terminal state (SUCCESS, FAILURE, ABORTED, TIMED_OUT). + This notifies any waiters (e.g., sync callers waiting for an existing task). + + :param task_uuid: UUID of the completed task + :param status: Final status of the task + :returns: True if message was published, False if Redis unavailable + """ + cache = cls._get_cache() + if not cache: + return False + + try: + channel = cls.get_completion_channel(task_uuid) + subscriber_count = cache.publish(channel, status) + logger.debug( + "Published completion to channel %s (status=%s, %d subscribers)", + channel, + status, + subscriber_count, + ) + return True + except redis.RedisError as ex: + logger.error("Failed to publish completion for task %s: %s", task_uuid, ex) + return False + + @classmethod + def wait_for_completion( + cls, + task_uuid: str, + timeout: float | None = None, + poll_interval: float = 1.0, + app: Any = None, + ) -> "Task": + """ + Block until task reaches terminal state. + + Uses Redis pub/sub if configured for low-latency, low-CPU waiting. + Uses database polling if Redis is not configured. + + :param task_uuid: UUID of the task to wait for + :param timeout: Maximum time to wait in seconds (None = no limit) + :param poll_interval: Interval for database polling (seconds) + :param app: Flask app for database access + :returns: Task in terminal state + :raises TimeoutError: If timeout expires before task completes + :raises ValueError: If task not found + """ + from superset.daos.tasks import TaskDAO + + start_time = time.monotonic() + + def time_remaining() -> float | None: + if timeout is None: + return None + elapsed = time.monotonic() - start_time + remaining = timeout - elapsed + return remaining if remaining > 0 else 0 + + def get_task() -> "Task | None": + if app: + with app.app_context(): + return TaskDAO.find_one_or_none(uuid=task_uuid) + return TaskDAO.find_one_or_none(uuid=task_uuid) + + # Check current state first + task = get_task() + if not task: + raise ValueError(f"Task {task_uuid} not found") + + if task.status in cls.TERMINAL_STATES: + return task + + logger.info( Review Comment: It seems all `logger.info` calls in this file could be `logger.debug`? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
