villebro commented on code in PR #36368: URL: https://github.com/apache/superset/pull/36368#discussion_r2749796624
########## superset/daos/tasks.py: ########## @@ -0,0 +1,309 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Task DAO for Global Task Framework (GTF)""" + +import logging +from datetime import datetime, timezone +from typing import Any + +from superset_core.api.tasks import TaskProperties, TaskScope, TaskStatus + +from superset.daos.base import BaseDAO +from superset.daos.exceptions import DAODeleteFailedError +from superset.extensions import db +from superset.models.task_subscribers import TaskSubscriber +from superset.models.tasks import Task +from superset.tasks.constants import ABORTABLE_STATES +from superset.tasks.filters import TaskFilter +from superset.tasks.utils import get_active_dedup_key + +logger = logging.getLogger(__name__) + + +class TaskDAO(BaseDAO[Task]): + """ + Concrete TaskDAO for the Global Task Framework (GTF). + + Provides database access operations for async tasks including + creation, status management, filtering, and subscription management + for shared tasks. + """ + + base_filter = TaskFilter + + @classmethod + def find_by_task_key( + cls, + task_type: str, + task_key: str, + scope: TaskScope | str = TaskScope.PRIVATE, + user_id: int | None = None, + ) -> Task | None: + """ + Find active task by type, key, scope, and user. + + Uses dedup_key internally for efficient querying with a unique index. + Only returns tasks that are active (pending or in progress). + + Uniqueness logic by scope: + - private: scope + task_type + task_key + user_id + - shared/system: scope + task_type + task_key (user-agnostic) + + :param task_type: Task type to filter by + :param task_key: Task identifier for deduplication + :param scope: Task scope (private/shared/system) + :param user_id: User ID (required for private tasks) + :returns: Task instance or None if not found or not active + """ + dedup_key = get_active_dedup_key( + scope=scope, + task_type=task_type, + task_key=task_key, + user_id=user_id, + ) + + # Simple single-column query with unique index + return db.session.query(Task).filter(Task.dedup_key == dedup_key).one_or_none() + + @classmethod + def create_task( + cls, + task_type: str, + task_key: str, + scope: TaskScope | str = TaskScope.PRIVATE, + user_id: int | None = None, + payload: dict[str, Any] | None = None, + properties: TaskProperties | None = None, + **kwargs: Any, + ) -> Task: + """ + Create a new task record in the database. + + This is a pure data operation - assumes caller holds lock and has + already checked for existing tasks. Business logic (create vs join) + is handled by SubmitTaskCommand. + + :param task_type: Type of task to create + :param task_key: Task identifier (required) + :param scope: Task scope (private/shared/system), defaults to private + :param user_id: User ID creating the task + :param payload: Optional user-defined context data (dict) + :param properties: Optional framework-managed runtime state (e.g., timeout) + :param kwargs: Additional task attributes (e.g., task_name) + :returns: Created Task instance + """ + # Build dedup_key for active task + dedup_key = get_active_dedup_key( + scope=scope, + task_type=task_type, + task_key=task_key, + user_id=user_id, + ) + + # Handle both TaskScope enum and string values + scope_value = scope.value if isinstance(scope, TaskScope) else scope + + # Note: properties is handled separately via update_properties() + # because it's a hybrid property with only a getter + task_data = { + "task_type": task_type, + "task_key": task_key, + "scope": scope_value, + "status": TaskStatus.PENDING.value, + "dedup_key": dedup_key, + **kwargs, + } + + # Handle payload - serialize to JSON if dict provided + if payload: + from superset.utils import json + + task_data["payload"] = json.dumps(payload) + + if user_id is not None: + task_data["user_id"] = user_id + + task = cls.create(attributes=task_data) + + # Set properties after creation (hybrid property with getter only) + if properties: + task.update_properties(properties) + + # Flush to get the task ID (auto-incremented primary key) + db.session.flush() + + # Auto-subscribe creator for all tasks + # This enables consistent subscriber display across all task types + if user_id: + cls.add_subscriber(task.id, user_id) + logger.info( + "Creator %s auto-subscribed to task: %s (scope: %s)", + user_id, + task_key, + scope_value, + ) + + logger.info( + "Created new async task: %s (type: %s, scope: %s)", + task_key, + task_type, + scope_value, + ) + return task + + @classmethod + def abort_task(cls, task_uuid: str, skip_base_filter: bool = False) -> Task | None: + """ + Abort a task by UUID. + + This is a pure data operation. Business logic (subscriber count checks, + permission validation) is handled by CancelTaskCommand which holds the lock. + + Abort behavior by status: + - PENDING: Goes directly to ABORTED (always abortable) + - IN_PROGRESS with is_abortable=True: Goes to ABORTING + - IN_PROGRESS with is_abortable=False/None: Raises TaskNotAbortableError + - ABORTING: Returns task (idempotent) + - Finished statuses: Returns None + + Note: Caller is responsible for calling TaskManager.publish_abort() AFTER + the transaction commits if task.status == ABORTING. This prevents race + conditions where listeners check the DB before the status is visible. + + :param task_uuid: UUID of task to abort + :param skip_base_filter: If True, skip base filter (for admin abortions) + :returns: Task if aborted/aborting, None if not found or already finished + :raises TaskNotAbortableError: If in-progress task has no abort handler + """ + from superset.commands.tasks.exceptions import TaskNotAbortableError + + task = cls.find_one_or_none(skip_base_filter=skip_base_filter, uuid=task_uuid) + if not task: + return None + + # Already aborting - idempotent success + if task.status == TaskStatus.ABORTING.value: + logger.info("Task %s is already aborting", task_uuid) + return task + + # Already finished - cannot abort + if task.status not in ABORTABLE_STATES: + return None + + # PENDING: Go directly to ABORTED + if task.status == TaskStatus.PENDING.value: + task.set_status(TaskStatus.ABORTED) + logger.info("Aborted pending task: %s (scope: %s)", task_uuid, task.scope) + return task + + # IN_PROGRESS: Check if abortable + if task.status == TaskStatus.IN_PROGRESS.value: + if task.properties.get("is_abortable") is not True: + raise TaskNotAbortableError( + f"Task {task_uuid} is in progress but has not registered " + "an abort handler (is_abortable is not true)" + ) + + # Transition to ABORTING (not ABORTED yet) + task.status = TaskStatus.ABORTING.value + db.session.merge(task) + logger.info("Set task %s to ABORTING (scope: %s)", task_uuid, task.scope) + + # NOTE: publish_abort is NOT called here - caller handles it after commit + # This prevents race conditions where listeners check DB before commit + + return task + + return None + + # Subscription management methods + + @classmethod + def add_subscriber(cls, task_id: int, user_id: int) -> bool: + """ + Add a user as a subscriber to a task. + + :param task_id: ID of the task + :param user_id: ID of the user to subscribe + :returns: True if subscriber was added, False if already exists + """ + # Check first to avoid IntegrityError which invalidates the session + # in nested transaction contexts (IntegrityError can't be recovered from) + existing = ( + db.session.query(TaskSubscriber) + .filter_by(task_id=task_id, user_id=user_id) + .first() + ) + if existing: + logger.debug( + "Subscriber %s already subscribed to task %s", user_id, task_id + ) + return False + + subscription = TaskSubscriber( + task_id=task_id, + user_id=user_id, + subscribed_at=datetime.now(timezone.utc), + ) + db.session.add(subscription) + db.session.flush() Review Comment: We have locking for this -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
