This is an automated email from the ASF dual-hosted git repository.
jason810496 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new beb5956ec2d Go-SDK: support UP_FOR_RETRY in coordinator mode with e2e
coverage (#68554)
beb5956ec2d is described below
commit beb5956ec2dcbadbf3d371b83c80f39a384f693e
Author: Jason(Zhe-You) Liu <[email protected]>
AuthorDate: Mon Jun 15 21:33:27 2026 +0800
Go-SDK: support UP_FOR_RETRY in coordinator mode with e2e coverage (#68554)
* Fix Go SDK not retrying failed tasks
In coordinator mode, failing tasks with retries available were being marked
as terminal FAILED because the Go SDK did not pass the UP_FOR_RETRY state back
to the supervisor. This adds ShouldRetry and MaxTries to TIRunContext and
handles emitting RetryTaskMsg in the task execution path.
* Verify Go SDK UP_FOR_RETRY end-to-end in e2e bundle
Make the Go example bundle's 'load' task fail on its first attempt and
succeed on the retry (retries=1), and assert end-to-end that the task
passes through UP_FOR_RETRY and ends success on try_number 2. This
exercises the RetryTask path the Go coordinator now emits.
* Fix retry_reason field naming
---------
Co-authored-by: Arnav <[email protected]>
---
.../go_sdk_tests/test_go_sdk_dag.py | 60 +++++++++++++++-------
go-sdk/dags/go_examples.py | 20 +++++---
go-sdk/example/bundle/main.go | 15 +++++-
go-sdk/pkg/execution/frames.go | 13 +++++
go-sdk/pkg/execution/integration_test.go | 56 ++++++++++++++++++++
go-sdk/pkg/execution/messages.go | 26 +++++++++-
go-sdk/pkg/execution/messages_test.go | 13 +++++
go-sdk/pkg/execution/task_runner.go | 29 +++++++----
8 files changed, 194 insertions(+), 38 deletions(-)
diff --git
a/airflow-e2e-tests/tests/airflow_e2e_tests/go_sdk_tests/test_go_sdk_dag.py
b/airflow-e2e-tests/tests/airflow_e2e_tests/go_sdk_tests/test_go_sdk_dag.py
index ad78b2dd3b1..02c1ccff6bd 100644
--- a/airflow-e2e-tests/tests/airflow_e2e_tests/go_sdk_tests/test_go_sdk_dag.py
+++ b/airflow-e2e-tests/tests/airflow_e2e_tests/go_sdk_tests/test_go_sdk_dag.py
@@ -38,8 +38,9 @@ Python tasks::
msgpack-over-IPC coordinator protocol.
* ``python_task_1`` (Python) pushes an XCom; ``extract`` (Go) fetches the
``test_http`` connection and returns ``{go_version, timestamp}``;
``transform``
- (Go) reads ``my_variable``; ``load`` (Go) returns an error on purpose;
- ``python_task_2`` (Python) pulls and re-emits the Go ``extract`` task's XCom.
+ (Go) reads ``my_variable``; ``load`` (Go) fails on its first attempt and
+ succeeds on the retry (``retries=1``); ``python_task_2`` (Python) pulls and
+ re-emits the Go ``extract`` task's XCom.
The Dag is triggered exactly once by the module-scoped ``completed_run``
fixture;
each test asserts a different facet of that single run. Together they confirm,
@@ -47,7 +48,8 @@ end-to-end:
1. ``ExecutableCoordinator`` discovers the AFBNDL01 bundle by dag_id and runs
the
binary in coordinator mode for every Go task, reporting ``SucceedTask`` for
- extract/transform and a failed ``TaskState`` for load.
+ extract/transform and -- because ``load`` has ``retries=1`` -- a
``RetryTask``
+ (UP_FOR_RETRY) for its first failing attempt, after which the retry
succeeds.
2. Connection / Variable reads and XCom writes work through the Task Execution
API, XCom values keep their types (the ``timestamp`` stays an ``int``), and
XCom crosses the Python <-> Go boundary in both directions.
@@ -56,6 +58,9 @@ end-to-end:
4. The runtime context surfaced to Go tasks via ``sdk.CurrentContext`` is fully
populated on the coordinator path: the task-instance identifiers and the Dag
run's scheduling timestamps (logical_date / data_interval_start/end).
+5. A Go task that fails with retries left emits ``RetryTask`` rather than a
+ terminal ``FAILED``, so the supervisor marks it UP_FOR_RETRY and re-runs it;
+ ``load`` therefore ends ``success`` on its second attempt (try_number 2).
"""
from __future__ import annotations
@@ -86,6 +91,10 @@ class _CompletedRun:
run_id: str
state: str
ti_states: dict[str, str]
+ ti_attrs: dict[str, dict]
+
+ def try_number(self, task_id: str) -> int | None:
+ return self.ti_attrs.get(task_id, {}).get("try_number")
def xcom(self, task_id: str, key: str = "return_value"):
return self.client.get_xcom_value(dag_id=_DAG_ID, task_id=task_id,
run_id=self.run_id, key=key).get(
@@ -121,17 +130,18 @@ def completed_run() -> _CompletedRun:
run_id = resp["dag_run_id"]
state = client.wait_for_dag_run(dag_id=_DAG_ID, run_id=run_id,
timeout=_GO_TASK_TIMEOUT)
ti_resp = client.get_task_instances(dag_id=_DAG_ID, run_id=run_id)
- ti_states = {ti["task_id"]: ti.get("state") for ti in
ti_resp.get("task_instances", [])}
- return _CompletedRun(client=client, run_id=run_id, state=state,
ti_states=ti_states)
+ ti_attrs = {ti["task_id"]: ti for ti in ti_resp.get("task_instances", [])}
+ ti_states = {task_id: ti.get("state") for task_id, ti in ti_attrs.items()}
+ return _CompletedRun(client=client, run_id=run_id, state=state,
ti_states=ti_states, ti_attrs=ti_attrs)
def test_task_states(completed_run: _CompletedRun):
- """Every task ends in its expected state (the Go ``load`` task fails on
purpose)."""
+ """Every task ends ``success`` (the Go ``load`` task succeeds on its
retry)."""
expected = {
"python_task_1": "success",
"extract": "success",
"transform": "success",
- "load": "failed",
+ "load": "success",
"python_task_2": "success",
}
for task_id, want in expected.items():
@@ -140,14 +150,35 @@ def test_task_states(completed_run: _CompletedRun):
)
-def test_dag_run_failed(completed_run: _CompletedRun):
- """The failing ``load`` leaf makes the overall run fail."""
- assert completed_run.state == "failed", (
- f"expected the run to fail because 'load' fails; got
{completed_run.state!r}. "
+def test_dag_run_succeeded(completed_run: _CompletedRun):
+ """The run succeeds once ``load`` recovers on its retry."""
+ assert completed_run.state == "success", (
+ f"expected the run to succeed because 'load' recovers on retry; got
{completed_run.state!r}. "
f"task states: {completed_run.ti_states}"
)
+def test_load_retried_then_succeeded(completed_run: _CompletedRun):
+ """``load`` fails once (UP_FOR_RETRY) then succeeds on the second attempt.
+
+ The Go coordinator must emit ``RetryTask`` (not terminal ``FAILED``) when
the
+ task fails with retries left, so the supervisor re-runs it. The end state
is
+ ``success`` reached on ``try_number`` 2, and each attempt's log reflects
the
+ first failure and then the recovery.
+ """
+ assert completed_run.ti_states.get("load") == "success",
completed_run.ti_states
+ assert completed_run.try_number("load") == 2, (
+ f"'load' should have run twice (fail then retry); try_number="
+ f"{completed_run.try_number('load')!r}, ti:
{completed_run.ti_attrs.get('load')}"
+ )
+ assert "Please fail" in completed_run.logs("load", try_number=1), (
+ "load's first attempt should log its failure message 'Please fail'"
+ )
+ assert "Recovered on retry" in completed_run.logs("load", try_number=2), (
+ "load's retry attempt should log 'Recovered on retry'"
+ )
+
+
def test_python_task_1_pushes_xcom(completed_run: _CompletedRun):
"""The upstream Python task's XCom is available (Python -> XCom)."""
assert completed_run.xcom("python_task_1") == "value_from_python_task_1"
@@ -230,10 +261,3 @@ def test_transform_logs_show_variable_read(completed_run:
_CompletedRun):
assert "Obtained variable" in completed_run.logs("transform"), (
"transform task should log 'Obtained variable'"
)
-
-
-def test_load_logs_show_failure(completed_run: _CompletedRun):
- """The Go 'load' task's error surfaces in its task log."""
- assert "Please fail" in completed_run.logs("load"), (
- "load task log should contain its failure message 'Please fail'"
- )
diff --git a/go-sdk/dags/go_examples.py b/go-sdk/dags/go_examples.py
index 523b6b8bf23..6c9ecf7b455 100644
--- a/go-sdk/dags/go_examples.py
+++ b/go-sdk/dags/go_examples.py
@@ -29,9 +29,10 @@ exercises XCom across the language boundary, the same way
routed to the ``ExecutableCoordinator``, which locates the bundle by dag_id
and
runs the binary in coordinator mode. ``extract`` returns a map (pushed as its
``return_value`` XCom); ``transform`` reads the ``my_variable`` variable.
-* ``load`` returns an error on purpose. It is a leaf (not upstream of
- ``python_task_2``) so its failure is observable while leaving the Go ->
Python
- XCom hop intact.
+* ``load`` (``retries=1``) returns an error on its first attempt and succeeds
+ on the retry, exercising the UP_FOR_RETRY path through the Go coordinator. It
+ is a leaf (not upstream of ``python_task_2``) so its retry is observable
+ while leaving the Go -> Python XCom hop intact.
* ``python_task_2`` (Python) pulls the Go ``extract`` task's XCom and re-emits
it, demonstrating the Go -> Python direction end-to-end.
@@ -43,6 +44,8 @@ default Python executor and are independent of the bundle.
from __future__ import annotations
+from datetime import timedelta
+
from airflow.sdk import dag, task
@@ -61,7 +64,10 @@ def extract(): ...
def transform(): ...
[email protected](queue="golang")
+# ``load`` fails on its first attempt and succeeds on the retry, exercising the
+# UP_FOR_RETRY path through the Go coordinator. The short ``retry_delay`` keeps
+# the end-to-end run fast.
[email protected](queue="golang", retries=1, retry_delay=timedelta(seconds=5))
def load(): ...
@@ -78,9 +84,9 @@ def simple_dag():
extracted = extract()
transformed = transform()
python_task_1() >> extracted >> transformed
- # ``load`` fails on purpose; keep it a leaf (not upstream of python_task_2)
- # so the failure is observable without skipping the Python task that pulls
- # the Go XCom.
+ # ``load`` fails once then succeeds on retry; keep it a leaf (not upstream
+ # of python_task_2) so its retry is observable without affecting the Python
+ # task that pulls the Go XCom.
transformed >> [load(), python_task_2(extracted)]
diff --git a/go-sdk/example/bundle/main.go b/go-sdk/example/bundle/main.go
index 8f7513dc355..7f4f1d22dcf 100644
--- a/go-sdk/example/bundle/main.go
+++ b/go-sdk/example/bundle/main.go
@@ -137,6 +137,17 @@ func transform(ctx sdk.TIRunContext, client
sdk.VariableClient, log *slog.Logger
return nil
}
-func load() error {
- return fmt.Errorf("Please fail")
+// load fails on its first attempt and succeeds on the retry. With retries
+// configured on the stub task, the first failure makes the supervisor mark the
+// task UP_FOR_RETRY -- which only works because the Go SDK now emits a
+// RetryTask frame (instead of a terminal FAILED) when ti_context.should_retry
+// is set. The retry then runs this task again and it returns nil.
+func load(ctx sdk.TIRunContext, log *slog.Logger) error {
+ tryNumber := ctx.TaskInstance().TryNumber
+ if tryNumber == 1 {
+ log.InfoContext(ctx, "Please fail", "try_number", tryNumber)
+ return fmt.Errorf("Please fail")
+ }
+ log.InfoContext(ctx, "Recovered on retry", "try_number", tryNumber)
+ return nil
}
diff --git a/go-sdk/pkg/execution/frames.go b/go-sdk/pkg/execution/frames.go
index 5d7d7ca3649..a8f65770fe2 100644
--- a/go-sdk/pkg/execution/frames.go
+++ b/go-sdk/pkg/execution/frames.go
@@ -242,6 +242,19 @@ func mapStringOr(m map[string]any, key string, def string)
string {
return s
}
+// mapBoolOr extracts a bool value from a map, returning the default if
missing.
+func mapBoolOr(m map[string]any, key string, def bool) bool {
+ v, ok := m[key]
+ if !ok {
+ return def
+ }
+ b, ok := v.(bool)
+ if !ok {
+ return def
+ }
+ return b
+}
+
// mapStringPtr extracts a nullable string value from a map. It returns nil
// when the key is missing or the value is nil (i.e. JSON null / Python None),
// and a pointer to the string when a value is present. Use this for fields
diff --git a/go-sdk/pkg/execution/integration_test.go
b/go-sdk/pkg/execution/integration_test.go
index 47f63b6d26e..90870449f5c 100644
--- a/go-sdk/pkg/execution/integration_test.go
+++ b/go-sdk/pkg/execution/integration_test.go
@@ -106,6 +106,34 @@ func TestTaskRunnerFailure(t *testing.T) {
assert.Equal(t, "failed", result["state"])
}
+func TestTaskRunnerRetry(t *testing.T) {
+ bundle := buildBundle(t, func(r bundlev1.Registry) {
+ r.AddDag("test_dag").AddTask(failingTask)
+ })
+
+ details := &StartupDetails{
+ TI: TaskInstanceInfo{
+ ID: "550e8400-e29b-41d4-a716-446655440000",
+ DagID: "test_dag",
+ TaskID: "failingTask",
+ RunID: "run1",
+ MapIndex: -1,
+ },
+ BundleInfo: BundleInfoMsg{Name: "test", Version: "1.0"},
+ TIContext: TIRunContext{
+ ShouldRetry: true,
+ MaxTries: 3,
+ },
+ }
+
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ comm := NewCoordinatorComm(bytes.NewReader(nil), io.Discard, logger)
+
+ result := RunTask(context.Background(), bundle, details, comm, logger)
+ assert.Equal(t, "RetryTask", result["type"])
+ assert.Equal(t, "task failed intentionally", result["retry_reason"])
+}
+
func TestTaskRunnerTaskNotFound(t *testing.T) {
bundle := buildBundle(t, func(r bundlev1.Registry) {
r.AddDag("test_dag").AddTask(simpleTask)
@@ -153,6 +181,34 @@ func TestTaskRunnerPanic(t *testing.T) {
assert.Equal(t, "failed", result["state"])
}
+func TestTaskRunnerPanicRetry(t *testing.T) {
+ bundle := buildBundle(t, func(r bundlev1.Registry) {
+ r.AddDag("test_dag").AddTask(panicTask)
+ })
+
+ details := &StartupDetails{
+ TI: TaskInstanceInfo{
+ ID: "550e8400-e29b-41d4-a716-446655440000",
+ DagID: "test_dag",
+ TaskID: "panicTask",
+ RunID: "run1",
+ MapIndex: -1,
+ },
+ BundleInfo: BundleInfoMsg{Name: "test", Version: "1.0"},
+ TIContext: TIRunContext{
+ ShouldRetry: true,
+ MaxTries: 3,
+ },
+ }
+
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ comm := NewCoordinatorComm(bytes.NewReader(nil), io.Discard, logger)
+
+ result := RunTask(context.Background(), bundle, details, comm, logger)
+ assert.Equal(t, "RetryTask", result["type"])
+ assert.Contains(t, result["retry_reason"], "panic: something went
wrong")
+}
+
func TestRunTaskHonorsContextCancellation(t *testing.T) {
bundle := buildBundle(t, func(r bundlev1.Registry) {
r.AddDag("test_dag").AddTaskWithName("ctxcheck",
diff --git a/go-sdk/pkg/execution/messages.go b/go-sdk/pkg/execution/messages.go
index 4cb85cebbb6..352215ac92f 100644
--- a/go-sdk/pkg/execution/messages.go
+++ b/go-sdk/pkg/execution/messages.go
@@ -104,21 +104,29 @@ type TIRunContext struct {
LogicalDate *time.Time
DataIntervalStart *time.Time
DataIntervalEnd *time.Time
+ MaxTries int
+ ShouldRetry bool
}
func decodeTIRunContext(m map[string]any) (TIRunContext, error) {
if m == nil {
return TIRunContext{}, nil
}
+ // max_tries / should_retry live at the top level of ti_context and
drive
+ // the retry decision, so they must be read regardless of whether the
+ // nested dag_run object is present.
+ ctx := TIRunContext{
+ MaxTries: mapIntOr(m, "max_tries", 0),
+ ShouldRetry: mapBoolOr(m, "should_retry", false),
+ }
// The scheduling timestamps live on the nested dag_run object in the
// supervisor's TIRunContext schema (ti_context.dag_run.logical_date,
...),
// not at the top level of ti_context. See task-sdk's
// airflow.sdk.api.datamodels._generated.{TIRunContext,DagRun}.
dagRun := mapMap(m, "dag_run")
if dagRun == nil {
- return TIRunContext{}, nil
+ return ctx, nil
}
- ctx := TIRunContext{}
for _, f := range []struct {
key string
dst **time.Time
@@ -388,6 +396,20 @@ func (m TaskStateMsg) toMap() map[string]any {
}
}
+// RetryTaskMsg is sent as a terminal message when a task fails but has
retries.
+type RetryTaskMsg struct {
+ EndDate time.Time
+ Reason string
+}
+
+func (m RetryTaskMsg) toMap() map[string]any {
+ return map[string]any{
+ "type": "RetryTask",
+ "end_date": m.EndDate.UTC().Format(time.RFC3339Nano),
+ "retry_reason": m.Reason,
+ }
+}
+
// Message dispatch.
// decodeIncomingBody dispatches decoding of a body map based on its "type"
field.
diff --git a/go-sdk/pkg/execution/messages_test.go
b/go-sdk/pkg/execution/messages_test.go
index e1e1e3efce2..44b93ce1fed 100644
--- a/go-sdk/pkg/execution/messages_test.go
+++ b/go-sdk/pkg/execution/messages_test.go
@@ -54,6 +54,8 @@ func TestDecodeStartupDetails(t *testing.T) {
"data_interval_start": "2024-01-14T00:00:00Z",
"data_interval_end": "2024-01-15T00:00:00Z",
},
+ "max_tries": int64(3),
+ "should_retry": true,
},
}
@@ -83,6 +85,8 @@ func TestDecodeStartupDetails(t *testing.T) {
time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC),
*details.TIContext.DataIntervalEnd,
)
+ assert.Equal(t, 3, details.TIContext.MaxTries)
+ assert.True(t, details.TIContext.ShouldRetry)
}
func TestDecodeStartupDetails_MalformedStartDate(t *testing.T) {
@@ -418,6 +422,15 @@ func TestTaskStateMsgToMap_PreservesSubsecondPrecision(t
*testing.T) {
assert.Equal(t, "2024-01-15T10:30:00.123456789Z", m["end_date"])
}
+func TestRetryTaskMsgToMap(t *testing.T) {
+ endDate := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+ msg := RetryTaskMsg{EndDate: endDate, Reason: "test error"}
+ m := msg.toMap()
+ assert.Equal(t, "RetryTask", m["type"])
+ assert.Equal(t, "test error", m["retry_reason"])
+ assert.Equal(t, "2024-01-15T10:30:00Z", m["end_date"])
+}
+
func TestTaskStateConstants_WireValues(t *testing.T) {
// Pin each enum constant to the exact wire string Python's
// TaskInstanceState expects. Renaming these constants is fine;
diff --git a/go-sdk/pkg/execution/task_runner.go
b/go-sdk/pkg/execution/task_runner.go
index f8bbd539751..46510955a8b 100644
--- a/go-sdk/pkg/execution/task_runner.go
+++ b/go-sdk/pkg/execution/task_runner.go
@@ -19,6 +19,7 @@ package execution
import (
"context"
+ "fmt"
"log/slog"
"runtime/debug"
"time"
@@ -114,7 +115,7 @@ func RunTask(
ctx = context.WithValue(ctx, sdkcontext.SdkClientContextKey,
sdk.Client(client))
ctx = context.WithValue(ctx, sdkcontext.RuntimeContextKey,
runtimeContext)
- return executeTask(ctx, task, logger)
+ return executeTask(ctx, task, details.TIContext.ShouldRetry, logger)
}
// mapIndexPtr converts the supervisor's map_index (which uses -1 as the
@@ -131,6 +132,7 @@ func mapIndexPtr(mapIndex int) *int {
func executeTask(
ctx context.Context,
task bundlev1.Task,
+ shouldRetry bool,
logger *slog.Logger,
) (result map[string]any) {
defer func() {
@@ -139,19 +141,28 @@ func executeTask(
"error", r,
"stack", string(debug.Stack()),
)
- result = TaskStateMsg{
- State: TaskStateFailed,
- EndDate: time.Now().UTC(),
- }.toMap()
+ if shouldRetry {
+ result = RetryTaskMsg{
+ EndDate: time.Now().UTC(),
+ Reason: fmt.Sprintf("panic: %v", r),
+ }.toMap()
+ } else {
+ result = TaskStateMsg{
+ State: TaskStateFailed,
+ EndDate: time.Now().UTC(),
+ }.toMap()
+ }
}
}()
if err := task.Execute(ctx, logger); err != nil {
logger.ErrorContext(ctx, "Task failed", "error", err)
- // TODO(https://github.com/apache/airflow/issues/67797): emit
RetryTask
- // (UP_FOR_RETRY) when ti_context.should_retry is set. Today
every
- // failure maps to terminal FAILED because the supervisor
honors this
- // frame on exit 0 and we never send RetryTask, so retries are
lost.
+ if shouldRetry {
+ return RetryTaskMsg{
+ EndDate: time.Now().UTC(),
+ Reason: err.Error(),
+ }.toMap()
+ }
return TaskStateMsg{
State: TaskStateFailed,
EndDate: time.Now().UTC(),