This is an automated email from the ASF dual-hosted git repository.

jason810496 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new beb5956ec2d Go-SDK: support UP_FOR_RETRY in coordinator mode with e2e 
coverage (#68554)
beb5956ec2d is described below

commit beb5956ec2dcbadbf3d371b83c80f39a384f693e
Author: Jason(Zhe-You) Liu <[email protected]>
AuthorDate: Mon Jun 15 21:33:27 2026 +0800

    Go-SDK: support UP_FOR_RETRY in coordinator mode with e2e coverage (#68554)
    
    * Fix Go SDK not retrying failed tasks
    
    In coordinator mode, failing tasks with retries available were being marked 
as terminal FAILED because the Go SDK did not pass the UP_FOR_RETRY state back 
to the supervisor. This adds ShouldRetry and MaxTries to TIRunContext and 
handles emitting RetryTaskMsg in the task execution path.
    
    * Verify Go SDK UP_FOR_RETRY end-to-end in e2e bundle
    
    Make the Go example bundle's 'load' task fail on its first attempt and
    succeed on the retry (retries=1), and assert end-to-end that the task
    passes through UP_FOR_RETRY and ends success on try_number 2. This
    exercises the RetryTask path the Go coordinator now emits.
    
    * Fix retry_reason field naming
    
    ---------
    
    Co-authored-by: Arnav <[email protected]>
---
 .../go_sdk_tests/test_go_sdk_dag.py                | 60 +++++++++++++++-------
 go-sdk/dags/go_examples.py                         | 20 +++++---
 go-sdk/example/bundle/main.go                      | 15 +++++-
 go-sdk/pkg/execution/frames.go                     | 13 +++++
 go-sdk/pkg/execution/integration_test.go           | 56 ++++++++++++++++++++
 go-sdk/pkg/execution/messages.go                   | 26 +++++++++-
 go-sdk/pkg/execution/messages_test.go              | 13 +++++
 go-sdk/pkg/execution/task_runner.go                | 29 +++++++----
 8 files changed, 194 insertions(+), 38 deletions(-)

diff --git 
a/airflow-e2e-tests/tests/airflow_e2e_tests/go_sdk_tests/test_go_sdk_dag.py 
b/airflow-e2e-tests/tests/airflow_e2e_tests/go_sdk_tests/test_go_sdk_dag.py
index ad78b2dd3b1..02c1ccff6bd 100644
--- a/airflow-e2e-tests/tests/airflow_e2e_tests/go_sdk_tests/test_go_sdk_dag.py
+++ b/airflow-e2e-tests/tests/airflow_e2e_tests/go_sdk_tests/test_go_sdk_dag.py
@@ -38,8 +38,9 @@ Python tasks::
   msgpack-over-IPC coordinator protocol.
 * ``python_task_1`` (Python) pushes an XCom; ``extract`` (Go) fetches the
   ``test_http`` connection and returns ``{go_version, timestamp}``; 
``transform``
-  (Go) reads ``my_variable``; ``load`` (Go) returns an error on purpose;
-  ``python_task_2`` (Python) pulls and re-emits the Go ``extract`` task's XCom.
+  (Go) reads ``my_variable``; ``load`` (Go) fails on its first attempt and
+  succeeds on the retry (``retries=1``); ``python_task_2`` (Python) pulls and
+  re-emits the Go ``extract`` task's XCom.
 
 The Dag is triggered exactly once by the module-scoped ``completed_run`` 
fixture;
 each test asserts a different facet of that single run. Together they confirm,
@@ -47,7 +48,8 @@ end-to-end:
 
 1. ``ExecutableCoordinator`` discovers the AFBNDL01 bundle by dag_id and runs 
the
    binary in coordinator mode for every Go task, reporting ``SucceedTask`` for
-   extract/transform and a failed ``TaskState`` for load.
+   extract/transform and -- because ``load`` has ``retries=1`` -- a 
``RetryTask``
+   (UP_FOR_RETRY) for its first failing attempt, after which the retry 
succeeds.
 2. Connection / Variable reads and XCom writes work through the Task Execution
    API, XCom values keep their types (the ``timestamp`` stays an ``int``), and
    XCom crosses the Python <-> Go boundary in both directions.
@@ -56,6 +58,9 @@ end-to-end:
 4. The runtime context surfaced to Go tasks via ``sdk.CurrentContext`` is fully
    populated on the coordinator path: the task-instance identifiers and the Dag
    run's scheduling timestamps (logical_date / data_interval_start/end).
+5. A Go task that fails with retries left emits ``RetryTask`` rather than a
+   terminal ``FAILED``, so the supervisor marks it UP_FOR_RETRY and re-runs it;
+   ``load`` therefore ends ``success`` on its second attempt (try_number 2).
 """
 
 from __future__ import annotations
@@ -86,6 +91,10 @@ class _CompletedRun:
     run_id: str
     state: str
     ti_states: dict[str, str]
+    ti_attrs: dict[str, dict]
+
+    def try_number(self, task_id: str) -> int | None:
+        return self.ti_attrs.get(task_id, {}).get("try_number")
 
     def xcom(self, task_id: str, key: str = "return_value"):
         return self.client.get_xcom_value(dag_id=_DAG_ID, task_id=task_id, 
run_id=self.run_id, key=key).get(
@@ -121,17 +130,18 @@ def completed_run() -> _CompletedRun:
     run_id = resp["dag_run_id"]
     state = client.wait_for_dag_run(dag_id=_DAG_ID, run_id=run_id, 
timeout=_GO_TASK_TIMEOUT)
     ti_resp = client.get_task_instances(dag_id=_DAG_ID, run_id=run_id)
-    ti_states = {ti["task_id"]: ti.get("state") for ti in 
ti_resp.get("task_instances", [])}
-    return _CompletedRun(client=client, run_id=run_id, state=state, 
ti_states=ti_states)
+    ti_attrs = {ti["task_id"]: ti for ti in ti_resp.get("task_instances", [])}
+    ti_states = {task_id: ti.get("state") for task_id, ti in ti_attrs.items()}
+    return _CompletedRun(client=client, run_id=run_id, state=state, 
ti_states=ti_states, ti_attrs=ti_attrs)
 
 
 def test_task_states(completed_run: _CompletedRun):
-    """Every task ends in its expected state (the Go ``load`` task fails on 
purpose)."""
+    """Every task ends ``success`` (the Go ``load`` task succeeds on its 
retry)."""
     expected = {
         "python_task_1": "success",
         "extract": "success",
         "transform": "success",
-        "load": "failed",
+        "load": "success",
         "python_task_2": "success",
     }
     for task_id, want in expected.items():
@@ -140,14 +150,35 @@ def test_task_states(completed_run: _CompletedRun):
         )
 
 
-def test_dag_run_failed(completed_run: _CompletedRun):
-    """The failing ``load`` leaf makes the overall run fail."""
-    assert completed_run.state == "failed", (
-        f"expected the run to fail because 'load' fails; got 
{completed_run.state!r}. "
+def test_dag_run_succeeded(completed_run: _CompletedRun):
+    """The run succeeds once ``load`` recovers on its retry."""
+    assert completed_run.state == "success", (
+        f"expected the run to succeed because 'load' recovers on retry; got 
{completed_run.state!r}. "
         f"task states: {completed_run.ti_states}"
     )
 
 
+def test_load_retried_then_succeeded(completed_run: _CompletedRun):
+    """``load`` fails once (UP_FOR_RETRY) then succeeds on the second attempt.
+
+    The Go coordinator must emit ``RetryTask`` (not terminal ``FAILED``) when 
the
+    task fails with retries left, so the supervisor re-runs it. The end state 
is
+    ``success`` reached on ``try_number`` 2, and each attempt's log reflects 
the
+    first failure and then the recovery.
+    """
+    assert completed_run.ti_states.get("load") == "success", 
completed_run.ti_states
+    assert completed_run.try_number("load") == 2, (
+        f"'load' should have run twice (fail then retry); try_number="
+        f"{completed_run.try_number('load')!r}, ti: 
{completed_run.ti_attrs.get('load')}"
+    )
+    assert "Please fail" in completed_run.logs("load", try_number=1), (
+        "load's first attempt should log its failure message 'Please fail'"
+    )
+    assert "Recovered on retry" in completed_run.logs("load", try_number=2), (
+        "load's retry attempt should log 'Recovered on retry'"
+    )
+
+
 def test_python_task_1_pushes_xcom(completed_run: _CompletedRun):
     """The upstream Python task's XCom is available (Python -> XCom)."""
     assert completed_run.xcom("python_task_1") == "value_from_python_task_1"
@@ -230,10 +261,3 @@ def test_transform_logs_show_variable_read(completed_run: 
_CompletedRun):
     assert "Obtained variable" in completed_run.logs("transform"), (
         "transform task should log 'Obtained variable'"
     )
-
-
-def test_load_logs_show_failure(completed_run: _CompletedRun):
-    """The Go 'load' task's error surfaces in its task log."""
-    assert "Please fail" in completed_run.logs("load"), (
-        "load task log should contain its failure message 'Please fail'"
-    )
diff --git a/go-sdk/dags/go_examples.py b/go-sdk/dags/go_examples.py
index 523b6b8bf23..6c9ecf7b455 100644
--- a/go-sdk/dags/go_examples.py
+++ b/go-sdk/dags/go_examples.py
@@ -29,9 +29,10 @@ exercises XCom across the language boundary, the same way
   routed to the ``ExecutableCoordinator``, which locates the bundle by dag_id 
and
   runs the binary in coordinator mode. ``extract`` returns a map (pushed as its
   ``return_value`` XCom); ``transform`` reads the ``my_variable`` variable.
-* ``load`` returns an error on purpose. It is a leaf (not upstream of
-  ``python_task_2``) so its failure is observable while leaving the Go -> 
Python
-  XCom hop intact.
+* ``load`` (``retries=1``) returns an error on its first attempt and succeeds
+  on the retry, exercising the UP_FOR_RETRY path through the Go coordinator. It
+  is a leaf (not upstream of ``python_task_2``) so its retry is observable
+  while leaving the Go -> Python XCom hop intact.
 * ``python_task_2`` (Python) pulls the Go ``extract`` task's XCom and re-emits
   it, demonstrating the Go -> Python direction end-to-end.
 
@@ -43,6 +44,8 @@ default Python executor and are independent of the bundle.
 
 from __future__ import annotations
 
+from datetime import timedelta
+
 from airflow.sdk import dag, task
 
 
@@ -61,7 +64,10 @@ def extract(): ...
 def transform(): ...
 
 
[email protected](queue="golang")
+# ``load`` fails on its first attempt and succeeds on the retry, exercising the
+# UP_FOR_RETRY path through the Go coordinator. The short ``retry_delay`` keeps
+# the end-to-end run fast.
[email protected](queue="golang", retries=1, retry_delay=timedelta(seconds=5))
 def load(): ...
 
 
@@ -78,9 +84,9 @@ def simple_dag():
     extracted = extract()
     transformed = transform()
     python_task_1() >> extracted >> transformed
-    # ``load`` fails on purpose; keep it a leaf (not upstream of python_task_2)
-    # so the failure is observable without skipping the Python task that pulls
-    # the Go XCom.
+    # ``load`` fails once then succeeds on retry; keep it a leaf (not upstream
+    # of python_task_2) so its retry is observable without affecting the Python
+    # task that pulls the Go XCom.
     transformed >> [load(), python_task_2(extracted)]
 
 
diff --git a/go-sdk/example/bundle/main.go b/go-sdk/example/bundle/main.go
index 8f7513dc355..7f4f1d22dcf 100644
--- a/go-sdk/example/bundle/main.go
+++ b/go-sdk/example/bundle/main.go
@@ -137,6 +137,17 @@ func transform(ctx sdk.TIRunContext, client 
sdk.VariableClient, log *slog.Logger
        return nil
 }
 
-func load() error {
-       return fmt.Errorf("Please fail")
+// load fails on its first attempt and succeeds on the retry. With retries
+// configured on the stub task, the first failure makes the supervisor mark the
+// task UP_FOR_RETRY -- which only works because the Go SDK now emits a
+// RetryTask frame (instead of a terminal FAILED) when ti_context.should_retry
+// is set. The retry then runs this task again and it returns nil.
+func load(ctx sdk.TIRunContext, log *slog.Logger) error {
+       tryNumber := ctx.TaskInstance().TryNumber
+       if tryNumber == 1 {
+               log.InfoContext(ctx, "Please fail", "try_number", tryNumber)
+               return fmt.Errorf("Please fail")
+       }
+       log.InfoContext(ctx, "Recovered on retry", "try_number", tryNumber)
+       return nil
 }
diff --git a/go-sdk/pkg/execution/frames.go b/go-sdk/pkg/execution/frames.go
index 5d7d7ca3649..a8f65770fe2 100644
--- a/go-sdk/pkg/execution/frames.go
+++ b/go-sdk/pkg/execution/frames.go
@@ -242,6 +242,19 @@ func mapStringOr(m map[string]any, key string, def string) 
string {
        return s
 }
 
+// mapBoolOr extracts a bool value from a map, returning the default if 
missing.
+func mapBoolOr(m map[string]any, key string, def bool) bool {
+       v, ok := m[key]
+       if !ok {
+               return def
+       }
+       b, ok := v.(bool)
+       if !ok {
+               return def
+       }
+       return b
+}
+
 // mapStringPtr extracts a nullable string value from a map. It returns nil
 // when the key is missing or the value is nil (i.e. JSON null / Python None),
 // and a pointer to the string when a value is present. Use this for fields
diff --git a/go-sdk/pkg/execution/integration_test.go 
b/go-sdk/pkg/execution/integration_test.go
index 47f63b6d26e..90870449f5c 100644
--- a/go-sdk/pkg/execution/integration_test.go
+++ b/go-sdk/pkg/execution/integration_test.go
@@ -106,6 +106,34 @@ func TestTaskRunnerFailure(t *testing.T) {
        assert.Equal(t, "failed", result["state"])
 }
 
+func TestTaskRunnerRetry(t *testing.T) {
+       bundle := buildBundle(t, func(r bundlev1.Registry) {
+               r.AddDag("test_dag").AddTask(failingTask)
+       })
+
+       details := &StartupDetails{
+               TI: TaskInstanceInfo{
+                       ID:       "550e8400-e29b-41d4-a716-446655440000",
+                       DagID:    "test_dag",
+                       TaskID:   "failingTask",
+                       RunID:    "run1",
+                       MapIndex: -1,
+               },
+               BundleInfo: BundleInfoMsg{Name: "test", Version: "1.0"},
+               TIContext: TIRunContext{
+                       ShouldRetry: true,
+                       MaxTries:    3,
+               },
+       }
+
+       logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+       comm := NewCoordinatorComm(bytes.NewReader(nil), io.Discard, logger)
+
+       result := RunTask(context.Background(), bundle, details, comm, logger)
+       assert.Equal(t, "RetryTask", result["type"])
+       assert.Equal(t, "task failed intentionally", result["retry_reason"])
+}
+
 func TestTaskRunnerTaskNotFound(t *testing.T) {
        bundle := buildBundle(t, func(r bundlev1.Registry) {
                r.AddDag("test_dag").AddTask(simpleTask)
@@ -153,6 +181,34 @@ func TestTaskRunnerPanic(t *testing.T) {
        assert.Equal(t, "failed", result["state"])
 }
 
+func TestTaskRunnerPanicRetry(t *testing.T) {
+       bundle := buildBundle(t, func(r bundlev1.Registry) {
+               r.AddDag("test_dag").AddTask(panicTask)
+       })
+
+       details := &StartupDetails{
+               TI: TaskInstanceInfo{
+                       ID:       "550e8400-e29b-41d4-a716-446655440000",
+                       DagID:    "test_dag",
+                       TaskID:   "panicTask",
+                       RunID:    "run1",
+                       MapIndex: -1,
+               },
+               BundleInfo: BundleInfoMsg{Name: "test", Version: "1.0"},
+               TIContext: TIRunContext{
+                       ShouldRetry: true,
+                       MaxTries:    3,
+               },
+       }
+
+       logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+       comm := NewCoordinatorComm(bytes.NewReader(nil), io.Discard, logger)
+
+       result := RunTask(context.Background(), bundle, details, comm, logger)
+       assert.Equal(t, "RetryTask", result["type"])
+       assert.Contains(t, result["retry_reason"], "panic: something went 
wrong")
+}
+
 func TestRunTaskHonorsContextCancellation(t *testing.T) {
        bundle := buildBundle(t, func(r bundlev1.Registry) {
                r.AddDag("test_dag").AddTaskWithName("ctxcheck",
diff --git a/go-sdk/pkg/execution/messages.go b/go-sdk/pkg/execution/messages.go
index 4cb85cebbb6..352215ac92f 100644
--- a/go-sdk/pkg/execution/messages.go
+++ b/go-sdk/pkg/execution/messages.go
@@ -104,21 +104,29 @@ type TIRunContext struct {
        LogicalDate       *time.Time
        DataIntervalStart *time.Time
        DataIntervalEnd   *time.Time
+       MaxTries          int
+       ShouldRetry       bool
 }
 
 func decodeTIRunContext(m map[string]any) (TIRunContext, error) {
        if m == nil {
                return TIRunContext{}, nil
        }
+       // max_tries / should_retry live at the top level of ti_context and 
drive
+       // the retry decision, so they must be read regardless of whether the
+       // nested dag_run object is present.
+       ctx := TIRunContext{
+               MaxTries:    mapIntOr(m, "max_tries", 0),
+               ShouldRetry: mapBoolOr(m, "should_retry", false),
+       }
        // The scheduling timestamps live on the nested dag_run object in the
        // supervisor's TIRunContext schema (ti_context.dag_run.logical_date, 
...),
        // not at the top level of ti_context. See task-sdk's
        // airflow.sdk.api.datamodels._generated.{TIRunContext,DagRun}.
        dagRun := mapMap(m, "dag_run")
        if dagRun == nil {
-               return TIRunContext{}, nil
+               return ctx, nil
        }
-       ctx := TIRunContext{}
        for _, f := range []struct {
                key string
                dst **time.Time
@@ -388,6 +396,20 @@ func (m TaskStateMsg) toMap() map[string]any {
        }
 }
 
+// RetryTaskMsg is sent as a terminal message when a task fails but has 
retries.
+type RetryTaskMsg struct {
+       EndDate time.Time
+       Reason  string
+}
+
+func (m RetryTaskMsg) toMap() map[string]any {
+       return map[string]any{
+               "type":         "RetryTask",
+               "end_date":     m.EndDate.UTC().Format(time.RFC3339Nano),
+               "retry_reason": m.Reason,
+       }
+}
+
 // Message dispatch.
 
 // decodeIncomingBody dispatches decoding of a body map based on its "type" 
field.
diff --git a/go-sdk/pkg/execution/messages_test.go 
b/go-sdk/pkg/execution/messages_test.go
index e1e1e3efce2..44b93ce1fed 100644
--- a/go-sdk/pkg/execution/messages_test.go
+++ b/go-sdk/pkg/execution/messages_test.go
@@ -54,6 +54,8 @@ func TestDecodeStartupDetails(t *testing.T) {
                                "data_interval_start": "2024-01-14T00:00:00Z",
                                "data_interval_end":   "2024-01-15T00:00:00Z",
                        },
+                       "max_tries":    int64(3),
+                       "should_retry": true,
                },
        }
 
@@ -83,6 +85,8 @@ func TestDecodeStartupDetails(t *testing.T) {
                time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC),
                *details.TIContext.DataIntervalEnd,
        )
+       assert.Equal(t, 3, details.TIContext.MaxTries)
+       assert.True(t, details.TIContext.ShouldRetry)
 }
 
 func TestDecodeStartupDetails_MalformedStartDate(t *testing.T) {
@@ -418,6 +422,15 @@ func TestTaskStateMsgToMap_PreservesSubsecondPrecision(t 
*testing.T) {
        assert.Equal(t, "2024-01-15T10:30:00.123456789Z", m["end_date"])
 }
 
+func TestRetryTaskMsgToMap(t *testing.T) {
+       endDate := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+       msg := RetryTaskMsg{EndDate: endDate, Reason: "test error"}
+       m := msg.toMap()
+       assert.Equal(t, "RetryTask", m["type"])
+       assert.Equal(t, "test error", m["retry_reason"])
+       assert.Equal(t, "2024-01-15T10:30:00Z", m["end_date"])
+}
+
 func TestTaskStateConstants_WireValues(t *testing.T) {
        // Pin each enum constant to the exact wire string Python's
        // TaskInstanceState expects. Renaming these constants is fine;
diff --git a/go-sdk/pkg/execution/task_runner.go 
b/go-sdk/pkg/execution/task_runner.go
index f8bbd539751..46510955a8b 100644
--- a/go-sdk/pkg/execution/task_runner.go
+++ b/go-sdk/pkg/execution/task_runner.go
@@ -19,6 +19,7 @@ package execution
 
 import (
        "context"
+       "fmt"
        "log/slog"
        "runtime/debug"
        "time"
@@ -114,7 +115,7 @@ func RunTask(
        ctx = context.WithValue(ctx, sdkcontext.SdkClientContextKey, 
sdk.Client(client))
        ctx = context.WithValue(ctx, sdkcontext.RuntimeContextKey, 
runtimeContext)
 
-       return executeTask(ctx, task, logger)
+       return executeTask(ctx, task, details.TIContext.ShouldRetry, logger)
 }
 
 // mapIndexPtr converts the supervisor's map_index (which uses -1 as the
@@ -131,6 +132,7 @@ func mapIndexPtr(mapIndex int) *int {
 func executeTask(
        ctx context.Context,
        task bundlev1.Task,
+       shouldRetry bool,
        logger *slog.Logger,
 ) (result map[string]any) {
        defer func() {
@@ -139,19 +141,28 @@ func executeTask(
                                "error", r,
                                "stack", string(debug.Stack()),
                        )
-                       result = TaskStateMsg{
-                               State:   TaskStateFailed,
-                               EndDate: time.Now().UTC(),
-                       }.toMap()
+                       if shouldRetry {
+                               result = RetryTaskMsg{
+                                       EndDate: time.Now().UTC(),
+                                       Reason:  fmt.Sprintf("panic: %v", r),
+                               }.toMap()
+                       } else {
+                               result = TaskStateMsg{
+                                       State:   TaskStateFailed,
+                                       EndDate: time.Now().UTC(),
+                               }.toMap()
+                       }
                }
        }()
 
        if err := task.Execute(ctx, logger); err != nil {
                logger.ErrorContext(ctx, "Task failed", "error", err)
-               // TODO(https://github.com/apache/airflow/issues/67797): emit 
RetryTask
-               // (UP_FOR_RETRY) when ti_context.should_retry is set. Today 
every
-               // failure maps to terminal FAILED because the supervisor 
honors this
-               // frame on exit 0 and we never send RetryTask, so retries are 
lost.
+               if shouldRetry {
+                       return RetryTaskMsg{
+                               EndDate: time.Now().UTC(),
+                               Reason:  err.Error(),
+                       }.toMap()
+               }
                return TaskStateMsg{
                        State:   TaskStateFailed,
                        EndDate: time.Now().UTC(),

Reply via email to