This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c7cbced1267 Move SkipMixin and BranchMixIn to Task SDK (#62749)
c7cbced1267 is described below

commit c7cbced12678ca1c6b2b1dcbcdf83eec2b1897d7
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Mar 3 08:06:44 2026 +0000

    Move SkipMixin and BranchMixIn to Task SDK (#62749)
---
 airflow-core/src/airflow/models/__init__.py        |   7 +-
 airflow-core/src/airflow/models/skipmixin.py       |  22 ---
 airflow-core/src/airflow/operators/__init__.py     |   4 +-
 .../src/airflow/providers/common/compat/sdk.py     |   8 +
 .../providers/common/compat/standard/operators.py  |   2 +
 .../providers/common/compat/standard/utils.py      |  17 ++
 .../airflow/providers/standard/operators/branch.py |   2 +-
 .../providers/standard/operators/datetime.py       |   3 +-
 .../airflow/providers/standard/operators/hitl.py   |   4 +-
 .../providers/standard/operators/latest_only.py    |   2 +-
 .../airflow/providers/standard/operators/python.py |  10 +-
 .../providers/standard/operators/weekday.py        |   3 +-
 .../tests/unit/standard/operators/test_weekday.py  |   2 +-
 .../tests/unit/standard/utils/test_skipmixin.py    |   2 +-
 task-sdk/docs/api.rst                              |   6 +
 task-sdk/src/airflow/sdk/__init__.py               |   8 +
 task-sdk/src/airflow/sdk/__init__.pyi              |   8 +
 .../src/airflow/sdk/bases}/branch.py               |  11 +-
 task-sdk/src/airflow/sdk/bases/skipmixin.py        | 168 +++++++++++++++++
 .../src/airflow/sdk/definitions/mappedoperator.py  |   2 +-
 task-sdk/tests/task_sdk/bases/test_branch.py       | 121 ++++++++++++
 task-sdk/tests/task_sdk/bases/test_skipmixin.py    | 202 +++++++++++++++++++++
 task-sdk/tests/task_sdk/docs/test_public_api.py    |   1 +
 23 files changed, 561 insertions(+), 54 deletions(-)

diff --git a/airflow-core/src/airflow/models/__init__.py 
b/airflow-core/src/airflow/models/__init__.py
index 400151124e0..8e12325f568 100644
--- a/airflow-core/src/airflow/models/__init__.py
+++ b/airflow-core/src/airflow/models/__init__.py
@@ -116,7 +116,7 @@ __lazy_imports = {
     "Param": "airflow.sdk.definitions.param",
     "Pool": "airflow.models.pool",
     "RenderedTaskInstanceFields": "airflow.models.renderedtifields",
-    "SkipMixin": "airflow.models.skipmixin",
+    "SkipMixin": "airflow.sdk.bases.skipmixin",
     "TaskInstance": "airflow.models.taskinstance",
     "TaskReschedule": "airflow.models.taskreschedule",
     "Team": "airflow.models.team",
@@ -142,13 +142,13 @@ if TYPE_CHECKING:
     from airflow.models.log import Log
     from airflow.models.pool import Pool
     from airflow.models.renderedtifields import RenderedTaskInstanceFields
-    from airflow.models.skipmixin import SkipMixin
     from airflow.models.taskinstance import TaskInstance, clear_task_instances
     from airflow.models.taskinstancehistory import TaskInstanceHistory
     from airflow.models.taskreschedule import TaskReschedule
     from airflow.models.trigger import Trigger
     from airflow.models.variable import Variable
     from airflow.sdk import DAG, BaseOperator, BaseOperatorLink, Param
+    from airflow.sdk.bases.skipmixin import SkipMixin
     from airflow.sdk.bases.xcom import BaseXCom
     from airflow.sdk.definitions.mappedoperator import MappedOperator
     from airflow.sdk.execution_time.xcom import XCom
@@ -176,6 +176,9 @@ __deprecated_classes = {
     "baseoperatorlink": {
         "BaseOperatorLink": "airflow.sdk.BaseOperatorLink",
     },
+    "skipmixin": {
+        "SkipMixin": "airflow.sdk.bases.skipmixin.SkipMixin",
+    },
     "operator": {
         "BaseOperator": "airflow.sdk.BaseOperator",
         "Operator": "airflow.sdk.types.Operator",
diff --git a/airflow-core/src/airflow/models/skipmixin.py 
b/airflow-core/src/airflow/models/skipmixin.py
deleted file mode 100644
index 8e78881e469..00000000000
--- a/airflow-core/src/airflow/models/skipmixin.py
+++ /dev/null
@@ -1,22 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from airflow.providers.standard.utils.skipmixin import SkipMixin
-
-__all__ = ["SkipMixin"]
diff --git a/airflow-core/src/airflow/operators/__init__.py 
b/airflow-core/src/airflow/operators/__init__.py
index 84a526c134f..d377aef37b1 100644
--- a/airflow-core/src/airflow/operators/__init__.py
+++ b/airflow-core/src/airflow/operators/__init__.py
@@ -65,8 +65,8 @@ __deprecated_classes = {
         "SmoothOperator": 
"airflow.providers.standard.operators.smooth.SmoothOperator",
     },
     "branch": {
-        "BranchMixIn": 
"airflow.providers.standard.operators.branch.BranchMixIn",
-        "BaseBranchOperator": 
"airflow.providers.standard.operators.branch.BaseBranchOperator",
+        "BranchMixIn": "airflow.sdk.bases.branch.BranchMixIn",
+        "BaseBranchOperator": "airflow.sdk.bases.branch.BaseBranchOperator",
     }
 
 }
diff --git a/providers/common/compat/src/airflow/providers/common/compat/sdk.py 
b/providers/common/compat/src/airflow/providers/common/compat/sdk.py
index bcd28c4799f..30fa73ad076 100644
--- a/providers/common/compat/src/airflow/providers/common/compat/sdk.py
+++ b/providers/common/compat/src/airflow/providers/common/compat/sdk.py
@@ -70,6 +70,7 @@ if TYPE_CHECKING:
     )
     from airflow.sdk._shared.listeners import hookimpl as hookimpl
     from airflow.sdk._shared.observability.metrics.stats import Stats as Stats
+    from airflow.sdk.bases.branch import BaseBranchOperator as 
BaseBranchOperator, BranchMixIn as BranchMixIn
     from airflow.sdk.bases.decorator import (
         DecoratedMappedOperator as DecoratedMappedOperator,
         DecoratedOperator as DecoratedOperator,
@@ -80,6 +81,7 @@ if TYPE_CHECKING:
         task_decorator_factory as task_decorator_factory,
     )
     from airflow.sdk.bases.sensor import poke_mode_only as poke_mode_only
+    from airflow.sdk.bases.skipmixin import SkipMixin as SkipMixin
     from airflow.sdk.configuration import conf as conf
     from airflow.sdk.definitions.context import context_merge as context_merge
     from airflow.sdk.definitions.mappedoperator import MappedOperator as 
MappedOperator
@@ -149,6 +151,12 @@ _IMPORT_MAP: dict[str, str | tuple[str, ...]] = {
     # 
============================================================================
     "BaseHook": ("airflow.sdk", "airflow.hooks.base"),
     # 
============================================================================
+    # Branching
+    # 
============================================================================
+    "BaseBranchOperator": ("airflow.sdk.bases.branch", 
"airflow.providers.standard.operators.branch"),
+    "BranchMixIn": ("airflow.sdk.bases.branch", 
"airflow.providers.standard.operators.branch"),
+    "SkipMixin": ("airflow.sdk.bases.skipmixin", "airflow.models.skipmixin"),
+    # 
============================================================================
     # Sensors
     # 
============================================================================
     "BaseSensorOperator": ("airflow.sdk", "airflow.sensors.base"),
diff --git 
a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
 
b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
index 6096b54638b..b916a9d5f9e 100644
--- 
a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
+++ 
b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
@@ -27,8 +27,10 @@ from airflow.providers.common.compat.version_compat import (
 
 _IMPORT_MAP: dict[str, str | tuple[str, ...]] = {
     # Re-export from sdk (which handles Airflow 2.x/3.x fallbacks)
+    "BaseBranchOperator": "airflow.providers.common.compat.sdk",
     "BaseOperator": "airflow.providers.common.compat.sdk",
     "BaseAsyncOperator": "airflow.providers.common.compat.sdk",
+    "BranchMixIn": "airflow.providers.common.compat.sdk",
     "get_current_context": "airflow.providers.common.compat.sdk",
     "is_async_callable": "airflow.providers.common.compat.sdk",
     # Standard provider items with direct fallbacks
diff --git 
a/providers/common/compat/src/airflow/providers/common/compat/standard/utils.py 
b/providers/common/compat/src/airflow/providers/common/compat/standard/utils.py
index 3f7f4b2962f..c2e15a6a377 100644
--- 
a/providers/common/compat/src/airflow/providers/common/compat/standard/utils.py
+++ 
b/providers/common/compat/src/airflow/providers/common/compat/standard/utils.py
@@ -20,6 +20,23 @@ from __future__ import annotations
 from airflow.providers.common.compat._compat_utils import create_module_getattr
 
 _IMPORT_MAP: dict[str, str | tuple[str, ...]] = {
+    "SkipMixin": (
+        "airflow.sdk.bases.skipmixin",
+        "airflow.providers.standard.utils.skipmixin",
+        "airflow.models.skipmixin",
+    ),
+    "XCOM_SKIPMIXIN_KEY": (
+        "airflow.sdk.bases.skipmixin",
+        "airflow.providers.standard.utils.skipmixin",
+    ),
+    "XCOM_SKIPMIXIN_SKIPPED": (
+        "airflow.sdk.bases.skipmixin",
+        "airflow.providers.standard.utils.skipmixin",
+    ),
+    "XCOM_SKIPMIXIN_FOLLOWED": (
+        "airflow.sdk.bases.skipmixin",
+        "airflow.providers.standard.utils.skipmixin",
+    ),
     "write_python_script": (
         "airflow.providers.standard.utils.python_virtualenv",
         "airflow.utils.python_virtualenv",
diff --git 
a/providers/standard/src/airflow/providers/standard/operators/branch.py 
b/providers/standard/src/airflow/providers/standard/operators/branch.py
index 64796bd0f7a..f1ed3bfdf66 100644
--- a/providers/standard/src/airflow/providers/standard/operators/branch.py
+++ b/providers/standard/src/airflow/providers/standard/operators/branch.py
@@ -27,7 +27,7 @@ from airflow.providers.standard.version_compat import 
AIRFLOW_V_3_0_PLUS, BaseOp
 if AIRFLOW_V_3_0_PLUS:
     from airflow.providers.standard.utils.skipmixin import SkipMixin
 else:
-    from airflow.models.skipmixin import SkipMixin
+    from airflow.models.skipmixin import SkipMixin  # type: ignore[no-redef]
 
 if TYPE_CHECKING:
     from airflow.providers.common.compat.sdk import Context
diff --git 
a/providers/standard/src/airflow/providers/standard/operators/datetime.py 
b/providers/standard/src/airflow/providers/standard/operators/datetime.py
index 914fc20828b..589e106fe31 100644
--- a/providers/standard/src/airflow/providers/standard/operators/datetime.py
+++ b/providers/standard/src/airflow/providers/standard/operators/datetime.py
@@ -20,8 +20,7 @@ import datetime
 from collections.abc import Iterable
 from typing import TYPE_CHECKING
 
-from airflow.providers.common.compat.sdk import AirflowException, timezone
-from airflow.providers.standard.operators.branch import BaseBranchOperator
+from airflow.providers.common.compat.sdk import AirflowException, 
BaseBranchOperator, timezone
 
 if TYPE_CHECKING:
     from airflow.providers.common.compat.sdk import Context
diff --git 
a/providers/standard/src/airflow/providers/standard/operators/hitl.py 
b/providers/standard/src/airflow/providers/standard/operators/hitl.py
index fc999bf8cd2..b009422be8d 100644
--- a/providers/standard/src/airflow/providers/standard/operators/hitl.py
+++ b/providers/standard/src/airflow/providers/standard/operators/hitl.py
@@ -28,11 +28,9 @@ from collections.abc import Collection, Mapping, Sequence
 from typing import TYPE_CHECKING, Any
 from urllib.parse import ParseResult, urlencode, urlparse, urlunparse
 
-from airflow.providers.common.compat.sdk import conf
+from airflow.providers.common.compat.sdk import BranchMixIn, SkipMixin, conf
 from airflow.providers.standard.exceptions import HITLRejectException, 
HITLTimeoutError, HITLTriggerEventError
-from airflow.providers.standard.operators.branch import BranchMixIn
 from airflow.providers.standard.triggers.hitl import HITLTrigger, 
HITLTriggerEventSuccessPayload
-from airflow.providers.standard.utils.skipmixin import SkipMixin
 from airflow.providers.standard.version_compat import BaseOperator
 from airflow.sdk.bases.notifier import BaseNotifier
 from airflow.sdk.definitions.param import ParamsDict
diff --git 
a/providers/standard/src/airflow/providers/standard/operators/latest_only.py 
b/providers/standard/src/airflow/providers/standard/operators/latest_only.py
index 087a8607d45..54f4df48983 100644
--- a/providers/standard/src/airflow/providers/standard/operators/latest_only.py
+++ b/providers/standard/src/airflow/providers/standard/operators/latest_only.py
@@ -25,7 +25,7 @@ from typing import TYPE_CHECKING
 
 import pendulum
 
-from airflow.providers.standard.operators.branch import BaseBranchOperator
+from airflow.providers.common.compat.sdk import BaseBranchOperator
 from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_2_PLUS
 from airflow.utils.types import DagRunType
 
diff --git 
a/providers/standard/src/airflow/providers/standard/operators/python.py 
b/providers/standard/src/airflow/providers/standard/operators/python.py
index cba2c1f31dd..31df8b747e9 100644
--- a/providers/standard/src/airflow/providers/standard/operators/python.py
+++ b/providers/standard/src/airflow/providers/standard/operators/python.py
@@ -50,7 +50,9 @@ from airflow.models.variable import Variable
 from airflow.providers.common.compat.sdk import (
     AirflowException,
     AirflowSkipException,
+    BaseBranchOperator,
     KeywordParameters,
+    SkipMixin,
     context_merge,
 )
 from airflow.providers.common.compat.standard.operators import (
@@ -67,14 +69,6 @@ from airflow.providers.standard.version_compat import 
AIRFLOW_V_3_0_PLUS, AIRFLO
 from airflow.utils import hashlib_wrapper
 from airflow.utils.file import get_unique_dag_module_name
 
-if AIRFLOW_V_3_0_PLUS:
-    from airflow.providers.standard.operators.branch import BaseBranchOperator
-    from airflow.providers.standard.utils.skipmixin import SkipMixin
-else:
-    from airflow.models.skipmixin import SkipMixin
-    from airflow.operators.branch import BaseBranchOperator  # type: 
ignore[no-redef]
-
-
 log = logging.getLogger(__name__)
 
 if TYPE_CHECKING:
diff --git 
a/providers/standard/src/airflow/providers/standard/operators/weekday.py 
b/providers/standard/src/airflow/providers/standard/operators/weekday.py
index 68c2c7cbe2a..e663bc7e4f0 100644
--- a/providers/standard/src/airflow/providers/standard/operators/weekday.py
+++ b/providers/standard/src/airflow/providers/standard/operators/weekday.py
@@ -20,8 +20,7 @@ from __future__ import annotations
 from collections.abc import Iterable
 from typing import TYPE_CHECKING
 
-from airflow.providers.common.compat.sdk import timezone
-from airflow.providers.standard.operators.branch import BaseBranchOperator
+from airflow.providers.common.compat.sdk import BaseBranchOperator, timezone
 from airflow.providers.standard.utils.weekday import WeekDay
 
 if TYPE_CHECKING:
diff --git a/providers/standard/tests/unit/standard/operators/test_weekday.py 
b/providers/standard/tests/unit/standard/operators/test_weekday.py
index fbb6af3555c..9836a8b9efb 100644
--- a/providers/standard/tests/unit/standard/operators/test_weekday.py
+++ b/providers/standard/tests/unit/standard/operators/test_weekday.py
@@ -26,9 +26,9 @@ from sqlalchemy import delete
 from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance as TI
 from airflow.providers.common.compat.sdk import AirflowException
+from airflow.providers.common.compat.standard.utils import 
XCOM_SKIPMIXIN_FOLLOWED, XCOM_SKIPMIXIN_KEY
 from airflow.providers.standard.operators.empty import EmptyOperator
 from airflow.providers.standard.operators.weekday import 
BranchDayOfWeekOperator
-from airflow.providers.standard.utils.skipmixin import 
XCOM_SKIPMIXIN_FOLLOWED, XCOM_SKIPMIXIN_KEY
 from airflow.providers.standard.utils.weekday import WeekDay
 from airflow.timetables.base import DataInterval
 from airflow.utils import timezone
diff --git a/providers/standard/tests/unit/standard/utils/test_skipmixin.py 
b/providers/standard/tests/unit/standard/utils/test_skipmixin.py
index 22611a57d21..ae1dbfa5963 100644
--- a/providers/standard/tests/unit/standard/utils/test_skipmixin.py
+++ b/providers/standard/tests/unit/standard/utils/test_skipmixin.py
@@ -44,7 +44,7 @@ if AIRFLOW_V_3_0_PLUS:
     from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
 else:
     from airflow.decorators import task, task_group  # type: 
ignore[attr-defined,no-redef]
-    from airflow.models.skipmixin import SkipMixin
+    from airflow.models.skipmixin import SkipMixin  # type: ignore[no-redef]
 
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
 DEFAULT_DAG_RUN_ID = "test1"
diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst
index 3d25e470183..642cff1e34a 100644
--- a/task-sdk/docs/api.rst
+++ b/task-sdk/docs/api.rst
@@ -81,6 +81,8 @@ Bases
 -----
 .. autoapiclass:: airflow.sdk.BaseAsyncOperator
 
+.. autoapiclass:: airflow.sdk.BaseBranchOperator
+
 .. autoapiclass:: airflow.sdk.BaseOperator
 
 .. autoapiclass:: airflow.sdk.BaseSensorOperator
@@ -91,8 +93,12 @@ Bases
 
 .. autoapiclass:: airflow.sdk.BaseXCom
 
+.. autoapiclass:: airflow.sdk.BranchMixIn
+
 .. autoapiclass:: airflow.sdk.PokeReturnValue
 
+.. autoapiclass:: airflow.sdk.SkipMixin
+
 .. autoapiclass:: airflow.sdk.BaseHook
 
 Callbacks
diff --git a/task-sdk/src/airflow/sdk/__init__.py 
b/task-sdk/src/airflow/sdk/__init__.py
index 7fd3f962bb5..22c12e74415 100644
--- a/task-sdk/src/airflow/sdk/__init__.py
+++ b/task-sdk/src/airflow/sdk/__init__.py
@@ -28,11 +28,13 @@ __all__ = [
     "AssetWatcher",
     "AsyncCallback",
     "BaseAsyncOperator",
+    "BaseBranchOperator",
     "BaseHook",
     "BaseNotifier",
     "BaseOperator",
     "BaseOperatorLink",
     "BaseSensorOperator",
+    "BranchMixIn",
     "Connection",
     "Context",
     "CronDataIntervalTimetable",
@@ -60,6 +62,7 @@ __all__ = [
     "PartitionMapper",
     "PokeReturnValue",
     "QuarterlyMapper",
+    "SkipMixin",
     "SyncCallback",
     "TaskGroup",
     "TaskInstanceState",
@@ -90,6 +93,7 @@ __version__ = "1.2.0"
 
 if TYPE_CHECKING:
     from airflow.sdk.api.datamodels._generated import DagRunState, 
TaskInstanceState, TriggerRule, WeightRule
+    from airflow.sdk.bases.branch import BaseBranchOperator, BranchMixIn
     from airflow.sdk.bases.hook import BaseHook
     from airflow.sdk.bases.notifier import BaseNotifier
     from airflow.sdk.bases.operator import (
@@ -101,6 +105,7 @@ if TYPE_CHECKING:
     )
     from airflow.sdk.bases.operatorlink import BaseOperatorLink
     from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue
+    from airflow.sdk.bases.skipmixin import SkipMixin
     from airflow.sdk.configuration import AirflowSDKConfigParser
     from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, 
AssetAny, AssetWatcher
     from airflow.sdk.definitions.asset.decorators import asset
@@ -158,11 +163,13 @@ __lazy_imports: dict[str, str] = {
     "AssetWatcher": ".definitions.asset",
     "AsyncCallback": ".definitions.callback",
     "BaseAsyncOperator": ".bases.operator",
+    "BaseBranchOperator": ".bases.branch",
     "BaseHook": ".bases.hook",
     "BaseNotifier": ".bases.notifier",
     "BaseOperator": ".bases.operator",
     "BaseOperatorLink": ".bases.operatorlink",
     "BaseSensorOperator": ".bases.sensor",
+    "BranchMixIn": ".bases.branch",
     "Connection": ".definitions.connection",
     "Context": ".definitions.context",
     "CronDataIntervalTimetable": ".definitions.timetables.interval",
@@ -191,6 +198,7 @@ __lazy_imports: dict[str, str] = {
     "PokeReturnValue": ".bases.sensor",
     "QuarterlyMapper": ".definitions.partition_mappers.temporal",
     "SecretCache": ".execution_time.cache",
+    "SkipMixin": ".bases.skipmixin",
     "SyncCallback": ".definitions.callback",
     "TaskGroup": ".definitions.taskgroup",
     "TaskInstanceState": ".api.datamodels._generated",
diff --git a/task-sdk/src/airflow/sdk/__init__.pyi 
b/task-sdk/src/airflow/sdk/__init__.pyi
index bbea94cb077..60b87aeec9a 100644
--- a/task-sdk/src/airflow/sdk/__init__.pyi
+++ b/task-sdk/src/airflow/sdk/__init__.pyi
@@ -21,6 +21,10 @@ from airflow.sdk.api.datamodels._generated import (
     TriggerRule as TriggerRule,
     WeightRule as WeightRule,
 )
+from airflow.sdk.bases.branch import (
+    BaseBranchOperator as BaseBranchOperator,
+    BranchMixIn as BranchMixIn,
+)
 from airflow.sdk.bases.hook import BaseHook as BaseHook
 from airflow.sdk.bases.notifier import BaseNotifier as BaseNotifier
 from airflow.sdk.bases.operator import (
@@ -35,6 +39,7 @@ from airflow.sdk.bases.sensor import (
     BaseSensorOperator as BaseSensorOperator,
     PokeReturnValue as PokeReturnValue,
 )
+from airflow.sdk.bases.skipmixin import SkipMixin as SkipMixin
 from airflow.sdk.configuration import AirflowSDKConfigParser
 from airflow.sdk.definitions.asset import (
     Asset as Asset,
@@ -100,11 +105,13 @@ __all__ = [
     "AssetOrTimeSchedule",
     "AssetWatcher",
     "BaseAsyncOperator",
+    "BaseBranchOperator",
     "BaseHook",
     "BaseNotifier",
     "BaseOperator",
     "BaseOperatorLink",
     "BaseSensorOperator",
+    "BranchMixIn",
     "Connection",
     "Context",
     "CronDataIntervalTimetable",
@@ -130,6 +137,7 @@ __all__ = [
     "PartitionMapper",
     "QuarterlyMapper",
     "SecretCache",
+    "SkipMixin",
     "TaskGroup",
     "TaskInstanceState",
     "TriggerRule",
diff --git 
a/providers/standard/src/airflow/providers/standard/operators/branch.py 
b/task-sdk/src/airflow/sdk/bases/branch.py
similarity index 91%
copy from providers/standard/src/airflow/providers/standard/operators/branch.py
copy to task-sdk/src/airflow/sdk/bases/branch.py
index 64796bd0f7a..97919b02ca4 100644
--- a/providers/standard/src/airflow/providers/standard/operators/branch.py
+++ b/task-sdk/src/airflow/sdk/bases/branch.py
@@ -22,15 +22,11 @@ from __future__ import annotations
 from collections.abc import Iterable
 from typing import TYPE_CHECKING
 
-from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, 
BaseOperator
-
-if AIRFLOW_V_3_0_PLUS:
-    from airflow.providers.standard.utils.skipmixin import SkipMixin
-else:
-    from airflow.models.skipmixin import SkipMixin
+from airflow.sdk.bases.operator import BaseOperator
+from airflow.sdk.bases.skipmixin import SkipMixin
 
 if TYPE_CHECKING:
-    from airflow.providers.common.compat.sdk import Context
+    from airflow.sdk.definitions.context import Context
     from airflow.sdk.types import RuntimeTaskInstanceProtocol
 
 
@@ -43,7 +39,6 @@ class BranchMixIn(SkipMixin):
         """Implement the handling of branching including logging."""
         self.log.info("Branch into %s", branches_to_execute)
         if branches_to_execute is None:
-            # When None is returned, skip all downstream tasks
             self.skip_all_except(context["ti"], None)
         else:
             branch_task_ids = self._expand_task_group_roots(context["ti"], 
branches_to_execute)
diff --git a/task-sdk/src/airflow/sdk/bases/skipmixin.py 
b/task-sdk/src/airflow/sdk/bases/skipmixin.py
new file mode 100644
index 00000000000..f4dd25f5f84
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/bases/skipmixin.py
@@ -0,0 +1,168 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from collections.abc import Iterable, Sequence
+from types import GeneratorType
+from typing import TYPE_CHECKING
+
+from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin
+from airflow.sdk.exceptions import AirflowException, DownstreamTasksSkipped
+
+if TYPE_CHECKING:
+    from airflow.sdk.definitions._internal.node import DAGNode
+    from airflow.sdk.types import Operator, RuntimeTaskInstanceProtocol
+
+XCOM_SKIPMIXIN_KEY = "skipmixin_key"
+XCOM_SKIPMIXIN_SKIPPED = "skipped"
+XCOM_SKIPMIXIN_FOLLOWED = "followed"
+
+
+def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]:
+    # Circular: BaseOperator imports SkipMixin
+    from airflow.sdk.bases.operator import BaseOperator
+    from airflow.sdk.definitions.mappedoperator import MappedOperator
+
+    return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))]
+
+
+class SkipMixin(LoggingMixin):
+    """A Mixin to skip Tasks Instances."""
+
+    @staticmethod
+    def _set_state_to_skipped(
+        tasks: Sequence[str | tuple[str, int]],
+        map_index: int | None,
+    ) -> None:
+        """Set state of task instances to skipped from the same dag run."""
+        # Only for non-mapped tasks — future mapped tasks have not been 
expanded yet
+        # and are handled by NotPreviouslySkippedDep.
+        if tasks and map_index == -1:
+            raise DownstreamTasksSkipped(tasks=tasks)
+
+    def skip(
+        self,
+        ti: RuntimeTaskInstanceProtocol,
+        tasks: Iterable[DAGNode],
+    ):
+        """
+        Set tasks instances to skipped from the same dag run.
+
+        If this instance has a ``task_id`` attribute, stores skipped task IDs 
to XCom
+        so that NotPreviouslySkippedDep knows these tasks should be skipped 
when cleared.
+        """
+        task_id: str | None = getattr(self, "task_id", None)
+        task_list = _ensure_tasks(tasks)
+        if not task_list:
+            return
+
+        task_ids_list = [d.task_id for d in task_list]
+
+        if task_id is not None:
+            ti.xcom_push(
+                key=XCOM_SKIPMIXIN_KEY,
+                value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list},
+            )
+
+        self._set_state_to_skipped(task_ids_list, ti.map_index)
+
+    def skip_all_except(
+        self,
+        ti: RuntimeTaskInstanceProtocol,
+        branch_task_ids: None | str | Iterable[str],
+    ):
+        """
+        Implement the logic for a branching operator.
+
+        Given a single task ID or list of task IDs to follow, this skips all 
other tasks
+        immediately downstream of this operator.
+
+        branch_task_ids is stored to XCom so that NotPreviouslySkippedDep 
knows skipped tasks or
+        newly added tasks should be skipped when they are cleared.
+        """
+        if branch_task_ids and isinstance(branch_task_ids, GeneratorType):
+            branch_task_ids = list(branch_task_ids)
+        log = self.log  # Note: need to catch logger form instance, static 
logger breaks pytest
+        if isinstance(branch_task_ids, str):
+            branch_task_id_set = {branch_task_ids}
+        elif isinstance(branch_task_ids, Iterable):
+            # Handle the case where invalid values are passed as elements of 
an Iterable
+            # Non-string values are considered invalid elements
+            branch_task_id_set = set(branch_task_ids)
+            invalid_task_ids_type = {
+                (bti, type(bti).__name__) for bti in branch_task_id_set if not 
isinstance(bti, str)
+            }
+            if invalid_task_ids_type:
+                raise AirflowException(
+                    f"Unable to branch to the specified tasks. "
+                    f"The branching function returned invalid 
'branch_task_ids': {invalid_task_ids_type}. "
+                    f"Please check that your function returns an Iterable of 
valid task IDs that exist in your DAG."
+                )
+        elif branch_task_ids is None:
+            branch_task_id_set = set()
+        else:
+            raise AirflowException(
+                "'branch_task_ids' must be either None, a task ID, or an 
Iterable of IDs, "
+                f"but got {type(branch_task_ids).__name__!r}."
+            )
+
+        log.info("Following branch %s", branch_task_id_set)
+
+        if TYPE_CHECKING:
+            assert ti.task
+
+        task = ti.task
+        dag = ti.task.dag
+
+        valid_task_ids = set(dag.task_ids)
+        invalid_task_ids = branch_task_id_set - valid_task_ids
+        if invalid_task_ids:
+            raise AirflowException(
+                "'branch_task_ids' must contain only valid task_ids. "
+                f"Invalid tasks found: {invalid_task_ids}."
+            )
+
+        downstream_tasks = _ensure_tasks(task.downstream_list)
+
+        if downstream_tasks:
+            # For a branching workflow that looks like this, when "branch" 
does skip_all_except("task1"),
+            # we intuitively expect both "task1" and "join" to execute even 
though strictly speaking,
+            # "join" is also immediately downstream of "branch" and should 
have been skipped. Therefore,
+            # we need a special case here for such empty branches: Check 
downstream tasks of branch_task_ids.
+            # In case the task to skip is also downstream of branch_task_ids, 
we add it to branch_task_ids and
+            # exclude it from skipping.
+            #
+            # branch  ----->  join
+            #   \            ^
+            #     v        /
+            #       task1
+            #
+            for branch_task_id in list(branch_task_id_set):
+                
branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
+
+            skip_tasks = [
+                (t.task_id, ti.map_index) for t in downstream_tasks if 
t.task_id not in branch_task_id_set
+            ]
+
+            follow_task_ids = [t.task_id for t in downstream_tasks if 
t.task_id in branch_task_id_set]
+            log.info("Skipping tasks %s", skip_tasks)
+            ti.xcom_push(
+                key=XCOM_SKIPMIXIN_KEY,
+                value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids},
+            )
+            self._set_state_to_skipped(skip_tasks, ti.map_index)  # type: 
ignore[arg-type]
diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py 
b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
index f217306f4da..abc4c86ed85 100644
--- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -214,8 +214,8 @@ class OperatorPartial:
 
     def _expand(self, expand_input: ExpandInput, *, strict: bool) -> 
MappedOperator:
         from airflow.providers.standard.operators.empty import EmptyOperator
-        from airflow.providers.standard.utils.skipmixin import SkipMixin
         from airflow.sdk import BaseSensorOperator
+        from airflow.sdk.bases.skipmixin import SkipMixin
 
         self._expand_called = True
         ensure_xcomarg_return_value(expand_input.value)
diff --git a/task-sdk/tests/task_sdk/bases/test_branch.py 
b/task-sdk/tests/task_sdk/bases/test_branch.py
new file mode 100644
index 00000000000..ff3b3515b45
--- /dev/null
+++ b/task-sdk/tests/task_sdk/bases/test_branch.py
@@ -0,0 +1,121 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, Mock
+
+import pytest
+
+from airflow.sdk.bases.branch import BaseBranchOperator, BranchMixIn
+from airflow.sdk.bases.operator import BaseOperator
+from airflow.sdk.definitions.dag import DAG
+from airflow.sdk.definitions.taskgroup import TaskGroup
+from airflow.sdk.exceptions import DownstreamTasksSkipped
+from airflow.sdk.types import RuntimeTaskInstanceProtocol
+
+
+class TestBranchMixIn:
+    def test_do_branch_with_none_skips_all(self):
+        """do_branch(context, None) should skip all downstream tasks."""
+        mixin = BranchMixIn()
+
+        downstream1 = MagicMock(spec=BaseOperator, task_id="down1")
+        downstream2 = MagicMock(spec=BaseOperator, task_id="down2")
+
+        mock_task = MagicMock(spec=BaseOperator)
+        mock_task.downstream_list = [downstream1, downstream2]
+        mock_dag = MagicMock(spec=DAG)
+        mock_dag.task_ids = ["branch", "down1", "down2"]
+        mock_task.dag = mock_dag
+
+        ti = Mock(spec=RuntimeTaskInstanceProtocol, map_index=-1, 
task=mock_task)
+        context = {"ti": ti}
+
+        with pytest.raises(DownstreamTasksSkipped) as exc_info:
+            mixin.do_branch(context, None)
+
+        assert set(exc_info.value.tasks) == {("down1", -1), ("down2", -1)}
+
+    def test_do_branch_with_string(self):
+        """do_branch(context, 'down1') should follow down1 and skip others."""
+        mixin = BranchMixIn()
+
+        downstream1 = MagicMock(spec=BaseOperator, task_id="down1")
+        downstream1.get_flat_relative_ids.return_value = set()
+        downstream2 = MagicMock(spec=BaseOperator, task_id="down2")
+
+        mock_task = MagicMock(spec=BaseOperator)
+        mock_task.downstream_list = [downstream1, downstream2]
+        mock_dag = MagicMock(spec=DAG)
+        mock_dag.task_ids = ["branch", "down1", "down2"]
+        mock_dag.get_task.return_value = downstream1
+        mock_dag.task_group_dict = {}
+        mock_task.dag = mock_dag
+
+        ti = Mock(spec=RuntimeTaskInstanceProtocol, map_index=-1, 
task=mock_task)
+        context = {"ti": ti}
+
+        with pytest.raises(DownstreamTasksSkipped) as exc_info:
+            mixin.do_branch(context, "down1")
+
+        assert exc_info.value.tasks == [("down2", -1)]
+
+    def test_expand_task_group_roots(self):
+        """_expand_task_group_roots should expand task group into root task 
IDs."""
+        mixin = BranchMixIn()
+
+        mock_tg = MagicMock(spec=TaskGroup)
+        mock_root1 = MagicMock(spec=BaseOperator, task_id="tg.root1")
+        mock_root2 = MagicMock(spec=BaseOperator, task_id="tg.root2")
+        mock_tg.roots = [mock_root1, mock_root2]
+        mock_tg.group_id = "tg"
+
+        mock_dag = MagicMock(spec=DAG)
+        mock_dag.task_group_dict = {"tg": mock_tg}
+        mock_task = MagicMock(spec=BaseOperator)
+        mock_task.dag = mock_dag
+
+        ti = Mock(spec=RuntimeTaskInstanceProtocol, task=mock_task)
+
+        result = list(mixin._expand_task_group_roots(ti, ["tg"]))
+        assert result == ["tg.root1", "tg.root2"]
+
+    def test_expand_task_group_roots_passthrough(self):
+        """_expand_task_group_roots should pass through regular task IDs."""
+        mixin = BranchMixIn()
+
+        mock_dag = MagicMock(spec=DAG)
+        mock_dag.task_group_dict = {}
+        mock_task = MagicMock(spec=BaseOperator)
+        mock_task.dag = mock_dag
+
+        ti = Mock(spec=RuntimeTaskInstanceProtocol, task=mock_task)
+
+        result = list(mixin._expand_task_group_roots(ti, "regular_task"))
+        assert result == ["regular_task"]
+
+
+class TestBaseBranchOperator:
+    def test_choose_branch_not_implemented(self):
+        """BaseBranchOperator.choose_branch should raise 
NotImplementedError."""
+        op = BaseBranchOperator.__new__(BaseBranchOperator)
+        with pytest.raises(NotImplementedError):
+            op.choose_branch({})
+
+    def test_inherits_from_skipmixin_flag(self):
+        assert BaseBranchOperator.inherits_from_skipmixin is True
diff --git a/task-sdk/tests/task_sdk/bases/test_skipmixin.py 
b/task-sdk/tests/task_sdk/bases/test_skipmixin.py
new file mode 100644
index 00000000000..9518270e229
--- /dev/null
+++ b/task-sdk/tests/task_sdk/bases/test_skipmixin.py
@@ -0,0 +1,202 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, Mock
+
+import pytest
+
+from airflow.sdk.bases.operator import BaseOperator
+from airflow.sdk.bases.skipmixin import (
+    XCOM_SKIPMIXIN_FOLLOWED,
+    XCOM_SKIPMIXIN_KEY,
+    XCOM_SKIPMIXIN_SKIPPED,
+    SkipMixin,
+    _ensure_tasks,
+)
+from airflow.sdk.definitions.dag import DAG
+from airflow.sdk.definitions.mappedoperator import MappedOperator
+from airflow.sdk.exceptions import AirflowException, DownstreamTasksSkipped
+from airflow.sdk.types import RuntimeTaskInstanceProtocol
+
+
+class TestEnsureTasks:
+    def test_filters_non_operators(self):
+        """Only BaseOperator and MappedOperator instances should be 
returned."""
+        op = MagicMock(spec=BaseOperator)
+        mapped = MagicMock(spec=MappedOperator)
+        other = MagicMock()  # not an operator
+
+        result = _ensure_tasks([op, mapped, other])
+        assert result == [op, mapped]
+
+    def test_empty_input(self):
+        assert _ensure_tasks([]) == []
+
+
+class TestSkipMixin:
+    def test_skip_pushes_xcom_and_raises(self):
+        """skip() should push skipped task IDs to XCom and raise 
DownstreamTasksSkipped."""
+        mixin = SkipMixin()
+        mixin.task_id = "branch_task"
+
+        task1 = MagicMock(spec=BaseOperator, task_id="task1")
+        ti = Mock(spec=RuntimeTaskInstanceProtocol, map_index=-1)
+
+        with pytest.raises(DownstreamTasksSkipped) as exc_info:
+            mixin.skip(ti=ti, tasks=[task1])
+
+        ti.xcom_push.assert_called_once_with(
+            key=XCOM_SKIPMIXIN_KEY,
+            value={XCOM_SKIPMIXIN_SKIPPED: ["task1"]},
+        )
+        assert exc_info.value.tasks == ["task1"]
+
+    def test_skip_none_tasks(self):
+        """skip() should return None when no tasks are provided."""
+        ti = Mock(spec=RuntimeTaskInstanceProtocol)
+        assert SkipMixin().skip(ti=ti, tasks=[]) is None
+
+    def test_skip_mapped_task_does_not_raise(self):
+        """skip() should not raise when map_index != -1 (mapped tasks)."""
+        mixin = SkipMixin()
+        mixin.task_id = "branch_task"
+
+        task1 = MagicMock(spec=BaseOperator, task_id="task1")
+        ti = Mock(spec=RuntimeTaskInstanceProtocol, map_index=2)
+
+        # Should not raise — mapped tasks are handled by 
NotPreviouslySkippedDep
+        mixin.skip(ti=ti, tasks=[task1])
+
+    def test_skip_without_task_id_does_not_push_xcom(self):
+        """skip() should not push XCom when the mixin has no task_id."""
+        mixin = SkipMixin()
+        # No task_id attribute set
+
+        task1 = MagicMock(spec=BaseOperator, task_id="task1")
+        ti = Mock(spec=RuntimeTaskInstanceProtocol, map_index=-1)
+
+        with pytest.raises(DownstreamTasksSkipped):
+            mixin.skip(ti=ti, tasks=[task1])
+
+        ti.xcom_push.assert_not_called()
+
+    def test_skip_all_except_with_none_skips_all(self):
+        """skip_all_except(None) should skip all downstream tasks."""
+        mixin = SkipMixin()
+
+        downstream1 = MagicMock(spec=BaseOperator, task_id="down1")
+        downstream2 = MagicMock(spec=BaseOperator, task_id="down2")
+
+        mock_task = MagicMock(spec=BaseOperator)
+        mock_task.downstream_list = [downstream1, downstream2]
+        mock_dag = MagicMock(spec=DAG)
+        mock_dag.task_ids = ["task1", "down1", "down2"]
+        mock_task.dag = mock_dag
+
+        ti = Mock(spec=RuntimeTaskInstanceProtocol, map_index=-1, 
task=mock_task)
+
+        with pytest.raises(DownstreamTasksSkipped) as exc_info:
+            mixin.skip_all_except(ti=ti, branch_task_ids=None)
+
+        assert set(exc_info.value.tasks) == {("down1", -1), ("down2", -1)}
+
+    def test_skip_all_except_with_string_branch(self):
+        """skip_all_except('down1') should skip down2 but not down1."""
+        mixin = SkipMixin()
+
+        downstream1 = MagicMock(spec=BaseOperator, task_id="down1")
+        downstream1.get_flat_relative_ids.return_value = set()
+        downstream2 = MagicMock(spec=BaseOperator, task_id="down2")
+
+        mock_task = MagicMock(spec=BaseOperator)
+        mock_task.downstream_list = [downstream1, downstream2]
+        mock_dag = MagicMock(spec=DAG)
+        mock_dag.task_ids = ["task1", "down1", "down2"]
+        mock_dag.get_task.return_value = downstream1
+        mock_task.dag = mock_dag
+
+        ti = Mock(spec=RuntimeTaskInstanceProtocol, map_index=-1, 
task=mock_task)
+
+        with pytest.raises(DownstreamTasksSkipped) as exc_info:
+            mixin.skip_all_except(ti=ti, branch_task_ids="down1")
+
+        assert exc_info.value.tasks == [("down2", -1)]
+        ti.xcom_push.assert_called_once_with(
+            key=XCOM_SKIPMIXIN_KEY,
+            value={XCOM_SKIPMIXIN_FOLLOWED: ["down1"]},
+        )
+
+    def test_skip_all_except_invalid_type_raises(self):
+        """skip_all_except() should raise when branch_task_ids is an invalid 
type."""
+        mixin = SkipMixin()
+        mock_task = MagicMock(spec=BaseOperator)
+        mock_task.dag = MagicMock(spec=DAG)
+        ti = Mock(spec=RuntimeTaskInstanceProtocol, task=mock_task)
+
+        with pytest.raises(AirflowException, match="must be either None, a 
task ID, or an Iterable"):
+            mixin.skip_all_except(ti=ti, branch_task_ids=42)
+
+    def test_skip_all_except_invalid_iterable_element_raises(self):
+        """skip_all_except() should raise when branch_task_ids contains 
non-string elements."""
+        mixin = SkipMixin()
+        mock_task = MagicMock(spec=BaseOperator)
+        mock_task.dag = MagicMock(spec=DAG)
+        ti = Mock(spec=RuntimeTaskInstanceProtocol, task=mock_task)
+
+        with pytest.raises(AirflowException, match="invalid 
'branch_task_ids'"):
+            mixin.skip_all_except(ti=ti, branch_task_ids=["task1", 42])
+
+    def test_skip_all_except_invalid_task_id_raises(self):
+        """skip_all_except() should raise when branch_task_ids contains 
non-existent task IDs."""
+        mixin = SkipMixin()
+
+        mock_task = MagicMock(spec=BaseOperator)
+        mock_dag = MagicMock(spec=DAG)
+        mock_dag.task_ids = ["task1", "down1"]
+        mock_task.dag = mock_dag
+
+        ti = Mock(spec=RuntimeTaskInstanceProtocol, task=mock_task)
+
+        with pytest.raises(AirflowException, match="must contain only valid 
task_ids"):
+            mixin.skip_all_except(ti=ti, branch_task_ids="nonexistent")
+
+    def test_skip_all_except_generator_branch_task_ids(self):
+        """skip_all_except() should handle generator branch_task_ids."""
+        mixin = SkipMixin()
+
+        downstream1 = MagicMock(spec=BaseOperator, task_id="down1")
+        downstream1.get_flat_relative_ids.return_value = set()
+        downstream2 = MagicMock(spec=BaseOperator, task_id="down2")
+
+        mock_task = MagicMock(spec=BaseOperator)
+        mock_task.downstream_list = [downstream1, downstream2]
+        mock_dag = MagicMock(spec=DAG)
+        mock_dag.task_ids = ["task1", "down1", "down2"]
+        mock_dag.get_task.return_value = downstream1
+        mock_task.dag = mock_dag
+
+        ti = Mock(spec=RuntimeTaskInstanceProtocol, map_index=-1, 
task=mock_task)
+
+        def gen():
+            yield "down1"
+
+        with pytest.raises(DownstreamTasksSkipped) as exc_info:
+            mixin.skip_all_except(ti=ti, branch_task_ids=gen())
+
+        assert exc_info.value.tasks == [("down2", -1)]
diff --git a/task-sdk/tests/task_sdk/docs/test_public_api.py 
b/task-sdk/tests/task_sdk/docs/test_public_api.py
index e8d7dd68e9e..98391927f8a 100644
--- a/task-sdk/tests/task_sdk/docs/test_public_api.py
+++ b/task-sdk/tests/task_sdk/docs/test_public_api.py
@@ -64,6 +64,7 @@ def test_airflow_sdk_no_unexpected_exports():
         "crypto",
         "providers_manager_runtime",
         "lineage",
+        "types",
     }
     unexpected = actual - public - ignore
     assert not unexpected, f"Unexpected exports in airflow.sdk: 
{sorted(unexpected)}"


Reply via email to