This is an automated email from the ASF dual-hosted git repository.
jason810496 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3e480942fb0 Go-SDK: Expose task runtime context to coordinator-mode
tasks (#68271)
3e480942fb0 is described below
commit 3e480942fb06d501dd6ee6abad2151152f4d0521
Author: Jason(Zhe-You) Liu <[email protected]>
AuthorDate: Wed Jun 10 09:52:25 2026 +0800
Go-SDK: Expose task runtime context to coordinator-mode tasks (#68271)
Go tasks can now read their task-instance identifiers and the Dag run's
scheduling timestamps via sdk.CurrentContext(ctx): TI dag_id/run_id/
task_id/map_index/try_number and DagRun logical_date/data_interval_*.
This brings the Go coordinator path to parity with the Java/Python SDK
execution context.
Also fixes coordinator StartupDetails decoding, which read the
scheduling timestamps from the top level of ti_context instead of the
nested dag_run object the supervisor actually sends, so those dates
were always empty.
---
.../go_sdk_tests/test_go_sdk_dag.py | 59 +++++++++++++--
go-sdk/example/bundle/main.go | 25 +++++++
go-sdk/pkg/execution/integration_test.go | 84 ++++++++++++++++++++++
go-sdk/pkg/execution/messages.go | 12 +++-
go-sdk/pkg/execution/messages_test.go | 31 ++++++--
go-sdk/pkg/execution/task_runner.go | 28 ++++++++
go-sdk/pkg/sdkcontext/keys.go | 9 +++
go-sdk/sdk/context.go | 79 ++++++++++++++++++++
go-sdk/sdk/context_test.go | 53 ++++++++++++++
9 files changed, 369 insertions(+), 11 deletions(-)
diff --git
a/airflow-e2e-tests/tests/airflow_e2e_tests/go_sdk_tests/test_go_sdk_dag.py
b/airflow-e2e-tests/tests/airflow_e2e_tests/go_sdk_tests/test_go_sdk_dag.py
index 6cd7ff0779d..ad78b2dd3b1 100644
--- a/airflow-e2e-tests/tests/airflow_e2e_tests/go_sdk_tests/test_go_sdk_dag.py
+++ b/airflow-e2e-tests/tests/airflow_e2e_tests/go_sdk_tests/test_go_sdk_dag.py
@@ -53,10 +53,14 @@ end-to-end:
XCom crosses the Python <-> Go boundary in both directions.
3. Structured task logs emitted by the Go binary over the coordinator logs
channel reach Airflow's task-log store.
+4. The runtime context surfaced to Go tasks via ``sdk.CurrentContext`` is fully
+ populated on the coordinator path: the task-instance identifiers and the Dag
+ run's scheduling timestamps (logical_date / data_interval_start/end).
"""
from __future__ import annotations
+import re
import time
from dataclasses import dataclass
from datetime import datetime, timezone
@@ -92,14 +96,18 @@ class _CompletedRun:
"""Return the concatenated task-log records for *task_id*, retrying
until present."""
deadline = time.monotonic() + _LOG_FETCH_TIMEOUT
while True:
- resp = self.client.get_task_logs(
- dag_id=_DAG_ID, run_id=self.run_id, task_id=task_id,
try_number=try_number
- )
- text = "\n".join(str(entry) for entry in resp.get("content", []))
+ text = "\n".join(str(entry) for entry in self.log_records(task_id,
try_number))
if text.strip() or time.monotonic() > deadline:
return text
time.sleep(3)
+ def log_records(self, task_id: str, try_number: int = 1) -> list[dict]:
+ """Return the structured task-log records (parsed JSON dicts) for
*task_id*."""
+ resp = self.client.get_task_logs(
+ dag_id=_DAG_ID, run_id=self.run_id, task_id=task_id,
try_number=try_number
+ )
+ return [entry for entry in resp.get("content", []) if
isinstance(entry, dict)]
+
@pytest.fixture(scope="module")
def completed_run() -> _CompletedRun:
@@ -174,6 +182,49 @@ def test_extract_logs_show_beep_loop(completed_run:
_CompletedRun):
assert "Goodbye from task" in logs, "extract task should log 'Goodbye from
task'"
+def test_extract_logs_show_runtime_context(completed_run: _CompletedRun):
+ """The Go 'extract' task logs every field surfaced by
``sdk.CurrentContext``.
+
+ ``extract`` (main.go) emits one ``task runtime context`` record whose
fields
+ are grouped under ``context.ti.*`` and ``context.dag_run.*``. This confirms
+ the coordinator path populates the full runtime context end-to-end -- the
+ task-instance identifiers and the Dag run's scheduling timestamps (the
+ ti_context.dag_run.* dates the supervisor sends over msgpack).
+ """
+ # The fields are emitted as structured attributes on a single log record,
+ # so read them from the parsed record rather than the rendered text.
+ records = completed_run.log_records("extract")
+ record = next((r for r in records if r.get("event") == "task runtime
context"), None)
+ assert record is not None, (
+ f"extract should emit a 'task runtime context' record; events seen: "
+ f"{[r.get('event') for r in records]}"
+ )
+
+ run_id = completed_run.run_id
+
+ # Task-instance identifiers come straight from the task instance.
+ assert record.get("context.ti.dag_id") == _DAG_ID, record
+ assert record.get("context.ti.task_id") == "extract", record
+ assert record.get("context.ti.run_id") == run_id, record
+ assert str(record.get("context.ti.try_number")) == "1", record
+ # Unmapped task -> nil *int -> logged as null.
+ assert record.get("context.ti.map_index") is None, record
+
+ # The Dag run mirrors the same ids and carries the scheduling timestamps.
+ assert record.get("context.dag_run.dag_id") == _DAG_ID, record
+ assert record.get("context.dag_run.run_id") == run_id, record
+
+ iso_prefix = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
+ for ts_field in (
+ "context.dag_run.logical_date",
+ "context.dag_run.data_interval_start",
+ "context.dag_run.data_interval_end",
+ ):
+ value = record.get(ts_field)
+ assert value, f"{ts_field} should be present and non-empty; record:
{record}"
+ assert iso_prefix.match(str(value)), f"{ts_field} should be an
ISO-8601 timestamp, got {value!r}"
+
+
def test_transform_logs_show_variable_read(completed_run: _CompletedRun):
"""The Go 'transform' task logs the variable it read."""
assert "Obtained variable" in completed_run.logs("transform"), (
diff --git a/go-sdk/example/bundle/main.go b/go-sdk/example/bundle/main.go
index f57ec0b67f3..284439741f3 100644
--- a/go-sdk/example/bundle/main.go
+++ b/go-sdk/example/bundle/main.go
@@ -62,6 +62,31 @@ func main() {
func extract(ctx context.Context, client sdk.Client, log *slog.Logger) (any,
error) {
log.Info("Hello from task")
+
+ // Log every field the runtime context exposes. The fields are
namespaced
+ // under a "context" group (so they serialise as context.ti.* /
+ // context.dag_run.* dotted keys) to avoid colliding with the reserved
+ // task_id/run_id/etc. keys the supervisor strips from its log view.
+ rc := sdk.CurrentContext(ctx)
+ log.InfoContext(ctx, "task runtime context",
+ slog.Group("context",
+ slog.Group("ti",
+ "dag_id", rc.TI.DagID,
+ "run_id", rc.TI.RunID,
+ "task_id", rc.TI.TaskID,
+ "map_index", rc.TI.MapIndex,
+ "try_number", rc.TI.TryNumber,
+ ),
+ slog.Group("dag_run",
+ "dag_id", rc.DagRun.DagID,
+ "run_id", rc.DagRun.RunID,
+ "logical_date", rc.DagRun.LogicalDate,
+ "data_interval_start",
rc.DagRun.DataIntervalStart,
+ "data_interval_end", rc.DagRun.DataIntervalEnd,
+ ),
+ ),
+ )
+
conn, err := client.GetConnection(ctx, "test_http")
if err != nil {
log.ErrorContext(ctx, "unable to get conn", "error", err)
diff --git a/go-sdk/pkg/execution/integration_test.go
b/go-sdk/pkg/execution/integration_test.go
index ed0f437cff6..0520e922995 100644
--- a/go-sdk/pkg/execution/integration_test.go
+++ b/go-sdk/pkg/execution/integration_test.go
@@ -183,6 +183,90 @@ func TestRunTaskHonorsContextCancellation(t *testing.T) {
assert.Equal(t, "failed", result["state"])
}
+func TestRunTaskInjectsRuntimeContext(t *testing.T) {
+ logical := time.Date(2026, 6, 9, 12, 0, 0, 0, time.UTC)
+ start := logical
+ end := logical.Add(time.Hour)
+
+ var got sdk.RuntimeContext
+ bundle := buildBundle(t, func(r bundlev1.Registry) {
+ r.AddDag("test_dag").AddTaskWithName("ctxgrab",
+ func(ctx context.Context) error {
+ got = sdk.CurrentContext(ctx)
+ return nil
+ })
+ })
+
+ details := &StartupDetails{
+ TI: TaskInstanceInfo{
+ ID: "550e8400-e29b-41d4-a716-446655440000",
+ DagID: "test_dag",
+ TaskID: "ctxgrab",
+ RunID: "run1",
+ TryNumber: 2,
+ MapIndex: -1,
+ },
+ BundleInfo: BundleInfoMsg{Name: "test", Version: "1.0"},
+ TIContext: TIRunContext{
+ LogicalDate: &logical,
+ DataIntervalStart: &start,
+ DataIntervalEnd: &end,
+ },
+ }
+
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ comm := NewCoordinatorComm(bytes.NewReader(nil), io.Discard, logger)
+
+ result := RunTask(context.Background(), bundle, details, comm, logger)
+ require.Equal(t, "SucceedTask", result["type"])
+
+ assert.Equal(t, "test_dag", got.TI.DagID)
+ assert.Equal(t, "run1", got.TI.RunID)
+ assert.Equal(t, "ctxgrab", got.TI.TaskID)
+ assert.Equal(t, 2, got.TI.TryNumber)
+ assert.Nil(t, got.TI.MapIndex, "an unmapped task (map_index -1) must
surface as nil")
+
+ assert.Equal(t, "test_dag", got.DagRun.DagID)
+ assert.Equal(t, "run1", got.DagRun.RunID)
+ require.NotNil(t, got.DagRun.LogicalDate)
+ assert.Equal(t, logical, *got.DagRun.LogicalDate)
+ require.NotNil(t, got.DagRun.DataIntervalStart)
+ assert.Equal(t, start, *got.DagRun.DataIntervalStart)
+ require.NotNil(t, got.DagRun.DataIntervalEnd)
+ assert.Equal(t, end, *got.DagRun.DataIntervalEnd)
+}
+
+func TestRunTaskRuntimeContextMappedIndex(t *testing.T) {
+ var got sdk.RuntimeContext
+ bundle := buildBundle(t, func(r bundlev1.Registry) {
+ r.AddDag("test_dag").AddTaskWithName("ctxgrab",
+ func(ctx context.Context) error {
+ got = sdk.CurrentContext(ctx)
+ return nil
+ })
+ })
+
+ details := &StartupDetails{
+ TI: TaskInstanceInfo{
+ ID: "550e8400-e29b-41d4-a716-446655440000",
+ DagID: "test_dag",
+ TaskID: "ctxgrab",
+ RunID: "run1",
+ MapIndex: 5,
+ },
+ BundleInfo: BundleInfoMsg{Name: "test", Version: "1.0"},
+ }
+
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ comm := NewCoordinatorComm(bytes.NewReader(nil), io.Discard, logger)
+
+ result := RunTask(context.Background(), bundle, details, comm, logger)
+ require.Equal(t, "SucceedTask", result["type"])
+
+ require.NotNil(t, got.TI.MapIndex, "a mapped task must surface its
index")
+ assert.Equal(t, 5, *got.TI.MapIndex)
+}
+
// --- End-to-end Serve test against a fake supervisor ---
// fakeProvider implements bundlev1.BundleProvider; it lets a test inject the
diff --git a/go-sdk/pkg/execution/messages.go b/go-sdk/pkg/execution/messages.go
index b41a2e9adf6..4cb85cebbb6 100644
--- a/go-sdk/pkg/execution/messages.go
+++ b/go-sdk/pkg/execution/messages.go
@@ -110,6 +110,14 @@ func decodeTIRunContext(m map[string]any) (TIRunContext,
error) {
if m == nil {
return TIRunContext{}, nil
}
+ // The scheduling timestamps live on the nested dag_run object in the
+ // supervisor's TIRunContext schema (ti_context.dag_run.logical_date,
...),
+ // not at the top level of ti_context. See task-sdk's
+ // airflow.sdk.api.datamodels._generated.{TIRunContext,DagRun}.
+ dagRun := mapMap(m, "dag_run")
+ if dagRun == nil {
+ return TIRunContext{}, nil
+ }
ctx := TIRunContext{}
for _, f := range []struct {
key string
@@ -119,13 +127,13 @@ func decodeTIRunContext(m map[string]any) (TIRunContext,
error) {
{"data_interval_start", &ctx.DataIntervalStart},
{"data_interval_end", &ctx.DataIntervalEnd},
} {
- raw, present := m[f.key]
+ raw, present := dagRun[f.key]
if !present || raw == nil {
continue
}
t, err := asTime(raw)
if err != nil {
- return TIRunContext{}, fmt.Errorf("ti_context.%s: %w",
f.key, err)
+ return TIRunContext{},
fmt.Errorf("ti_context.dag_run.%s: %w", f.key, err)
}
*f.dst = &t
}
diff --git a/go-sdk/pkg/execution/messages_test.go
b/go-sdk/pkg/execution/messages_test.go
index 289accee257..e1e1e3efce2 100644
--- a/go-sdk/pkg/execution/messages_test.go
+++ b/go-sdk/pkg/execution/messages_test.go
@@ -44,10 +44,16 @@ func TestDecodeStartupDetails(t *testing.T) {
},
"start_date": "2024-01-15T10:30:00Z",
"sentry_integration": "",
+ // The supervisor nests the scheduling timestamps under
dag_run, matching
+ // task-sdk's TIRunContext / DagRun schema.
"ti_context": map[string]any{
- "logical_date": "2024-01-15T00:00:00Z",
- "data_interval_start": "2024-01-14T00:00:00Z",
- "data_interval_end": "2024-01-15T00:00:00Z",
+ "dag_run": map[string]any{
+ "dag_id": "tutorial_dag",
+ "run_id": "manual__2024-01-15",
+ "logical_date": "2024-01-15T00:00:00Z",
+ "data_interval_start": "2024-01-14T00:00:00Z",
+ "data_interval_end": "2024-01-15T00:00:00Z",
+ },
},
}
@@ -63,7 +69,20 @@ func TestDecodeStartupDetails(t *testing.T) {
assert.Equal(t, "dags/tutorial.go", details.DagRelPath)
assert.Equal(t, "example_dags", details.BundleInfo.Name)
assert.Equal(t, "1.0.0", details.BundleInfo.Version)
- assert.NotNil(t, details.TIContext.LogicalDate)
+ require.NotNil(t, details.TIContext.LogicalDate)
+ assert.Equal(t, time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC),
*details.TIContext.LogicalDate)
+ require.NotNil(t, details.TIContext.DataIntervalStart)
+ assert.Equal(
+ t,
+ time.Date(2024, 1, 14, 0, 0, 0, 0, time.UTC),
+ *details.TIContext.DataIntervalStart,
+ )
+ require.NotNil(t, details.TIContext.DataIntervalEnd)
+ assert.Equal(
+ t,
+ time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC),
+ *details.TIContext.DataIntervalEnd,
+ )
}
func TestDecodeStartupDetails_MalformedStartDate(t *testing.T) {
@@ -93,7 +112,9 @@ func TestDecodeStartupDetails_MalformedTIRunContext(t
*testing.T) {
"try_number": int64(1),
},
"ti_context": map[string]any{
- "logical_date": "garbage",
+ "dag_run": map[string]any{
+ "logical_date": "garbage",
+ },
},
}
_, err := decodeStartupDetails(m)
diff --git a/go-sdk/pkg/execution/task_runner.go
b/go-sdk/pkg/execution/task_runner.go
index dba184b588e..9a0e1eca2fa 100644
--- a/go-sdk/pkg/execution/task_runner.go
+++ b/go-sdk/pkg/execution/task_runner.go
@@ -89,12 +89,40 @@ func RunTask(
},
}
+ runtimeContext := sdk.RuntimeContext{
+ TI: sdk.TaskInstance{
+ DagID: details.TI.DagID,
+ RunID: details.TI.RunID,
+ TaskID: details.TI.TaskID,
+ MapIndex: mapIndexPtr(details.TI.MapIndex),
+ TryNumber: details.TI.TryNumber,
+ },
+ DagRun: sdk.DagRun{
+ DagID: details.TI.DagID,
+ RunID: details.TI.RunID,
+ LogicalDate: details.TIContext.LogicalDate,
+ DataIntervalStart: details.TIContext.DataIntervalStart,
+ DataIntervalEnd: details.TIContext.DataIntervalEnd,
+ },
+ }
+
ctx = context.WithValue(ctx, sdkcontext.WorkloadContextKey, workload)
ctx = context.WithValue(ctx, sdkcontext.SdkClientContextKey,
sdk.Client(client))
+ ctx = context.WithValue(ctx, sdkcontext.RuntimeContextKey,
runtimeContext)
return executeTask(ctx, task, logger)
}
+// mapIndexPtr converts the supervisor's map_index (which uses -1 as the
+// sentinel for an unmapped task) into the optional form exposed on
+// sdk.TaskInstance: nil for an unmapped task, otherwise a pointer to the
index.
+func mapIndexPtr(mapIndex int) *int {
+ if mapIndex < 0 {
+ return nil
+ }
+ return &mapIndex
+}
+
// executeTask runs the task and handles success, failure, and panics.
func executeTask(
ctx context.Context,
diff --git a/go-sdk/pkg/sdkcontext/keys.go b/go-sdk/pkg/sdkcontext/keys.go
index ad83dfce3bd..8c64e78a03a 100644
--- a/go-sdk/pkg/sdkcontext/keys.go
+++ b/go-sdk/pkg/sdkcontext/keys.go
@@ -22,6 +22,7 @@ type (
apiClientContextKey struct{}
workerContextKey struct{}
runtimeTIContextKey struct{}
+ runtimeContextKey struct{}
sdkClientContextKey struct{}
)
@@ -34,6 +35,14 @@ var (
ApiClientContextKey = apiClientContextKey{}
WorkerContextKey = workerContextKey{}
+ // RuntimeContextKey stores the public, task-facing runtime context
+ // (task instance identifiers and the Dag run's scheduling timestamps).
+ // The coordinator-mode runtime populates it from StartupDetails; task
+ // functions read it through sdk.CurrentContext rather than
+ // touching this key directly. Its value type is sdk.RuntimeContext, but
+ // this package does not import sdk to avoid an import cycle.
+ RuntimeContextKey = runtimeContextKey{}
+
// SdkClientContextKey, when present, holds an sdk.Client implementation
// that should be injected into task functions instead of constructing a
// default HTTP-backed client. The coordinator-mode runtime uses this to
diff --git a/go-sdk/sdk/context.go b/go-sdk/sdk/context.go
new file mode 100644
index 00000000000..5e56d188da2
--- /dev/null
+++ b/go-sdk/sdk/context.go
@@ -0,0 +1,79 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package sdk
+
+import (
+ "context"
+ "time"
+
+ "github.com/apache/airflow/go-sdk/pkg/sdkcontext"
+)
+
+// RuntimeContext carries the identifiers and scheduling timestamps of the task
+// instance that is currently executing, along with the Dag run it belongs to.
+// It is the Go equivalent of the execution context the Python and Java SDKs
+// expose to task authors.
+//
+// Retrieve it inside a task function with CurrentContext:
+//
+// func myTask(ctx context.Context, log *slog.Logger) error {
+// rc := sdk.CurrentContext(ctx)
+// log.Info("running", "task_id", rc.TI.TaskID, "run_id",
rc.DagRun.RunID)
+// return nil
+// }
+type RuntimeContext struct {
+ // TI identifies the task instance that is executing.
+ TI TaskInstance
+ // DagRun identifies the Dag run the task instance belongs to.
+ DagRun DagRun
+}
+
+// TaskInstance identifies the currently executing task instance.
+type TaskInstance struct {
+ DagID string
+ RunID string
+ TaskID string
+ // MapIndex is the index within a dynamically mapped task, or nil for an
+ // unmapped (regular) task instance.
+ MapIndex *int
+ TryNumber int
+}
+
+// DagRun identifies the Dag run the current task instance belongs to and
+// carries its scheduling timestamps. The *time.Time fields are nil when the
+// supervisor did not provide a value (for example, a manually triggered run
+// without a logical date).
+type DagRun struct {
+ DagID string
+ RunID string
+ LogicalDate *time.Time
+ DataIntervalStart *time.Time
+ DataIntervalEnd *time.Time
+}
+
+// CurrentContext returns the RuntimeContext the runtime stored on ctx for the
+// executing task. When ctx carries no RuntimeContext (for example when called
+// outside of a running task) it returns the zero value.
+//
+// It takes ctx explicitly rather than reading task-local state like Python's
+// get_current_context, because Go has no goroutine-local storage and the
+// worker path runs multiple tasks concurrently in one process.
+func CurrentContext(ctx context.Context) RuntimeContext {
+ rc, _ := ctx.Value(sdkcontext.RuntimeContextKey).(RuntimeContext)
+ return rc
+}
diff --git a/go-sdk/sdk/context_test.go b/go-sdk/sdk/context_test.go
new file mode 100644
index 00000000000..c63e79a470d
--- /dev/null
+++ b/go-sdk/sdk/context_test.go
@@ -0,0 +1,53 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package sdk
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+
+ "github.com/apache/airflow/go-sdk/pkg/sdkcontext"
+)
+
+func TestCurrentContext(t *testing.T) {
+ logical := time.Date(2026, 6, 9, 12, 0, 0, 0, time.UTC)
+ want := RuntimeContext{
+ TI: TaskInstance{
+ DagID: "dag1",
+ RunID: "run1",
+ TaskID: "task1",
+ MapIndex: ptr(3),
+ TryNumber: 2,
+ },
+ DagRun: DagRun{
+ DagID: "dag1",
+ RunID: "run1",
+ LogicalDate: &logical,
+ },
+ }
+
+ ctx := context.WithValue(context.Background(),
sdkcontext.RuntimeContextKey, want)
+ assert.Equal(t, want, CurrentContext(ctx))
+}
+
+func TestCurrentContextAbsentReturnsZero(t *testing.T) {
+ assert.Equal(t, RuntimeContext{}, CurrentContext(context.Background()))
+}