This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 99d88c13a0e Add default parameter to task and asset state get() 
accessors (#67842)
99d88c13a0e is described below

commit 99d88c13a0eaaa5398a1f6e6dc6a48b0e59dc1d2
Author: Amogh Desai <[email protected]>
AuthorDate: Wed Jun 3 09:02:44 2026 +0530

    Add default parameter to task and asset state get() accessors (#67842)
---
 .../airflow/example_dags/example_asset_store.py    |  5 +--
 task-sdk/src/airflow/sdk/bases/resumablemixin.py   |  9 ++++-
 task-sdk/src/airflow/sdk/execution_time/context.py | 28 ++++++++-----
 .../tests/task_sdk/bases/test_resumablemixin.py    | 22 +++++++++++
 .../tests/task_sdk/execution_time/test_context.py  | 46 ++++++++++++++++++++++
 .../task_sdk/execution_time/test_task_runner.py    | 39 ++++++++++++++++++
 6 files changed, 134 insertions(+), 15 deletions(-)

diff --git a/airflow-core/src/airflow/example_dags/example_asset_store.py 
b/airflow-core/src/airflow/example_dags/example_asset_store.py
index 6e6d30e3501..febef84e6ad 100644
--- a/airflow-core/src/airflow/example_dags/example_asset_store.py
+++ b/airflow-core/src/airflow/example_dags/example_asset_store.py
@@ -53,14 +53,13 @@ with DAG(
     def load(asset_store=None):
         state = asset_store[ORDERS]
 
-        # First run: watermark is None — fall back to epoch start.
-        watermark = state.get("watermark") or "2026-01-01T00:00:00+00:00"
+        watermark = state.get("watermark", default="2026-01-01T00:00:00+00:00")
         records = _fetch_records(since=watermark)
         row_count = len(records)
 
         now = datetime.now(tz=timezone.utc).isoformat()
         state.set("watermark", now)
-        state.set("total_runs", (state.get("total_runs") or 0) + 1)
+        state.set("total_runs", state.get("total_runs", default=0) + 1)
         state.set(
             "last_run_summary",
             {
diff --git a/task-sdk/src/airflow/sdk/bases/resumablemixin.py 
b/task-sdk/src/airflow/sdk/bases/resumablemixin.py
index 55e2d97ff7a..68620924bcf 100644
--- a/task-sdk/src/airflow/sdk/bases/resumablemixin.py
+++ b/task-sdk/src/airflow/sdk/bases/resumablemixin.py
@@ -126,14 +126,19 @@ class ResumableJobMixin:
 
         external_id = self.submit_job(context)
 
-        if task_store is not None:
+        if task_store is not None and external_id is not None:
             task_store.set(self.external_id_key, external_id)
 
         self.poll_until_complete(external_id, context)
         return self.get_job_result(external_id, context)
 
     def submit_job(self, context: Context) -> JsonValue:
-        """Submit the job to the external system. Return its external ID."""
+        """
+        Submit the job to the external system. Return its external ID.
+
+        The returned ID must not be ``None``, a ``None`` return is treated as
+        "no ID available" and the ID will not be persisted to task state.
+        """
         raise NotImplementedError
 
     def get_job_status(self, external_id: JsonValue) -> str:
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py 
b/task-sdk/src/airflow/sdk/execution_time/context.py
index 5971c076ccb..7a601e6f320 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -506,9 +506,9 @@ class TaskStoreAccessor:
     # is not implemented yet cos it's unclear whether task state values will be
     # used in templates.
 
-    def get(self, key: str) -> JsonValue:
+    def get(self, key: str, default: JsonValue = None) -> JsonValue:
         """
-        Return the stored value, or ``None`` if the key does not exist.
+        Return the stored value, or ``default`` if the key does not exist.
 
         Supported types: ``str``, ``int``, ``float``, ``bool``, ``list``, 
``dict``.
         ``datetime`` is not JSON-serializable; store it as 
``value.isoformat()`` and
@@ -535,12 +535,14 @@ class TaskStoreAccessor:
                     key,
                 )
             return stored
-        return None
+        return default
 
     def set(self, key: str, value: JsonValue, *, retention: timedelta | None = 
None) -> None:
         """
         Write or overwrite the value for the given key.
 
+        ``value`` must not be ``None``.
+
         ``retention`` is an optional key that controls when this key expires:
 
         - ``timedelta(...)`` — expire after the given duration (e.g. 
``timedelta(hours=6)``).
@@ -550,6 +552,9 @@ class TaskStoreAccessor:
         from airflow.sdk.execution_time.comms import SetTaskStore
         from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
 
+        if value is None:
+            raise ValueError("Cannot set value as None")
+
         # expires_at is always resolved on the worker in UTC before being sent.
         now = datetime.now(tz=timezone.utc)
         if retention is NEVER_EXPIRE:
@@ -640,8 +645,8 @@ class AssetStoreAccessor:
             return f"<AssetStoreAccessor name={self._name!r}>"
         return f"<AssetStoreAccessor uri={self._uri!r}>"
 
-    def get(self, key: str) -> JsonValue:
-        """Return the stored value, or ``None`` if the key does not exist."""
+    def get(self, key: str, default: JsonValue = None) -> JsonValue:
+        """Return the stored value, or ``default`` if the key does not 
exist."""
         from airflow.sdk.execution_time.comms import (
             AssetStoreResult,
             ErrorResponse,
@@ -674,13 +679,16 @@ class AssetStoreAccessor:
                     self._name or self._uri,
                 )
             return stored
-        return None
+        return default
 
     def set(self, key: str, value: JsonValue) -> None:
-        """Write or overwrite the value for the given key."""
+        """Write or overwrite the value for the given key. ``value`` must not 
be ``None``."""
         from airflow.sdk.execution_time.comms import SetAssetStoreByName, 
SetAssetStoreByUri, ToSupervisor
         from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
 
+        if value is None:
+            raise ValueError("Cannot set value as None")
+
         # if custom backend is configured, store the value on the custom 
backend, and return the reference
         # to the stored value to store in the DB
         backend = _get_worker_state_backend()
@@ -797,9 +805,9 @@ class AssetStoreAccessors:
             return next(iter(self._by_name.values()))
         return next(iter(self._by_uri.values()))
 
-    def get(self, key: str) -> JsonValue:
-        """Return the stored value for the single-inlet task, or ``None`` if 
not found."""
-        return self._single_accessor().get(key)
+    def get(self, key: str, default: JsonValue = None) -> JsonValue:
+        """Return the stored value for the single-inlet or single-outlet task, 
or ``default`` if not found."""
+        return self._single_accessor().get(key, default)
 
     def set(self, key: str, value: JsonValue) -> None:
         """Write or overwrite the value for the single-inlet task."""
diff --git a/task-sdk/tests/task_sdk/bases/test_resumablemixin.py 
b/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
index 4d0f8fa8cdb..6796ec1029c 100644
--- a/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
+++ b/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
@@ -164,6 +164,28 @@ class TestRetryWithDifferentJobStatuses:
         assert op.polled_ids == ["job-002"]
 
 
+class TestNoneExternalId:
+    def test_none_external_id_is_not_stored(self):
+        """submit_job() returning None must not call task_state.set()."""
+
+        class NoneIdOp(ConcreteResumableOperator):
+            def submit_job(self, context) -> JsonValue:
+                return None
+
+            def poll_until_complete(self, external_id, context) -> None:
+                pass
+
+            def get_job_result(self, external_id, context) -> str:
+                return "done"
+
+        op = NoneIdOp(task_id="test_task")
+        task_state = FakeTaskState()
+
+        op.execute_resumable(make_context(task_state))
+
+        assert task_state._store == {}
+
+
 class TestExternalIdKey:
     def test_custom_key_used_for_storage_and_retrieval(self):
         class CustomKeyOp(ConcreteResumableOperator):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py 
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index 2531d003cc0..9bcaa513d90 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -1123,6 +1123,24 @@ class TestTaskStoreAccessor:
 
         assert result is None
 
+    def test_get_returns_default_when_key_missing(self, mock_supervisor_comms):
+        mock_supervisor_comms.send.return_value = ErrorResponse(
+            error=ErrorType.TASK_STORE_NOT_FOUND, detail={"key": "job_id"}
+        )
+
+        result = TaskStoreAccessor(ti_id=self.TI_ID, 
scope=self.SCOPE).get("job_id", default="default-id")
+
+        assert result == "default-id"
+
+    def test_get_ignores_default_when_key_exists(self, mock_supervisor_comms):
+        mock_supervisor_comms.send.return_value = 
TaskStoreResult(value="job-001")
+
+        result = TaskStoreAccessor(ti_id=self.TI_ID, scope=self.SCOPE).get(
+            "job_id", default="do-not-start-here"
+        )
+
+        assert result == "job-001"
+
     def test_get_raises_on_error(self, mock_supervisor_comms):
         mock_supervisor_comms.send.return_value = ErrorResponse(
             error=ErrorType.GENERIC_ERROR, detail={"message": "server error"}
@@ -1131,6 +1149,10 @@ class TestTaskStoreAccessor:
         with pytest.raises(AirflowRuntimeError):
             TaskStoreAccessor(ti_id=self.TI_ID, 
scope=self.SCOPE).get("some_key")
 
+    def test_set_none_raises(self, mock_supervisor_comms):
+        with pytest.raises(ValueError, match="Cannot set value as None"):
+            TaskStoreAccessor(ti_id=self.TI_ID, 
scope=self.SCOPE).set("job_id", None)
+
     def test_set_operation_with_global_retention(self, mock_supervisor_comms, 
time_machine):
         """set() with no retention uses global default_retention_days 
config."""
 
@@ -1287,6 +1309,26 @@ class TestAssetStoreAccessor:
 
         assert result is None
 
+    def test_get_returns_default_when_key_missing(self, mock_supervisor_comms):
+        mock_supervisor_comms.send.return_value = ErrorResponse(
+            error=ErrorType.ASSET_STORE_NOT_FOUND, detail={"key": "watermark"}
+        )
+
+        result = AssetStoreAccessor(name=self.ASSET_NAME).get(
+            "watermark", default="2026-01-01T00:00:00+00:00"
+        )
+
+        assert result == "2026-01-01T00:00:00+00:00"
+
+    def test_get_ignores_default_when_key_exists(self, mock_supervisor_comms):
+        mock_supervisor_comms.send.return_value = 
AssetStoreResult(value="2026-06-01T00:00:00+00:00")
+
+        result = AssetStoreAccessor(name=self.ASSET_NAME).get(
+            "watermark", default="2026-01-01T00:00:00+00:00"
+        )
+
+        assert result == "2026-06-01T00:00:00+00:00"
+
     def test_get_raises_on_error(self, mock_supervisor_comms):
         mock_supervisor_comms.send.return_value = ErrorResponse(
             error=ErrorType.GENERIC_ERROR, detail={"message": "server error"}
@@ -1304,6 +1346,10 @@ class TestAssetStoreAccessor:
             SetAssetStoreByName(name=self.ASSET_NAME, key="watermark", 
value="2026-04-30T00:00:00Z")
         )
 
+    def test_set_none_raises(self, mock_supervisor_comms):
+        with pytest.raises(ValueError, match="Cannot set value as None"):
+            AssetStoreAccessor(name=self.ASSET_NAME).set("watermark", None)
+
     def test_delete_operation(self, mock_supervisor_comms):
         mock_supervisor_comms.send.return_value = OKResponse(ok=True)
 
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 893419cc25c..f8953dc2329 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -5325,6 +5325,24 @@ class TestTaskInstanceStateOperations:
         )
         
mock_supervisor_comms.send.assert_any_call(GetTaskStore(ti_id=runtime_ti.id, 
key="job_id"))
 
+    def test_task_state_get_returns_default_when_key_missing(self, 
create_runtime_ti, mock_supervisor_comms):
+        captured = {}
+
+        class MyOperator(BaseOperator):
+            def execute(self, context):
+                captured["result"] = context["task_store"].get(
+                    "watermark", default="2026-01-01T00:00:00+00:00"
+                )
+
+        mock_supervisor_comms.send.return_value = ErrorResponse(
+            error=ErrorType.TASK_STORE_NOT_FOUND, detail={"key": "watermark"}
+        )
+        task = MyOperator(task_id="t")
+        runtime_ti = create_runtime_ti(task=task)
+        run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+
+        assert captured["result"] == "2026-01-01T00:00:00+00:00"
+
     def test_task_state_set_sends_typed_values(self, create_runtime_ti, 
mock_supervisor_comms, time_machine):
         """set() accepts any JsonValue — dict, int, list — not just strings."""
 
@@ -5441,6 +5459,27 @@ class TestTaskInstanceStateOperations:
         )
         
mock_supervisor_comms.send.assert_any_call(GetAssetStoreByName(name="my_asset", 
key="watermark"))
 
+    def test_asset_state_get_returns_default_when_key_missing(self, 
create_runtime_ti, mock_supervisor_comms):
+        watched = Asset(name="my_asset", uri="s3://bucket/data")
+        captured = {}
+
+        class WatcherOperator(BaseOperator):
+            def execute(self, context):
+                captured["result"] = context["asset_store"].get(
+                    "watermark", default="2026-01-01T00:00:00+00:00"
+                )
+
+        task = WatcherOperator(task_id="t", inlets=[watched])
+        runtime_ti = create_runtime_ti(task=task)
+        mock_supervisor_comms.send.side_effect = lambda msg: (
+            ErrorResponse(error=ErrorType.ASSET_STORE_NOT_FOUND, 
detail={"key": "watermark"})
+            if isinstance(msg, GetAssetStoreByName)
+            else TestTaskInstanceStateOperations._watcher_side_effect(msg)
+        )
+        run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+
+        assert captured["result"] == "2026-01-01T00:00:00+00:00"
+
     def test_asset_state_delete(self, create_runtime_ti, 
mock_supervisor_comms):
         watched = Asset(name="my_asset", uri="s3://bucket/data")
 

Reply via email to