This is an automated email from the ASF dual-hosted git repository.

jason810496 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d411f55c7e3 Go-SDK: Inject task runtime context as sdk.TIRunContext 
parameter (#68349)
d411f55c7e3 is described below

commit d411f55c7e36ac9a87b41e1d7b57f89bbfa75376
Author: Jason(Zhe-You) Liu <[email protected]>
AuthorDate: Thu Jun 11 19:14:45 2026 +0800

    Go-SDK: Inject task runtime context as sdk.TIRunContext parameter (#68349)
    
    * Go-SDK: Inject task runtime context as sdk.TIRunContext parameter
    
    Replace the sdk.CurrentContext(ctx) accessor and sdk.RuntimeContext type
    with sdk.TIRunContext, which embeds context.Context and carries the task
    instance identifiers and Dag run scheduling timestamps. Task authors now
    declare it as the task's context parameter and the runtime binds it by
    type, so they no longer pass a context to a function to get a context:
    
        func extract(ctx sdk.TIRunContext, log *slog.Logger) (any, error) {
            log.Info("run", "task_id", ctx.TI.TaskID, "run_id", 
ctx.DagRun.RunID)
        }
    
    Because TIRunContext is itself a context.Context, it flows straight into
    client calls, ctx.Done(), and downstream helpers. Plain context.Context
    binding is retained for the Edge Worker path, which does not yet populate
    the runtime context.
    
    * Refactor: Make sdk.TIRunContext an interface instead of struct
---
 go-sdk/bundle/bundlev1/task.go           | 26 ++++++++--
 go-sdk/bundle/bundlev1/task_test.go      | 77 +++++++++++++++++++++++++++++
 go-sdk/example/bundle/main.go            | 37 +++++++-------
 go-sdk/example/bundle/main_test.go       |  4 +-
 go-sdk/pkg/execution/integration_test.go | 51 +++++++++++---------
 go-sdk/pkg/execution/task_runner.go      | 12 +++--
 go-sdk/pkg/sdkcontext/keys.go            | 10 ++--
 go-sdk/sdk/context.go                    | 83 ++++++++++++++++++++++----------
 go-sdk/sdk/context_test.go               | 49 ++++++++++---------
 9 files changed, 247 insertions(+), 102 deletions(-)

diff --git a/go-sdk/bundle/bundlev1/task.go b/go-sdk/bundle/bundlev1/task.go
index 0a3aca70666..9ac3cbfc3cf 100644
--- a/go-sdk/bundle/bundlev1/task.go
+++ b/go-sdk/bundle/bundlev1/task.go
@@ -61,7 +61,22 @@ func (f *taskFunction) Execute(ctx context.Context, logger 
*slog.Logger) error {
                in := fnType.In(i)
 
                switch {
+               case isTIRunContext(in):
+                       // sdk.TIRunContext embeds context.Context, so it also 
satisfies
+                       // isContext - this case must come first. The runtime 
stores the
+                       // identifiers/timestamps under RuntimeContextKey; 
rebuild the
+                       // value around the live task context here.
+                       var ti sdk.TaskInstance
+                       var dagRun sdk.DagRun
+                       if stored, ok := 
ctx.Value(sdkcontext.RuntimeContextKey).(sdk.TIRunContext); ok {
+                               ti, dagRun = stored.TaskInstance(), 
stored.DagRun()
+                       }
+                       reflectArgs[i] = 
reflect.ValueOf(sdk.NewTIRunContext(ctx, ti, dagRun))
                case isContext(in):
+                       // Plain context.Context injection is retained for the 
Edge Worker
+                       // runtime path, which does not populate the task 
runtime context
+                       // (TI/DagRun) that sdk.TIRunContext carries. New tasks 
should
+                       // declare sdk.TIRunContext instead.
                        reflectArgs[i] = reflect.ValueOf(ctx)
                case isLogger(in):
                        reflectArgs[i] = reflect.ValueOf(logger)
@@ -148,9 +163,10 @@ func isValidResultType(inType reflect.Type) bool {
 }
 
 var (
-       errorType      = reflect.TypeFor[error]()
-       contextType    = reflect.TypeFor[context.Context]()
-       slogLoggerType = reflect.TypeFor[*slog.Logger]()
+       errorType        = reflect.TypeFor[error]()
+       contextType      = reflect.TypeFor[context.Context]()
+       tiRunContextType = reflect.TypeFor[sdk.TIRunContext]()
+       slogLoggerType   = reflect.TypeFor[*slog.Logger]()
 
        connClientType = reflect.TypeFor[sdk.ConnectionClient]()
        varClientType  = reflect.TypeFor[sdk.VariableClient]()
@@ -165,6 +181,10 @@ func isContext(inType reflect.Type) bool {
        return inType != nil && inType.Implements(contextType)
 }
 
+func isTIRunContext(inType reflect.Type) bool {
+       return inType == tiRunContextType
+}
+
 func isLogger(inType reflect.Type) bool {
        return inType != nil && inType.AssignableTo(slogLoggerType)
 }
diff --git a/go-sdk/bundle/bundlev1/task_test.go 
b/go-sdk/bundle/bundlev1/task_test.go
index d5f95c8251e..0caf8154f69 100644
--- a/go-sdk/bundle/bundlev1/task_test.go
+++ b/go-sdk/bundle/bundlev1/task_test.go
@@ -25,6 +25,7 @@ import (
        "github.com/stretchr/testify/suite"
 
        "github.com/apache/airflow/go-sdk/pkg/logging"
+       "github.com/apache/airflow/go-sdk/pkg/sdkcontext"
        "github.com/apache/airflow/go-sdk/sdk"
 )
 
@@ -117,3 +118,79 @@ func (s *TaskSuite) TestArgumentBinding() {
                })
        }
 }
+
+// probeKey is an unexported context key used to confirm the live task context
+// (not a freshly built one) backs the injected sdk.TIRunContext.
+type probeKeyType struct{}
+
+var probeKey probeKeyType
+
+// TestTIRunContextInjection verifies a task declaring sdk.TIRunContext 
receives
+// the TaskInstance/DagRun stored on the context, backed by the live task
+// context so it is usable as a context.Context. It must take precedence over
+// the plain context.Context binding, which sdk.TIRunContext also satisfies.
+func (s *TaskSuite) TestTIRunContextInjection() {
+       mapIndex := 3
+       ti := sdk.TaskInstance{
+               DagID:     "dag1",
+               RunID:     "run1",
+               TaskID:    "task1",
+               MapIndex:  &mapIndex,
+               TryNumber: 2,
+       }
+       dagRun := sdk.DagRun{DagID: "dag1", RunID: "run1"}
+       stored := sdk.NewTIRunContext(context.Background(), ti, dagRun)
+
+       var got sdk.TIRunContext
+       task, err := NewTaskFunction(func(ctx sdk.TIRunContext) error {
+               got = ctx
+               return nil
+       })
+       s.Require().NoError(err)
+
+       ctx := context.WithValue(context.Background(), 
sdkcontext.RuntimeContextKey, stored)
+       ctx = context.WithValue(ctx, probeKey, "probe-value")
+       s.Require().NoError(task.Execute(ctx, slog.New(logging.NewTeeLogger())))
+
+       s.Require().NotNil(got, "the task must receive a non-nil TIRunContext")
+       s.Equal(ti, got.TaskInstance())
+       s.Equal(dagRun, got.DagRun())
+       s.Equal(
+               "probe-value",
+               got.Value(probeKey),
+               "the injected context must be backed by the one passed to 
Execute",
+       )
+}
+
+// TestTIRunContextInjectionWithoutRuntimeContext covers the Edge Worker path:
+// the runtime does not populate RuntimeContextKey, so a task declaring
+// sdk.TIRunContext gets zero TaskInstance/DagRun but is still backed by the
+// live task context, leaving it usable as a context.Context.
+func (s *TaskSuite) TestTIRunContextInjectionWithoutRuntimeContext() {
+       var got sdk.TIRunContext
+       task, err := NewTaskFunction(func(ctx sdk.TIRunContext) error {
+               got = ctx
+               return nil
+       })
+       s.Require().NoError(err)
+
+       ctx := context.WithValue(context.Background(), probeKey, "probe-value")
+       s.Require().NoError(task.Execute(ctx, slog.New(logging.NewTeeLogger())))
+
+       s.Require().NotNil(got, "the task must receive a non-nil TIRunContext")
+       s.Equal(
+               sdk.TaskInstance{},
+               got.TaskInstance(),
+               "TaskInstance must be zero when no runtime context is present",
+       )
+       s.Equal(
+               sdk.DagRun{},
+               got.DagRun(),
+               "DagRun must be zero when no runtime context is present",
+       )
+       s.Equal(
+               "probe-value",
+               got.Value(probeKey),
+               "the injected context must be backed by the one passed to 
Execute",
+       )
+}
diff --git a/go-sdk/example/bundle/main.go b/go-sdk/example/bundle/main.go
index 284439741f3..8f7513dc355 100644
--- a/go-sdk/example/bundle/main.go
+++ b/go-sdk/example/bundle/main.go
@@ -18,7 +18,6 @@
 package main
 
 import (
-       "context"
        "fmt"
        "log"
        "log/slog"
@@ -60,29 +59,31 @@ func main() {
        }
 }
 
-func extract(ctx context.Context, client sdk.Client, log *slog.Logger) (any, 
error) {
+func extract(ctx sdk.TIRunContext, client sdk.Client, log *slog.Logger) (any, 
error) {
        log.Info("Hello from task")
 
-       // Log every field the runtime context exposes. The fields are 
namespaced
-       // under a "context" group (so they serialise as context.ti.* /
-       // context.dag_run.* dotted keys) to avoid colliding with the reserved
-       // task_id/run_id/etc. keys the supervisor strips from its log view.
-       rc := sdk.CurrentContext(ctx)
+       // ctx behaves as a context.Context and also carries the task instance
+       // identifiers and the Dag run's scheduling timestamps. Log every field 
the
+       // runtime context exposes. The fields are namespaced under a "context"
+       // group (so they serialise as context.ti.* / context.dag_run.* dotted
+       // keys) to avoid colliding with the reserved task_id/run_id/etc. keys 
the
+       // supervisor strips from its log view.
+       ti, dagRun := ctx.TaskInstance(), ctx.DagRun()
        log.InfoContext(ctx, "task runtime context",
                slog.Group("context",
                        slog.Group("ti",
-                               "dag_id", rc.TI.DagID,
-                               "run_id", rc.TI.RunID,
-                               "task_id", rc.TI.TaskID,
-                               "map_index", rc.TI.MapIndex,
-                               "try_number", rc.TI.TryNumber,
+                               "dag_id", ti.DagID,
+                               "run_id", ti.RunID,
+                               "task_id", ti.TaskID,
+                               "map_index", ti.MapIndex,
+                               "try_number", ti.TryNumber,
                        ),
                        slog.Group("dag_run",
-                               "dag_id", rc.DagRun.DagID,
-                               "run_id", rc.DagRun.RunID,
-                               "logical_date", rc.DagRun.LogicalDate,
-                               "data_interval_start", 
rc.DagRun.DataIntervalStart,
-                               "data_interval_end", rc.DagRun.DataIntervalEnd,
+                               "dag_id", dagRun.DagID,
+                               "run_id", dagRun.RunID,
+                               "logical_date", dagRun.LogicalDate,
+                               "data_interval_start", dagRun.DataIntervalStart,
+                               "data_interval_end", dagRun.DataIntervalEnd,
                        ),
                ),
        )
@@ -121,7 +122,7 @@ func extract(ctx context.Context, client sdk.Client, log 
*slog.Logger) (any, err
        return ret, nil
 }
 
-func transform(ctx context.Context, client sdk.VariableClient, log 
*slog.Logger) error {
+func transform(ctx sdk.TIRunContext, client sdk.VariableClient, log 
*slog.Logger) error {
        // This function takes a VariableClient and not a Client to make unit 
testing it easier. See
        // `./main_test.go` for an example unit of this task fn. Functionally 
taking a `sdk.Client` is the same (as
        // Client includes VariableClient) but by using the dedicated type it 
can be easier to write unit tests.
diff --git a/go-sdk/example/bundle/main_test.go 
b/go-sdk/example/bundle/main_test.go
index 54ff86c1044..474a8406d62 100644
--- a/go-sdk/example/bundle/main_test.go
+++ b/go-sdk/example/bundle/main_test.go
@@ -52,6 +52,8 @@ var _ sdk.VariableClient = (*mockVars)(nil)
 func Test_transform(t *testing.T) {
        log := slog.Default()
        // This is not the best test, but it is a good proof of concept -- you 
can just call the function.
-       err := transform(context.Background(), &mockVars{}, log)
+       // sdk.NewTIRunContext wraps any context to build a TIRunContext in a 
test.
+       ctx := sdk.NewTIRunContext(context.Background(), sdk.TaskInstance{}, 
sdk.DagRun{})
+       err := transform(ctx, &mockVars{}, log)
        assert.NoError(t, err)
 }
diff --git a/go-sdk/pkg/execution/integration_test.go 
b/go-sdk/pkg/execution/integration_test.go
index 0520e922995..47f63b6d26e 100644
--- a/go-sdk/pkg/execution/integration_test.go
+++ b/go-sdk/pkg/execution/integration_test.go
@@ -188,11 +188,11 @@ func TestRunTaskInjectsRuntimeContext(t *testing.T) {
        start := logical
        end := logical.Add(time.Hour)
 
-       var got sdk.RuntimeContext
+       var got sdk.TIRunContext
        bundle := buildBundle(t, func(r bundlev1.Registry) {
                r.AddDag("test_dag").AddTaskWithName("ctxgrab",
-                       func(ctx context.Context) error {
-                               got = sdk.CurrentContext(ctx)
+                       func(ctx sdk.TIRunContext) error {
+                               got = ctx
                                return nil
                        })
        })
@@ -220,28 +220,35 @@ func TestRunTaskInjectsRuntimeContext(t *testing.T) {
        result := RunTask(context.Background(), bundle, details, comm, logger)
        require.Equal(t, "SucceedTask", result["type"])
 
-       assert.Equal(t, "test_dag", got.TI.DagID)
-       assert.Equal(t, "run1", got.TI.RunID)
-       assert.Equal(t, "ctxgrab", got.TI.TaskID)
-       assert.Equal(t, 2, got.TI.TryNumber)
-       assert.Nil(t, got.TI.MapIndex, "an unmapped task (map_index -1) must 
surface as nil")
-
-       assert.Equal(t, "test_dag", got.DagRun.DagID)
-       assert.Equal(t, "run1", got.DagRun.RunID)
-       require.NotNil(t, got.DagRun.LogicalDate)
-       assert.Equal(t, logical, *got.DagRun.LogicalDate)
-       require.NotNil(t, got.DagRun.DataIntervalStart)
-       assert.Equal(t, start, *got.DagRun.DataIntervalStart)
-       require.NotNil(t, got.DagRun.DataIntervalEnd)
-       assert.Equal(t, end, *got.DagRun.DataIntervalEnd)
+       require.NotNil(
+               t,
+               got,
+               "the task must receive a TIRunContext backed by the live task 
context",
+       )
+       ti := got.TaskInstance()
+       assert.Equal(t, "test_dag", ti.DagID)
+       assert.Equal(t, "run1", ti.RunID)
+       assert.Equal(t, "ctxgrab", ti.TaskID)
+       assert.Equal(t, 2, ti.TryNumber)
+       assert.Nil(t, ti.MapIndex, "an unmapped task (map_index -1) must 
surface as nil")
+
+       dagRun := got.DagRun()
+       assert.Equal(t, "test_dag", dagRun.DagID)
+       assert.Equal(t, "run1", dagRun.RunID)
+       require.NotNil(t, dagRun.LogicalDate)
+       assert.Equal(t, logical, *dagRun.LogicalDate)
+       require.NotNil(t, dagRun.DataIntervalStart)
+       assert.Equal(t, start, *dagRun.DataIntervalStart)
+       require.NotNil(t, dagRun.DataIntervalEnd)
+       assert.Equal(t, end, *dagRun.DataIntervalEnd)
 }
 
 func TestRunTaskRuntimeContextMappedIndex(t *testing.T) {
-       var got sdk.RuntimeContext
+       var got sdk.TIRunContext
        bundle := buildBundle(t, func(r bundlev1.Registry) {
                r.AddDag("test_dag").AddTaskWithName("ctxgrab",
-                       func(ctx context.Context) error {
-                               got = sdk.CurrentContext(ctx)
+                       func(ctx sdk.TIRunContext) error {
+                               got = ctx
                                return nil
                        })
        })
@@ -263,8 +270,8 @@ func TestRunTaskRuntimeContextMappedIndex(t *testing.T) {
        result := RunTask(context.Background(), bundle, details, comm, logger)
        require.Equal(t, "SucceedTask", result["type"])
 
-       require.NotNil(t, got.TI.MapIndex, "a mapped task must surface its 
index")
-       assert.Equal(t, 5, *got.TI.MapIndex)
+       require.NotNil(t, got.TaskInstance().MapIndex, "a mapped task must 
surface its index")
+       assert.Equal(t, 5, *got.TaskInstance().MapIndex)
 }
 
 // --- End-to-end Serve test against a fake supervisor ---
diff --git a/go-sdk/pkg/execution/task_runner.go 
b/go-sdk/pkg/execution/task_runner.go
index 9a0e1eca2fa..f8bbd539751 100644
--- a/go-sdk/pkg/execution/task_runner.go
+++ b/go-sdk/pkg/execution/task_runner.go
@@ -89,22 +89,26 @@ func RunTask(
                },
        }
 
-       runtimeContext := sdk.RuntimeContext{
-               TI: sdk.TaskInstance{
+       // Carries the task runtime context for sdk.TIRunContext injection. The
+       // base context is a placeholder; bundlev1.Execute rebuilds the value
+       // around the live task context when binding the parameter.
+       runtimeContext := sdk.NewTIRunContext(
+               context.Background(),
+               sdk.TaskInstance{
                        DagID:     details.TI.DagID,
                        RunID:     details.TI.RunID,
                        TaskID:    details.TI.TaskID,
                        MapIndex:  mapIndexPtr(details.TI.MapIndex),
                        TryNumber: details.TI.TryNumber,
                },
-               DagRun: sdk.DagRun{
+               sdk.DagRun{
                        DagID:             details.TI.DagID,
                        RunID:             details.TI.RunID,
                        LogicalDate:       details.TIContext.LogicalDate,
                        DataIntervalStart: details.TIContext.DataIntervalStart,
                        DataIntervalEnd:   details.TIContext.DataIntervalEnd,
                },
-       }
+       )
 
        ctx = context.WithValue(ctx, sdkcontext.WorkloadContextKey, workload)
        ctx = context.WithValue(ctx, sdkcontext.SdkClientContextKey, 
sdk.Client(client))
diff --git a/go-sdk/pkg/sdkcontext/keys.go b/go-sdk/pkg/sdkcontext/keys.go
index 8c64e78a03a..677952e9a9f 100644
--- a/go-sdk/pkg/sdkcontext/keys.go
+++ b/go-sdk/pkg/sdkcontext/keys.go
@@ -37,10 +37,12 @@ var (
 
        // RuntimeContextKey stores the public, task-facing runtime context
        // (task instance identifiers and the Dag run's scheduling timestamps).
-       // The coordinator-mode runtime populates it from StartupDetails; task
-       // functions read it through sdk.CurrentContext rather than
-       // touching this key directly. Its value type is sdk.RuntimeContext, but
-       // this package does not import sdk to avoid an import cycle.
+       // The coordinator-mode runtime populates it from StartupDetails; the
+       // bundle runtime reads it to inject an sdk.TIRunContext parameter into
+       // task functions rather than exposing this key directly. Its value type
+       // is sdk.TIRunContext (built over a placeholder base context; the 
bundle
+       // runtime rebuilds it around the live task context at injection time),
+       // but this package does not import sdk to avoid an import cycle.
        RuntimeContextKey = runtimeContextKey{}
 
        // SdkClientContextKey, when present, holds an sdk.Client implementation
diff --git a/go-sdk/sdk/context.go b/go-sdk/sdk/context.go
index 5e56d188da2..afd0dd37050 100644
--- a/go-sdk/sdk/context.go
+++ b/go-sdk/sdk/context.go
@@ -20,29 +20,72 @@ package sdk
 import (
        "context"
        "time"
-
-       "github.com/apache/airflow/go-sdk/pkg/sdkcontext"
 )
 
-// RuntimeContext carries the identifiers and scheduling timestamps of the task
-// instance that is currently executing, along with the Dag run it belongs to.
-// It is the Go equivalent of the execution context the Python and Java SDKs
-// expose to task authors.
+// TIRunContext is the execution context handed to a task. It behaves as the
+// standard context.Context (cancellation, deadline, request-scoped values) and
+// additionally exposes the identifiers and scheduling timestamps of the task
+// instance that is executing, along with the Dag run it belongs to. It is the
+// Go equivalent of the execution context the Python and Java SDKs expose to
+// task authors.
 //
-// Retrieve it inside a task function with CurrentContext:
+// The runtime injects it into a task function by parameter type, so declare it
+// as the task's context argument:
 //
-//     func myTask(ctx context.Context, log *slog.Logger) error {
-//             rc := sdk.CurrentContext(ctx)
-//             log.Info("running", "task_id", rc.TI.TaskID, "run_id", 
rc.DagRun.RunID)
+//     func myTask(ctx sdk.TIRunContext, log *slog.Logger) error {
+//             log.Info("running",
+//                     "task_id", ctx.TaskInstance().TaskID,
+//                     "run_id", ctx.DagRun().RunID,
+//             )
 //             return nil
 //     }
-type RuntimeContext struct {
-       // TI identifies the task instance that is executing.
-       TI TaskInstance
+//
+// Because it embeds context.Context it is usable wherever one is expected:
+// pass it straight to client calls, select on ctx.Done(), or hand it to
+// downstream helpers that take a context.Context.
+//
+// It is an interface rather than a struct holding a context.Context, which
+// the context package advises against 
(https://pkg.go.dev/context#hdr-Contexts_and_structs):
+// the runtime constructs a fresh value around the live task context for each
+// invocation, and task code cannot end up with a half-initialised value. Build
+// one in tests with NewTIRunContext.
+type TIRunContext interface {
+       context.Context
+
+       // TaskInstance identifies the task instance that is executing.
+       TaskInstance() TaskInstance
        // DagRun identifies the Dag run the task instance belongs to.
-       DagRun DagRun
+       DagRun() DagRun
 }
 
+// NewTIRunContext returns a TIRunContext that delegates context behaviour to
+// ctx and exposes ti and dagRun. It panics on a nil ctx, mirroring the context
+// package's own constructors. The runtime calls it when binding a task's
+// TIRunContext parameter; in unit tests, use it to hand-build the argument:
+//
+//     ctx := sdk.NewTIRunContext(context.Background(), 
sdk.TaskInstance{TaskID: "t1"}, sdk.DagRun{})
+func NewTIRunContext(ctx context.Context, ti TaskInstance, dagRun DagRun) 
TIRunContext {
+       if ctx == nil {
+               // This cannot happen from the runtime: taskFunction.Execute 
always
+               // binds the live task context. A nil ctx is a programming 
error in
+               // the caller, so fail loudly instead of masking it.
+               panic("sdk.NewTIRunContext: cannot create TIRunContext from nil 
context.Context")
+       }
+       return tiRunContext{Context: ctx, ti: ti, dagRun: dagRun}
+}
+
+// tiRunContext is the runtime implementation of TIRunContext.
+type tiRunContext struct {
+       context.Context
+
+       ti     TaskInstance
+       dagRun DagRun
+}
+
+func (c tiRunContext) TaskInstance() TaskInstance { return c.ti }
+
+func (c tiRunContext) DagRun() DagRun { return c.dagRun }
+
 // TaskInstance identifies the currently executing task instance.
 type TaskInstance struct {
        DagID  string
@@ -65,15 +108,3 @@ type DagRun struct {
        DataIntervalStart *time.Time
        DataIntervalEnd   *time.Time
 }
-
-// CurrentContext returns the RuntimeContext the runtime stored on ctx for the
-// executing task. When ctx carries no RuntimeContext (for example when called
-// outside of a running task) it returns the zero value.
-//
-// It takes ctx explicitly rather than reading task-local state like Python's
-// get_current_context, because Go has no goroutine-local storage and the
-// worker path runs multiple tasks concurrently in one process.
-func CurrentContext(ctx context.Context) RuntimeContext {
-       rc, _ := ctx.Value(sdkcontext.RuntimeContextKey).(RuntimeContext)
-       return rc
-}
diff --git a/go-sdk/sdk/context_test.go b/go-sdk/sdk/context_test.go
index c63e79a470d..cc6f379abf4 100644
--- a/go-sdk/sdk/context_test.go
+++ b/go-sdk/sdk/context_test.go
@@ -20,34 +20,35 @@ package sdk
 import (
        "context"
        "testing"
-       "time"
 
        "github.com/stretchr/testify/assert"
-
-       "github.com/apache/airflow/go-sdk/pkg/sdkcontext"
 )
 
-func TestCurrentContext(t *testing.T) {
-       logical := time.Date(2026, 6, 9, 12, 0, 0, 0, time.UTC)
-       want := RuntimeContext{
-               TI: TaskInstance{
-                       DagID:     "dag1",
-                       RunID:     "run1",
-                       TaskID:    "task1",
-                       MapIndex:  ptr(3),
-                       TryNumber: 2,
-               },
-               DagRun: DagRun{
-                       DagID:       "dag1",
-                       RunID:       "run1",
-                       LogicalDate: &logical,
-               },
-       }
-
-       ctx := context.WithValue(context.Background(), 
sdkcontext.RuntimeContextKey, want)
-       assert.Equal(t, want, CurrentContext(ctx))
+type ctxKey struct{}
+
+func TestNewTIRunContext(t *testing.T) {
+       ti := TaskInstance{DagID: "dag1", RunID: "run1", TaskID: "task1", 
TryNumber: 2}
+       dagRun := DagRun{DagID: "dag1", RunID: "run1"}
+
+       base := context.WithValue(context.Background(), ctxKey{}, "probe-value")
+       ctx := NewTIRunContext(base, ti, dagRun)
+
+       assert.Equal(t, ti, ctx.TaskInstance())
+       assert.Equal(t, dagRun, ctx.DagRun())
+       assert.Equal(
+               t,
+               "probe-value",
+               ctx.Value(ctxKey{}),
+               "context behaviour must delegate to the base context",
+       )
 }
 
-func TestCurrentContextAbsentReturnsZero(t *testing.T) {
-       assert.Equal(t, RuntimeContext{}, CurrentContext(context.Background()))
+// A nil base context is a programming error: the runtime always binds the
+// live task context, so the constructor must fail loudly rather than hand out
+// a value that panics later.
+func TestNewTIRunContextNilBase(t *testing.T) {
+       var nilBase context.Context
+       assert.Panics(t, func() {
+               NewTIRunContext(nilBase, TaskInstance{}, DagRun{})
+       })
 }

Reply via email to