This is an automated email from the ASF dual-hosted git repository.

jason810496 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 129d03ed7a9 Go-SDK: Add coordinator-mode protocol primitives and SDK 
surface hooks (#67315)
129d03ed7a9 is described below

commit 129d03ed7a94f55d69f3040c5111dda2980b9ae1
Author: Jason(Zhe-You) Liu <[email protected]>
AuthorDate: Wed May 27 13:05:04 2026 +0800

    Go-SDK: Add coordinator-mode protocol primitives and SDK surface hooks 
(#67315)
    
    * go-sdk: Add coordinator-mode protocol primitives and SDK surface hooks
    
    First step toward landing the Go SDK coordinator-mode runtime
    (ADR 0003, msgpack-over-IPC). Scaffolding only -- no entry point is wired
    here, so go-plugin / Edge Worker behaviour is unchanged.
    
    Adds the length-prefixed msgpack frame codec and the typed message
    envelopes the runtime will exchange with the supervisor, the
    sdkcontext.SdkClientContextKey injection hook on bundlev1.Task so a
    follow-up PR can swap in a comm-socket-backed sdk.Client, and the small
    sdk surface tweaks (ConnFromAPIResponse export, VariableClient interface
    docs, secret-masking TODOs) the comm-socket client will rely on. Pulls
    in github.com/vmihailenco/msgpack/v5 -- the encoding the supervisor
    speaks.
    
    * Fix: Update max frame size constant and enhance payload validation
    
    * Fix: Tighten coordinator-protocol message decoding
    
    - Add strict mapInt for required fields; require try_number on
      StartupDetails so supervisor/runtime version-drift surfaces as a
      decode error instead of a silent default.
    - Use RFC3339Nano for SucceedTask and TaskState end_date so the
      sub-second precision asTime parses round-trips on the wire.
    - Change SetXComMsg.MapIndex to *int and omit when nil, matching
      Python's SetXCom.map_index = None semantics and GetXComMsg.MapIndex.
    
    * Add regression test for writeFrame oversize-payload guard
    
    Pins the guard at the top of writeFrame so a future refactor cannot
    silently drop the uint32-overflow protection. Uses unsafe.Slice to build
    a fake-length payload without allocating multi-GiB buffers, since the
    guard only reads len(payload) before any allocation or byte access.
    
    The matching read-side guard in readFrame is dead code with MaxFrameSize
    pinned at the uint32 maximum and cannot be exercised without modifying
    production code; documented inline.
---
 go-sdk/bundle/bundlev1/task.go        |   7 +-
 go-sdk/go.mod                         |   2 +
 go-sdk/go.sum                         |   4 +
 go-sdk/pkg/execution/frames.go        | 286 +++++++++++++++++++++++
 go-sdk/pkg/execution/frames_test.go   | 258 +++++++++++++++++++++
 go-sdk/pkg/execution/messages.go      | 412 ++++++++++++++++++++++++++++++++++
 go-sdk/pkg/execution/messages_test.go | 375 +++++++++++++++++++++++++++++++
 go-sdk/pkg/sdkcontext/keys.go         |   8 +
 go-sdk/sdk/client.go                  |  15 +-
 go-sdk/sdk/connection.go              |   6 +-
 go-sdk/sdk/sdk.go                     |  16 ++
 11 files changed, 1386 insertions(+), 3 deletions(-)

diff --git a/go-sdk/bundle/bundlev1/task.go b/go-sdk/bundle/bundlev1/task.go
index 4271f4892bb..5277f40681c 100644
--- a/go-sdk/bundle/bundlev1/task.go
+++ b/go-sdk/bundle/bundlev1/task.go
@@ -45,7 +45,12 @@ func NewTaskFunction(fn any) (Task, error) {
 
 func (f *taskFunction) Execute(ctx context.Context, logger *slog.Logger) error 
{
        fnType := f.fn.Type()
-       sdkClient := sdk.NewClient()
+       var sdkClient sdk.Client
+       if injected, ok := 
ctx.Value(sdkcontext.SdkClientContextKey).(sdk.Client); ok {
+               sdkClient = injected
+       } else {
+               sdkClient = sdk.NewClient()
+       }
 
        reflectArgs := make([]reflect.Value, fnType.NumIn())
        for i := range reflectArgs {
diff --git a/go-sdk/go.mod b/go-sdk/go.mod
index bfb400eee94..f3bfcd4b0f6 100644
--- a/go-sdk/go.mod
+++ b/go-sdk/go.mod
@@ -16,6 +16,7 @@ require (
        github.com/spf13/pflag v1.0.10
        github.com/spf13/viper v1.20.1
        github.com/stretchr/testify v1.11.1
+       github.com/vmihailenco/msgpack/v5 v5.4.1
        google.golang.org/grpc v1.79.3
        google.golang.org/protobuf v1.36.10
        resty.dev/v3 v3.0.0-beta.2
@@ -38,6 +39,7 @@ require (
        github.com/spf13/afero v1.12.0 // indirect
        github.com/stretchr/objx v0.5.2 // indirect
        github.com/subosito/gotenv v1.6.0 // indirect
+       github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
        go.opentelemetry.io/otel v1.41.0 // indirect
        go.opentelemetry.io/otel/trace v1.41.0 // indirect
        go.uber.org/multierr v1.10.0 // indirect
diff --git a/go-sdk/go.sum b/go-sdk/go.sum
index 5b7940672b1..a275d6b63c8 100644
--- a/go-sdk/go.sum
+++ b/go-sdk/go.sum
@@ -114,6 +114,10 @@ github.com/stretchr/testify v1.11.1 
h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
 github.com/stretchr/testify v1.11.1/go.mod 
h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
 github.com/subosito/gotenv v1.6.0 
h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
 github.com/subosito/gotenv v1.6.0/go.mod 
h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
+github.com/vmihailenco/msgpack/v5 v5.4.1 
h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
+github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod 
h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
+github.com/vmihailenco/tagparser/v2 v2.0.0 
h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
+github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod 
h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
 go.opentelemetry.io/auto/sdk v1.2.1 
h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
 go.opentelemetry.io/auto/sdk v1.2.1/go.mod 
h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
 go.opentelemetry.io/otel v1.41.0 
h1:YlEwVsGAlCvczDILpUXpIpPSL/VPugt7zHThEMLce1c=
diff --git a/go-sdk/pkg/execution/frames.go b/go-sdk/pkg/execution/frames.go
new file mode 100644
index 00000000000..1d8f4671013
--- /dev/null
+++ b/go-sdk/pkg/execution/frames.go
@@ -0,0 +1,286 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package execution
+
+import (
+       "bytes"
+       "encoding/binary"
+       "fmt"
+       "io"
+
+       "github.com/vmihailenco/msgpack/v5"
+)
+
+// MaxFrameSize is the maximum payload length of a single frame, in bytes.
+// The 4-byte length prefix bounds this to 2^32 - 1; matches Python's cap in
+// task-sdk comms.py:_FrameMixin.as_bytes (n >= 2**32 raises OverflowError).
+const MaxFrameSize = 1<<32 - 1
+
+// IncomingFrame represents a decoded frame received from the comm socket.
+type IncomingFrame struct {
+       ID   int
+       Body map[string]any
+       Err  map[string]any // non-nil only for response frames (3-element 
arrays)
+}
+
+// encodeRequest encodes a request frame (2-element msgpack array: [id, body]).
+func encodeRequest(id int, body map[string]any) ([]byte, error) {
+       var buf bytes.Buffer
+       enc := msgpack.NewEncoder(&buf)
+       enc.UseCompactInts(true)
+
+       if err := enc.EncodeArrayLen(2); err != nil {
+               return nil, err
+       }
+       if err := enc.EncodeInt(int64(id)); err != nil {
+               return nil, err
+       }
+       if err := enc.Encode(body); err != nil {
+               return nil, err
+       }
+       return buf.Bytes(), nil
+}
+
+// writeFrame writes a length-prefixed msgpack payload to the writer.
+// Format: [4-byte big-endian length][payload bytes].
+//
+// The prefix and payload are concatenated into a single buffer and written
+// in one Write call so we never leave a half-framed message on the wire if
+// an io.Writer implementation does a short write between the two halves.
+func writeFrame(w io.Writer, payload []byte) error {
+       // Refuse to send a payload the peer would refuse to read. Without this
+       // guard, lengths >= 4 GiB would silently wrap in the uint32 conversion
+       // below and put a corrupt length prefix on the wire, desynchronising
+       // the peer instead of failing loudly here. Mirrors the OverflowError
+       // raised by task-sdk comms.py:_FrameMixin.as_bytes.
+       if len(payload) > MaxFrameSize {
+               return fmt.Errorf(
+                       "frame payload length %d exceeds max %d",
+                       len(payload),
+                       MaxFrameSize,
+               )
+       }
+       buf := make([]byte, 4+len(payload))
+       binary.BigEndian.PutUint32(buf[:4], uint32(len(payload)))
+       copy(buf[4:], payload)
+       n, err := w.Write(buf)
+       if err != nil {
+               return fmt.Errorf("writing frame: %w", err)
+       }
+       if n < len(buf) {
+               return fmt.Errorf("writing frame: %w", io.ErrShortWrite)
+       }
+       return nil
+}
+
+// readFrame reads one length-prefixed msgpack frame from the reader and 
decodes it.
+func readFrame(r io.Reader) (IncomingFrame, error) {
+       // Read 4-byte big-endian length prefix.
+       prefix := make([]byte, 4)
+       if _, err := io.ReadFull(r, prefix); err != nil {
+               return IncomingFrame{}, fmt.Errorf("reading length prefix: %w", 
err)
+       }
+       payloadLen := binary.BigEndian.Uint32(prefix)
+       // Reject oversized frames defensively. A non-Python sender (or a
+       // MaxFrameSize lowered for memory-budget reasons) might violate the cap
+       // the reader is willing to allocate, so fail loudly here rather than
+       // trusting the peer.
+       if payloadLen > MaxFrameSize {
+               return IncomingFrame{}, fmt.Errorf(
+                       "frame payload length %d exceeds max %d",
+                       payloadLen,
+                       MaxFrameSize,
+               )
+       }
+       payload := make([]byte, int(payloadLen))
+       if _, err := io.ReadFull(r, payload); err != nil {
+               return IncomingFrame{}, fmt.Errorf("reading payload (%d bytes): 
%w", payloadLen, err)
+       }
+
+       return decodeFrame(payload)
+}
+
+// decodeFrame decodes a msgpack payload into an IncomingFrame.
+func decodeFrame(data []byte) (IncomingFrame, error) {
+       dec := msgpack.NewDecoder(bytes.NewReader(data))
+
+       arrLen, err := dec.DecodeArrayLen()
+       if err != nil {
+               return IncomingFrame{}, fmt.Errorf("decoding array header: %w", 
err)
+       }
+       if arrLen < 2 {
+               return IncomingFrame{}, fmt.Errorf("unexpected frame arity %d, 
need at least 2", arrLen)
+       }
+
+       id64, err := dec.DecodeInt64()
+       if err != nil {
+               return IncomingFrame{}, fmt.Errorf("decoding frame id: %w", err)
+       }
+
+       // Decode the body element.
+       bodyRaw, err := dec.DecodeInterface()
+       if err != nil {
+               return IncomingFrame{}, fmt.Errorf("decoding body: %w", err)
+       }
+       body, ok := toStringMap(bodyRaw)
+       if bodyRaw != nil && !ok {
+               return IncomingFrame{}, fmt.Errorf("body element: expected map, 
got %T", bodyRaw)
+       }
+
+       // For response frames (3-element), decode the error element.
+       var errMap map[string]any
+       if arrLen >= 3 {
+               errRaw, err := dec.DecodeInterface()
+               if err != nil {
+                       return IncomingFrame{}, fmt.Errorf("decoding error 
element: %w", err)
+               }
+               errMap, ok = toStringMap(errRaw)
+               if errRaw != nil && !ok {
+                       return IncomingFrame{}, fmt.Errorf("error element: 
expected map, got %T", errRaw)
+               }
+       }
+
+       return IncomingFrame{
+               ID:   int(id64),
+               Body: body,
+               Err:  errMap,
+       }, nil
+}
+
+// toStringMap converts a decoded interface{} to map[string]any.
+// Returns nil, false if the input is nil or not a map.
+func toStringMap(v any) (map[string]any, bool) {
+       if v == nil {
+               return nil, false
+       }
+       switch m := v.(type) {
+       case map[string]any:
+               return m, true
+       case map[any]any:
+               result := make(map[string]any, len(m))
+               for k, val := range m {
+                       result[fmt.Sprint(k)] = val
+               }
+               return result, true
+       default:
+               return nil, false
+       }
+}
+
+// mapString extracts a string value from a map.
+func mapString(m map[string]any, key string) (string, error) {
+       v, ok := m[key]
+       if !ok {
+               return "", fmt.Errorf("missing key %q", key)
+       }
+       s, ok := v.(string)
+       if !ok {
+               return "", fmt.Errorf("key %q: expected string, got %T", key, v)
+       }
+       return s, nil
+}
+
+// mapInt extracts an int value from a map. Returns an error if the key is
+// missing or the value is not a numeric type. Use this for fields the
+// supervisor is contractually required to send (e.g. try_number); a silent
+// default would mask supervisor/runtime version-drift bugs.
+func mapInt(m map[string]any, key string) (int, error) {
+       v, ok := m[key]
+       if !ok {
+               return 0, fmt.Errorf("missing key %q", key)
+       }
+       n, err := toInt(v)
+       if err != nil {
+               return 0, fmt.Errorf("key %q: %w", key, err)
+       }
+       return n, nil
+}
+
+// mapIntOr extracts an int value from a map, returning the default when the
+// key is missing OR the value is not a numeric type. Use this only for
+// genuinely optional fields where any decoding hiccup should fall back to
+// the default; for required fields, use mapInt.
+func mapIntOr(m map[string]any, key string, def int) int {
+       v, ok := m[key]
+       if !ok {
+               return def
+       }
+       n, err := toInt(v)
+       if err != nil {
+               return def
+       }
+       return n
+}
+
+// mapStringOr extracts a string value from a map, returning the default if 
missing.
+func mapStringOr(m map[string]any, key string, def string) string {
+       v, ok := m[key]
+       if !ok {
+               return def
+       }
+       s, ok := v.(string)
+       if !ok {
+               return def
+       }
+       return s
+}
+
+// mapMap extracts a nested map from a map.
+func mapMap(m map[string]any, key string) map[string]any {
+       v, ok := m[key]
+       if !ok || v == nil {
+               return nil
+       }
+       sub, ok := toStringMap(v)
+       if !ok {
+               return nil
+       }
+       return sub
+}
+
+// toInt converts various numeric types from msgpack decoding to int.
+func toInt(v any) (int, error) {
+       switch n := v.(type) {
+       case int:
+               return n, nil
+       case int8:
+               return int(n), nil
+       case int16:
+               return int(n), nil
+       case int32:
+               return int(n), nil
+       case int64:
+               return int(n), nil
+       case uint:
+               return int(n), nil
+       case uint8:
+               return int(n), nil
+       case uint16:
+               return int(n), nil
+       case uint32:
+               return int(n), nil
+       case uint64:
+               return int(n), nil
+       case float32:
+               return int(n), nil
+       case float64:
+               return int(n), nil
+       default:
+               return 0, fmt.Errorf("expected numeric, got %T", v)
+       }
+}
diff --git a/go-sdk/pkg/execution/frames_test.go 
b/go-sdk/pkg/execution/frames_test.go
new file mode 100644
index 00000000000..18d247bf362
--- /dev/null
+++ b/go-sdk/pkg/execution/frames_test.go
@@ -0,0 +1,258 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package execution
+
+import (
+       "bytes"
+       "encoding/binary"
+       "strconv"
+       "testing"
+       "unsafe"
+
+       "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/require"
+       "github.com/vmihailenco/msgpack/v5"
+)
+
+func TestEncodeRequest(t *testing.T) {
+       body := map[string]any{
+               "type": "GetVariable",
+               "key":  "my_var",
+       }
+
+       data, err := encodeRequest(42, body)
+       require.NoError(t, err)
+
+       // Decode and verify structure.
+       dec := msgpack.NewDecoder(bytes.NewReader(data))
+       arrLen, err := dec.DecodeArrayLen()
+       require.NoError(t, err)
+       assert.Equal(t, 2, arrLen, "request frame should be 2-element array")
+
+       id, err := dec.DecodeInt64()
+       require.NoError(t, err)
+       assert.Equal(t, int64(42), id)
+
+       var decodedBody map[string]any
+       err = dec.Decode(&decodedBody)
+       require.NoError(t, err)
+       assert.Equal(t, "GetVariable", decodedBody["type"])
+       assert.Equal(t, "my_var", decodedBody["key"])
+}
+
+func TestWriteAndReadFrame(t *testing.T) {
+       body := map[string]any{
+               "type":    "GetConnection",
+               "conn_id": "my_db",
+       }
+
+       payload, err := encodeRequest(7, body)
+       require.NoError(t, err)
+
+       // Write to buffer with length prefix.
+       var buf bytes.Buffer
+       err = writeFrame(&buf, payload)
+       require.NoError(t, err)
+
+       // Verify length prefix.
+       prefix := buf.Bytes()[:4]
+       expectedLen := uint32(len(payload))
+       assert.Equal(t, expectedLen, binary.BigEndian.Uint32(prefix))
+
+       // Read back.
+       frame, err := readFrame(&buf)
+       require.NoError(t, err)
+       assert.Equal(t, 7, frame.ID)
+       assert.Equal(t, "GetConnection", frame.Body["type"])
+       assert.Equal(t, "my_db", frame.Body["conn_id"])
+       assert.Nil(t, frame.Err)
+}
+
+func TestDecodeResponseFrame(t *testing.T) {
+       // Encode a 3-element response frame: [id, body, error]
+       var buf bytes.Buffer
+       enc := msgpack.NewEncoder(&buf)
+       enc.UseCompactInts(true)
+
+       require.NoError(t, enc.EncodeArrayLen(3))
+       require.NoError(t, enc.EncodeInt(5))
+       require.NoError(t, enc.Encode(map[string]any{
+               "type":    "ConnectionResult",
+               "conn_id": "test_conn",
+               "host":    "localhost",
+       }))
+       require.NoError(t, enc.Encode(nil)) // no error
+
+       frame, err := decodeFrame(buf.Bytes())
+       require.NoError(t, err)
+       assert.Equal(t, 5, frame.ID)
+       assert.Equal(t, "ConnectionResult", frame.Body["type"])
+       assert.Equal(t, "localhost", frame.Body["host"])
+       assert.Nil(t, frame.Err)
+}
+
+func TestDecodeResponseFrameWithError(t *testing.T) {
+       var buf bytes.Buffer
+       enc := msgpack.NewEncoder(&buf)
+       enc.UseCompactInts(true)
+
+       require.NoError(t, enc.EncodeArrayLen(3))
+       require.NoError(t, enc.EncodeInt(3))
+       require.NoError(t, enc.Encode(nil)) // nil body
+       require.NoError(t, enc.Encode(map[string]any{
+               "type":   "ErrorResponse",
+               "error":  "not_found",
+               "detail": "Variable 'x' not found",
+       }))
+
+       frame, err := decodeFrame(buf.Bytes())
+       require.NoError(t, err)
+       assert.Equal(t, 3, frame.ID)
+       assert.Nil(t, frame.Body)
+       assert.NotNil(t, frame.Err)
+       assert.Equal(t, "not_found", frame.Err["error"])
+}
+
+func TestDecodeFrameRejectsNonMapBody(t *testing.T) {
+       // A non-nil, non-map body element is a protocol violation; the decoder
+       // must surface it instead of silently turning the body into nil.
+       var buf bytes.Buffer
+       enc := msgpack.NewEncoder(&buf)
+       enc.UseCompactInts(true)
+
+       require.NoError(t, enc.EncodeArrayLen(2))
+       require.NoError(t, enc.EncodeInt(1))
+       require.NoError(t, enc.EncodeString("not a map"))
+
+       _, err := decodeFrame(buf.Bytes())
+       require.Error(t, err)
+       assert.Contains(t, err.Error(), "body element: expected map")
+}
+
+func TestDecodeFrameRejectsNonMapError(t *testing.T) {
+       // Same rule applies to the error element of a 3-tuple response frame.
+       var buf bytes.Buffer
+       enc := msgpack.NewEncoder(&buf)
+       enc.UseCompactInts(true)
+
+       require.NoError(t, enc.EncodeArrayLen(3))
+       require.NoError(t, enc.EncodeInt(2))
+       require.NoError(t, enc.Encode(nil))
+       require.NoError(t, enc.EncodeString("not a map"))
+
+       _, err := decodeFrame(buf.Bytes())
+       require.Error(t, err)
+       assert.Contains(t, err.Error(), "error element: expected map")
+}
+
+// TestWriteFrameRejectsOversizedPayload pins the guard at the top of
+// writeFrame against the rename/refactor that previously dropped its
+// coverage. The guard only inspects len(payload) before doing any allocation
+// or read of payload bytes, so we hand it a fake-length slice built with
+// unsafe.Slice (one real byte of backing storage, length > MaxFrameSize)
+// rather than allocating 4 GiB of real memory.
+//
+// The matching read-side guard at the top of readFrame is dead code with
+// MaxFrameSize pinned at the uint32 maximum (payloadLen is uint32, so
+// payloadLen > MaxFrameSize is never true) and cannot be exercised without
+// modifying production code; it remains as defense-in-depth in case
+// MaxFrameSize is ever lowered.
+func TestWriteFrameRejectsOversizedPayload(t *testing.T) {
+       if strconv.IntSize < 64 {
+               t.Skip("requires 64-bit int to construct a slice longer than 
MaxFrameSize")
+       }
+       var backing byte
+       payload := unsafe.Slice(&backing, uint64(MaxFrameSize)+1)
+
+       err := writeFrame(&bytes.Buffer{}, payload)
+       require.Error(t, err)
+       assert.Contains(t, err.Error(), "exceeds max")
+}
+
+func TestRoundTripMultipleFrames(t *testing.T) {
+       var buf bytes.Buffer
+
+       // Write two frames.
+       bodies := []map[string]any{
+               {"type": "GetVariable", "key": "v1"},
+               {"type": "GetVariable", "key": "v2"},
+       }
+       for i, body := range bodies {
+               payload, err := encodeRequest(i, body)
+               require.NoError(t, err)
+               require.NoError(t, writeFrame(&buf, payload))
+       }
+
+       // Read them back.
+       for i, expected := range bodies {
+               frame, err := readFrame(&buf)
+               require.NoError(t, err)
+               assert.Equal(t, i, frame.ID)
+               assert.Equal(t, expected["key"], frame.Body["key"])
+       }
+}
+
+func TestToStringMap(t *testing.T) {
+       tests := []struct {
+               name  string
+               input any
+               want  map[string]any
+               ok    bool
+       }{
+               {"nil", nil, nil, false},
+               {"string map", map[string]any{"a": 1}, map[string]any{"a": 1}, 
true},
+               {"any key map", map[any]any{"b": 2}, map[string]any{"b": 2}, 
true},
+               {"not a map", "hello", nil, false},
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       got, ok := toStringMap(tt.input)
+                       assert.Equal(t, tt.ok, ok)
+                       if tt.ok {
+                               assert.Equal(t, tt.want, got)
+                       }
+               })
+       }
+}
+
+func TestToInt(t *testing.T) {
+       tests := []struct {
+               input any
+               want  int
+       }{
+               {int8(42), 42},
+               {int16(42), 42},
+               {int32(42), 42},
+               {int64(42), 42},
+               {uint8(42), 42},
+               {uint16(42), 42},
+               {uint32(42), 42},
+               {uint64(42), 42},
+               {float32(42.0), 42},
+               {float64(42.0), 42},
+               {int(42), 42},
+       }
+       for _, tt := range tests {
+               got, err := toInt(tt.input)
+               require.NoError(t, err)
+               assert.Equal(t, tt.want, got)
+       }
+
+       _, err := toInt("not a number")
+       assert.Error(t, err)
+}
diff --git a/go-sdk/pkg/execution/messages.go b/go-sdk/pkg/execution/messages.go
new file mode 100644
index 00000000000..f7d9ae6cc27
--- /dev/null
+++ b/go-sdk/pkg/execution/messages.go
@@ -0,0 +1,412 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package execution
+
+import (
+       "fmt"
+       "time"
+)
+
+// Inbound messages (Supervisor -> Runtime).
+
+// TaskInstanceInfo holds task instance details from StartupDetails.
+type TaskInstanceInfo struct {
+       ID             string
+       TaskID         string
+       DagID          string
+       RunID          string
+       TryNumber      int
+       DagVersionID   string
+       MapIndex       int
+       ContextCarrier map[string]any
+}
+
+func decodeTaskInstanceInfo(m map[string]any) (TaskInstanceInfo, error) {
+       if m == nil {
+               return TaskInstanceInfo{}, fmt.Errorf("nil task instance map")
+       }
+       id, err := mapString(m, "id")
+       if err != nil {
+               return TaskInstanceInfo{}, fmt.Errorf("ti.id: %w", err)
+       }
+       taskID, err := mapString(m, "task_id")
+       if err != nil {
+               return TaskInstanceInfo{}, fmt.Errorf("ti.task_id: %w", err)
+       }
+       dagID, err := mapString(m, "dag_id")
+       if err != nil {
+               return TaskInstanceInfo{}, fmt.Errorf("ti.dag_id: %w", err)
+       }
+       runID, err := mapString(m, "run_id")
+       if err != nil {
+               return TaskInstanceInfo{}, fmt.Errorf("ti.run_id: %w", err)
+       }
+       tryNumber, err := mapInt(m, "try_number")
+       if err != nil {
+               return TaskInstanceInfo{}, fmt.Errorf("ti.try_number: %w", err)
+       }
+       dagVersionID := mapStringOr(m, "dag_version_id", "")
+       mapIndex := mapIntOr(m, "map_index", -1)
+       contextCarrier := mapMap(m, "context_carrier")
+
+       return TaskInstanceInfo{
+               ID:             id,
+               TaskID:         taskID,
+               DagID:          dagID,
+               RunID:          runID,
+               TryNumber:      tryNumber,
+               DagVersionID:   dagVersionID,
+               MapIndex:       mapIndex,
+               ContextCarrier: contextCarrier,
+       }, nil
+}
+
+// BundleInfoMsg holds bundle identification from StartupDetails.
+type BundleInfoMsg struct {
+       Name    string
+       Version string
+}
+
+func decodeBundleInfo(m map[string]any) BundleInfoMsg {
+       if m == nil {
+               return BundleInfoMsg{}
+       }
+       return BundleInfoMsg{
+               Name:    mapStringOr(m, "name", ""),
+               Version: mapStringOr(m, "version", ""),
+       }
+}
+
+// TIRunContext holds the runtime context for a task instance.
+type TIRunContext struct {
+       LogicalDate       *time.Time
+       DataIntervalStart *time.Time
+       DataIntervalEnd   *time.Time
+}
+
+func decodeTIRunContext(m map[string]any) (TIRunContext, error) {
+       if m == nil {
+               return TIRunContext{}, nil
+       }
+       ctx := TIRunContext{}
+       for _, f := range []struct {
+               key string
+               dst **time.Time
+       }{
+               {"logical_date", &ctx.LogicalDate},
+               {"data_interval_start", &ctx.DataIntervalStart},
+               {"data_interval_end", &ctx.DataIntervalEnd},
+       } {
+               raw, present := m[f.key]
+               if !present || raw == nil {
+                       continue
+               }
+               t, err := asTime(raw)
+               if err != nil {
+                       return TIRunContext{}, fmt.Errorf("ti_context.%s: %w", 
f.key, err)
+               }
+               *f.dst = &t
+       }
+       return ctx, nil
+}
+
+// StartupDetails is sent by the supervisor to initiate task execution.
+type StartupDetails struct {
+       TI                TaskInstanceInfo
+       DagRelPath        string
+       BundleInfo        BundleInfoMsg
+       StartDate         time.Time
+       TIContext         TIRunContext
+       SentryIntegration string
+}
+
+func decodeStartupDetails(m map[string]any) (*StartupDetails, error) {
+       tiMap := mapMap(m, "ti")
+       ti, err := decodeTaskInstanceInfo(tiMap)
+       if err != nil {
+               return nil, fmt.Errorf("decoding ti: %w", err)
+       }
+
+       dagRelPath := mapStringOr(m, "dag_rel_path", "")
+       bundleInfo := decodeBundleInfo(mapMap(m, "bundle_info"))
+
+       var startDate time.Time
+       if raw, present := m["start_date"]; present && raw != nil {
+               startDate, err = asTime(raw)
+               if err != nil {
+                       return nil, fmt.Errorf("start_date: %w", err)
+               }
+       }
+
+       tiContext, err := decodeTIRunContext(mapMap(m, "ti_context"))
+       if err != nil {
+               return nil, fmt.Errorf("decoding ti_context: %w", err)
+       }
+       sentryIntegration := mapStringOr(m, "sentry_integration", "")
+
+       return &StartupDetails{
+               TI:                ti,
+               DagRelPath:        dagRelPath,
+               BundleInfo:        bundleInfo,
+               StartDate:         startDate,
+               TIContext:         tiContext,
+               SentryIntegration: sentryIntegration,
+       }, nil
+}
+
+// Response types (for runtime-initiated requests).
+
+// ConnectionResult is the response to GetConnection.
+type ConnectionResult struct {
+       ConnID   string
+       ConnType string
+       Host     string
+       Schema   string
+       Login    string
+       Password string
+       Port     int
+       Extra    string
+}
+
+func decodeConnectionResult(m map[string]any) (*ConnectionResult, error) {
+       return &ConnectionResult{
+               ConnID:   mapStringOr(m, "conn_id", ""),
+               ConnType: mapStringOr(m, "conn_type", ""),
+               Host:     mapStringOr(m, "host", ""),
+               Schema:   mapStringOr(m, "schema", ""),
+               Login:    mapStringOr(m, "login", ""),
+               Password: mapStringOr(m, "password", ""),
+               Port:     mapIntOr(m, "port", 0),
+               Extra:    mapStringOr(m, "extra", ""),
+       }, nil
+}
+
+// VariableResult is the response to GetVariable.
+type VariableResult struct {
+       Key   string
+       Value any
+}
+
+func decodeVariableResult(m map[string]any) (*VariableResult, error) {
+       return &VariableResult{
+               Key:   mapStringOr(m, "key", ""),
+               Value: m["value"],
+       }, nil
+}
+
+// XComResult is the response to GetXCom.
+type XComResult struct {
+       Key   string
+       Value any
+}
+
+func decodeXComResult(m map[string]any) (*XComResult, error) {
+       return &XComResult{
+               Key:   mapStringOr(m, "key", ""),
+               Value: m["value"],
+       }, nil
+}
+
+// ErrorResponse represents an error returned by the supervisor.
+type ErrorResponse struct {
+       Error  string
+       Detail any
+}
+
+func decodeErrorResponse(m map[string]any) *ErrorResponse {
+       if m == nil {
+               return nil
+       }
+       return &ErrorResponse{
+               Error:  mapStringOr(m, "error", ""),
+               Detail: m["detail"],
+       }
+}
+
+// Outbound messages (Runtime -> Supervisor).
+
+// GetConnectionMsg is sent to request a connection from the supervisor.
+type GetConnectionMsg struct {
+       ConnID string
+}
+
+func (m GetConnectionMsg) toMap() map[string]any {
+       return map[string]any{
+               "type":    "GetConnection",
+               "conn_id": m.ConnID,
+       }
+}
+
+// GetVariableMsg is sent to request a variable from the supervisor.
+type GetVariableMsg struct {
+       Key string
+}
+
+func (m GetVariableMsg) toMap() map[string]any {
+       return map[string]any{
+               "type": "GetVariable",
+               "key":  m.Key,
+       }
+}
+
+// GetXComMsg is sent to request an XCom value from the supervisor.
+type GetXComMsg struct {
+       Key               string
+       DagID             string
+       TaskID            string
+       RunID             string
+       MapIndex          *int
+       IncludePriorDates bool
+}
+
+func (m GetXComMsg) toMap() map[string]any {
+       result := map[string]any{
+               "type":                "GetXCom",
+               "key":                 m.Key,
+               "dag_id":              m.DagID,
+               "task_id":             m.TaskID,
+               "run_id":              m.RunID,
+               "include_prior_dates": m.IncludePriorDates,
+       }
+       if m.MapIndex != nil {
+               result["map_index"] = *m.MapIndex
+       }
+       return result
+}
+
+// SetXComMsg is sent to set an XCom value. MapIndex mirrors Python's
+// SetXCom.map_index (int | None): nil means "unmapped task", and is omitted
+// from the wire payload rather than encoded as a -1 sentinel.
+type SetXComMsg struct {
+       Key          string
+       Value        any
+       DagID        string
+       TaskID       string
+       RunID        string
+       MapIndex     *int
+       MappedLength *int
+}
+
+func (m SetXComMsg) toMap() map[string]any {
+       result := map[string]any{
+               "type":    "SetXCom",
+               "key":     m.Key,
+               "value":   m.Value,
+               "dag_id":  m.DagID,
+               "task_id": m.TaskID,
+               "run_id":  m.RunID,
+       }
+       if m.MapIndex != nil {
+               result["map_index"] = *m.MapIndex
+       }
+       if m.MappedLength != nil {
+               result["mapped_length"] = *m.MappedLength
+       }
+       return result
+}
+
+// SucceedTaskMsg is sent as a terminal message when a task succeeds.
+type SucceedTaskMsg struct {
+       EndDate      time.Time
+       TaskOutlets  []any
+       OutletEvents []any
+}
+
+func (m SucceedTaskMsg) toMap() map[string]any {
+       taskOutlets := m.TaskOutlets
+       if taskOutlets == nil {
+               taskOutlets = []any{}
+       }
+       outletEvents := m.OutletEvents
+       if outletEvents == nil {
+               outletEvents = []any{}
+       }
+       return map[string]any{
+               "type":          "SucceedTask",
+               "end_date":      m.EndDate.UTC().Format(time.RFC3339Nano),
+               "task_outlets":  taskOutlets,
+               "outlet_events": outletEvents,
+       }
+}
+
+// TaskState is the terminal non-success state reported via TaskStateMsg.
+// The wire values match Python's TaskInstanceState enum (and the generated
+// api.TerminalStateNonSuccess); we define a local typed string so call
+// sites get compile-time checking and don't have to import pkg/api just
+// for the constants.
+type TaskState string
+
+const (
+       TaskStateFailed  TaskState = "failed"
+       TaskStateRemoved TaskState = "removed"
+       TaskStateSkipped TaskState = "skipped"
+)
+
+// TaskStateMsg is sent as a terminal message for failed/removed/skipped tasks.
+type TaskStateMsg struct {
+       State   TaskState
+       EndDate time.Time
+}
+
+func (m TaskStateMsg) toMap() map[string]any {
+       return map[string]any{
+               "type":     "TaskState",
+               "state":    string(m.State),
+               "end_date": m.EndDate.UTC().Format(time.RFC3339Nano),
+       }
+}
+
+// Message dispatch.
+
+// decodeIncomingBody dispatches decoding of a body map based on its "type" 
field.
+func decodeIncomingBody(m map[string]any) (any, error) {
+       if m == nil {
+               return nil, nil
+       }
+       typ, _ := m["type"].(string)
+       switch typ {
+       case "StartupDetails":
+               return decodeStartupDetails(m)
+       case "ConnectionResult":
+               return decodeConnectionResult(m)
+       case "VariableResult":
+               return decodeVariableResult(m)
+       case "XComResult":
+               return decodeXComResult(m)
+       case "ErrorResponse":
+               return decodeErrorResponse(m), nil
+       default:
+               return nil, fmt.Errorf("unknown message type %q", typ)
+       }
+}
+
+// asTime parses a time value that may be a time.Time (from msgpack timestamp 
ext)
+// or a string (ISO 8601 format).
+func asTime(v any) (time.Time, error) {
+       if v == nil {
+               return time.Time{}, fmt.Errorf("nil time value")
+       }
+       switch t := v.(type) {
+       case time.Time:
+               return t, nil
+       case string:
+               return time.Parse(time.RFC3339Nano, t)
+       default:
+               return time.Time{}, fmt.Errorf("expected time, got %T", v)
+       }
+}
diff --git a/go-sdk/pkg/execution/messages_test.go 
b/go-sdk/pkg/execution/messages_test.go
new file mode 100644
index 00000000000..ed3951b5ba9
--- /dev/null
+++ b/go-sdk/pkg/execution/messages_test.go
@@ -0,0 +1,375 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package execution
+
+import (
+       "testing"
+       "time"
+
+       "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/require"
+)
+
+func TestDecodeStartupDetails(t *testing.T) {
+       m := map[string]any{
+               "type": "StartupDetails",
+               "ti": map[string]any{
+                       "id":             
"550e8400-e29b-41d4-a716-446655440000",
+                       "task_id":        "extract",
+                       "dag_id":         "tutorial_dag",
+                       "run_id":         "manual__2024-01-15",
+                       "try_number":     int64(1),
+                       "dag_version_id": "abc-123",
+                       "map_index":      int64(-1),
+               },
+               "dag_rel_path": "dags/tutorial.go",
+               "bundle_info": map[string]any{
+                       "name":    "example_dags",
+                       "version": "1.0.0",
+               },
+               "start_date":         "2024-01-15T10:30:00Z",
+               "sentry_integration": "",
+               "ti_context": map[string]any{
+                       "logical_date":        "2024-01-15T00:00:00Z",
+                       "data_interval_start": "2024-01-14T00:00:00Z",
+                       "data_interval_end":   "2024-01-15T00:00:00Z",
+               },
+       }
+
+       details, err := decodeStartupDetails(m)
+       require.NoError(t, err)
+
+       assert.Equal(t, "550e8400-e29b-41d4-a716-446655440000", details.TI.ID)
+       assert.Equal(t, "extract", details.TI.TaskID)
+       assert.Equal(t, "tutorial_dag", details.TI.DagID)
+       assert.Equal(t, "manual__2024-01-15", details.TI.RunID)
+       assert.Equal(t, 1, details.TI.TryNumber)
+       assert.Equal(t, -1, details.TI.MapIndex)
+       assert.Equal(t, "dags/tutorial.go", details.DagRelPath)
+       assert.Equal(t, "example_dags", details.BundleInfo.Name)
+       assert.Equal(t, "1.0.0", details.BundleInfo.Version)
+       assert.NotNil(t, details.TIContext.LogicalDate)
+}
+
+func TestDecodeStartupDetails_MalformedStartDate(t *testing.T) {
+       // A present-but-malformed start_date must surface as a decode error;
+       // silently leaving startDate as the zero time would let tasks run with
+       // incorrect provenance.
+       m := map[string]any{
+               "type": "StartupDetails",
+               "ti": map[string]any{
+                       "id":      "550e8400-e29b-41d4-a716-446655440000",
+                       "task_id": "t", "dag_id": "d", "run_id": "r",
+                       "try_number": int64(1),
+               },
+               "start_date": "not-a-timestamp",
+       }
+       _, err := decodeStartupDetails(m)
+       require.Error(t, err)
+       assert.Contains(t, err.Error(), "start_date")
+}
+
+func TestDecodeStartupDetails_MalformedTIRunContext(t *testing.T) {
+       m := map[string]any{
+               "type": "StartupDetails",
+               "ti": map[string]any{
+                       "id":      "550e8400-e29b-41d4-a716-446655440000",
+                       "task_id": "t", "dag_id": "d", "run_id": "r",
+                       "try_number": int64(1),
+               },
+               "ti_context": map[string]any{
+                       "logical_date": "garbage",
+               },
+       }
+       _, err := decodeStartupDetails(m)
+       require.Error(t, err)
+       assert.Contains(t, err.Error(), "logical_date")
+}
+
+func TestDecodeStartupDetails_RequiresTryNumber(t *testing.T) {
+       // try_number is required in Python's TaskInstance model; a missing or
+       // wrong-typed value must surface as a decode error rather than silently
+       // defaulting and masking supervisor/runtime version-drift bugs.
+       t.Run("missing", func(t *testing.T) {
+               m := map[string]any{
+                       "type": "StartupDetails",
+                       "ti": map[string]any{
+                               "id":      
"550e8400-e29b-41d4-a716-446655440000",
+                               "task_id": "t", "dag_id": "d", "run_id": "r",
+                       },
+               }
+               _, err := decodeStartupDetails(m)
+               require.Error(t, err)
+               assert.Contains(t, err.Error(), "try_number")
+       })
+
+       t.Run("wrong type", func(t *testing.T) {
+               m := map[string]any{
+                       "type": "StartupDetails",
+                       "ti": map[string]any{
+                               "id":         
"550e8400-e29b-41d4-a716-446655440000",
+                               "task_id":    "t",
+                               "dag_id":     "d",
+                               "run_id":     "r",
+                               "try_number": "1",
+                       },
+               }
+               _, err := decodeStartupDetails(m)
+               require.Error(t, err)
+               assert.Contains(t, err.Error(), "try_number")
+       })
+}
+
+func TestDecodeStartupDetails_MissingOptionalTimestamps(t *testing.T) {
+       // start_date and ti_context fields are optional; omitting them must
+       // still decode cleanly (no error, zero/nil values).
+       m := map[string]any{
+               "type": "StartupDetails",
+               "ti": map[string]any{
+                       "id":      "550e8400-e29b-41d4-a716-446655440000",
+                       "task_id": "t", "dag_id": "d", "run_id": "r",
+                       "try_number": int64(1),
+               },
+       }
+       details, err := decodeStartupDetails(m)
+       require.NoError(t, err)
+       assert.True(t, details.StartDate.IsZero())
+       assert.Nil(t, details.TIContext.LogicalDate)
+       assert.Nil(t, details.TIContext.DataIntervalStart)
+       assert.Nil(t, details.TIContext.DataIntervalEnd)
+}
+
+func TestDecodeConnectionResult(t *testing.T) {
+       m := map[string]any{
+               "type":      "ConnectionResult",
+               "conn_id":   "my_db",
+               "conn_type": "postgres",
+               "host":      "db.example.com",
+               "schema":    "mydb",
+               "login":     "user",
+               "password":  "secret",
+               "port":      int64(5432),
+               "extra":     `{"sslmode":"require"}`,
+       }
+
+       result, err := decodeConnectionResult(m)
+       require.NoError(t, err)
+       assert.Equal(t, "my_db", result.ConnID)
+       assert.Equal(t, "postgres", result.ConnType)
+       assert.Equal(t, "db.example.com", result.Host)
+       assert.Equal(t, "mydb", result.Schema)
+       assert.Equal(t, "user", result.Login)
+       assert.Equal(t, "secret", result.Password)
+       assert.Equal(t, 5432, result.Port)
+}
+
+func TestDecodeVariableResult(t *testing.T) {
+       m := map[string]any{
+               "type":  "VariableResult",
+               "key":   "my_var",
+               "value": "hello",
+       }
+
+       result, err := decodeVariableResult(m)
+       require.NoError(t, err)
+       assert.Equal(t, "my_var", result.Key)
+       assert.Equal(t, "hello", result.Value)
+}
+
+func TestDecodeXComResult(t *testing.T) {
+       m := map[string]any{
+               "type":  "XComResult",
+               "key":   "return_value",
+               "value": map[string]any{"data": "processed"},
+       }
+
+       result, err := decodeXComResult(m)
+       require.NoError(t, err)
+       assert.Equal(t, "return_value", result.Key)
+       valMap, ok := result.Value.(map[string]any)
+       require.True(t, ok)
+       assert.Equal(t, "processed", valMap["data"])
+}
+
+func TestDecodeErrorResponseNil(t *testing.T) {
+       assert.Nil(t, decodeErrorResponse(nil))
+}
+
+func TestGetConnectionMsgToMap(t *testing.T) {
+       msg := GetConnectionMsg{ConnID: "my_db"}
+       m := msg.toMap()
+       assert.Equal(t, "GetConnection", m["type"])
+       assert.Equal(t, "my_db", m["conn_id"])
+}
+
+func TestGetVariableMsgToMap(t *testing.T) {
+       msg := GetVariableMsg{Key: "my_var"}
+       m := msg.toMap()
+       assert.Equal(t, "GetVariable", m["type"])
+       assert.Equal(t, "my_var", m["key"])
+}
+
+func TestGetXComMsgToMapWithMapIndex(t *testing.T) {
+       mapIdx := 3
+       msg := GetXComMsg{
+               Key:               "result",
+               DagID:             "dag1",
+               TaskID:            "task1",
+               RunID:             "run1",
+               MapIndex:          &mapIdx,
+               IncludePriorDates: true,
+       }
+       m := msg.toMap()
+       assert.Equal(t, "GetXCom", m["type"])
+       assert.Equal(t, 3, m["map_index"])
+       assert.Equal(t, true, m["include_prior_dates"])
+}
+
+func TestGetXComMsgToMapNilMapIndex(t *testing.T) {
+       msg := GetXComMsg{Key: "result", DagID: "d", TaskID: "t", RunID: "r"}
+       m := msg.toMap()
+       _, hasMapIndex := m["map_index"]
+       assert.False(t, hasMapIndex)
+}
+
+func TestSetXComMsgToMap(t *testing.T) {
+       t.Run("nil map_index is omitted", func(t *testing.T) {
+               // Unmapped tasks must omit map_index entirely, matching 
Python's
+               // SetXCom.map_index = None semantics; a -1 sentinel would 
conflate
+               // "unmapped" with "explicit index -1".
+               msg := SetXComMsg{
+                       Key: "output", Value: 42,
+                       DagID: "dag1", TaskID: "task1", RunID: "run1",
+               }
+               m := msg.toMap()
+               assert.Equal(t, "SetXCom", m["type"])
+               assert.Equal(t, 42, m["value"])
+               _, hasMapIndex := m["map_index"]
+               assert.False(t, hasMapIndex)
+               _, hasMappedLength := m["mapped_length"]
+               assert.False(t, hasMappedLength)
+       })
+
+       t.Run("non-nil map_index is emitted", func(t *testing.T) {
+               idx := 3
+               msg := SetXComMsg{
+                       Key: "output", Value: 42,
+                       DagID: "dag1", TaskID: "task1", RunID: "run1", 
MapIndex: &idx,
+               }
+               m := msg.toMap()
+               assert.Equal(t, 3, m["map_index"])
+       })
+}
+
+func TestSucceedTaskMsgToMap(t *testing.T) {
+       endDate := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+       msg := SucceedTaskMsg{EndDate: endDate}
+       m := msg.toMap()
+       assert.Equal(t, "SucceedTask", m["type"])
+       assert.Equal(t, "2024-01-15T10:30:00Z", m["end_date"])
+       assert.Equal(t, []any{}, m["task_outlets"])
+       assert.Equal(t, []any{}, m["outlet_events"])
+}
+
+func TestSucceedTaskMsgToMap_PreservesSubsecondPrecision(t *testing.T) {
+       // end_date is formatted with RFC3339Nano so sub-second precision
+       // round-trips through asTime (which parses RFC3339Nano). Truncating to
+       // whole seconds would lose ordering for closely-spaced terminal events.
+       endDate := time.Date(2024, 1, 15, 10, 30, 0, 123456789, time.UTC)
+       msg := SucceedTaskMsg{EndDate: endDate}
+       m := msg.toMap()
+       assert.Equal(t, "2024-01-15T10:30:00.123456789Z", m["end_date"])
+}
+
+func TestTaskStateMsgToMap(t *testing.T) {
+       endDate := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+       msg := TaskStateMsg{State: TaskStateFailed, EndDate: endDate}
+       m := msg.toMap()
+       assert.Equal(t, "TaskState", m["type"])
+       assert.Equal(t, "failed", m["state"])
+}
+
+func TestTaskStateMsgToMap_PreservesSubsecondPrecision(t *testing.T) {
+       endDate := time.Date(2024, 1, 15, 10, 30, 0, 123456789, time.UTC)
+       msg := TaskStateMsg{State: TaskStateFailed, EndDate: endDate}
+       m := msg.toMap()
+       assert.Equal(t, "2024-01-15T10:30:00.123456789Z", m["end_date"])
+}
+
+func TestTaskStateConstants_WireValues(t *testing.T) {
+       // Pin each enum constant to the exact wire string Python's
+       // TaskInstanceState expects. Renaming these constants is fine;
+       // changing the wire value would silently break the protocol.
+       cases := map[TaskState]string{
+               TaskStateFailed:  "failed",
+               TaskStateRemoved: "removed",
+               TaskStateSkipped: "skipped",
+       }
+       for state, wire := range cases {
+               assert.Equal(t, wire, string(state))
+               m := TaskStateMsg{State: state}.toMap()
+               assert.Equal(t, wire, m["state"], "wire value for %s", state)
+       }
+}
+
+func TestDecodeIncomingBodyDispatch(t *testing.T) {
+       t.Run("ConnectionResult", func(t *testing.T) {
+               body := map[string]any{"type": "ConnectionResult", "conn_id": 
"x"}
+               result, err := decodeIncomingBody(body)
+               require.NoError(t, err)
+               _, ok := result.(*ConnectionResult)
+               assert.True(t, ok)
+       })
+
+       t.Run("nil", func(t *testing.T) {
+               result, err := decodeIncomingBody(nil)
+               require.NoError(t, err)
+               assert.Nil(t, result)
+       })
+
+       t.Run("unknown type", func(t *testing.T) {
+               _, err := decodeIncomingBody(map[string]any{"type": 
"UnknownMsg"})
+               assert.Error(t, err)
+       })
+}
+
+func TestAsTime(t *testing.T) {
+       t.Run("from string", func(t *testing.T) {
+               ts, err := asTime("2024-01-15T10:30:00Z")
+               require.NoError(t, err)
+               assert.Equal(t, 2024, ts.Year())
+               assert.Equal(t, time.January, ts.Month())
+       })
+
+       t.Run("from time.Time", func(t *testing.T) {
+               now := time.Now()
+               ts, err := asTime(now)
+               require.NoError(t, err)
+               assert.Equal(t, now, ts)
+       })
+
+       t.Run("nil", func(t *testing.T) {
+               _, err := asTime(nil)
+               assert.Error(t, err)
+       })
+
+       t.Run("wrong type", func(t *testing.T) {
+               _, err := asTime(42)
+               assert.Error(t, err)
+       })
+}
diff --git a/go-sdk/pkg/sdkcontext/keys.go b/go-sdk/pkg/sdkcontext/keys.go
index 0dbc2c60194..ad83dfce3bd 100644
--- a/go-sdk/pkg/sdkcontext/keys.go
+++ b/go-sdk/pkg/sdkcontext/keys.go
@@ -22,6 +22,7 @@ type (
        apiClientContextKey struct{}
        workerContextKey    struct{}
        runtimeTIContextKey struct{}
+       sdkClientContextKey struct{}
 )
 
 var (
@@ -32,4 +33,11 @@ var (
        RuntimeTIContextKey = runtimeTIContextKey{}
        ApiClientContextKey = apiClientContextKey{}
        WorkerContextKey    = workerContextKey{}
+
+       // SdkClientContextKey, when present, holds an sdk.Client implementation
+       // that should be injected into task functions instead of constructing a
+       // default HTTP-backed client. The coordinator-mode runtime uses this to
+       // route task SDK calls (GetVariable, GetConnection, ...) over the
+       // supervisor comm socket rather than to the Execution API.
+       SdkClientContextKey = sdkClientContextKey{}
 )
diff --git a/go-sdk/sdk/client.go b/go-sdk/sdk/client.go
index d2aa1f40351..f719fdc4787 100644
--- a/go-sdk/sdk/client.go
+++ b/go-sdk/sdk/client.go
@@ -57,6 +57,12 @@ func (*client) GetVariable(ctx context.Context, key string) 
(string, error) {
                }
                return "", err
        }
+       // TODO: register secret-named variables with a SecretsMasker so the
+       // returned value is automatically redacted from subsequent task logs,
+       // matching Python's airflow.models.variable.Variable.get behaviour.
+       // Pairs with the "TODO: mask secrets here" hook in
+       // pkg/worker/runner.go's task log handler — both halves are needed
+       // before secret masking actually works end-to-end.
        return *resp.Value, nil
 }
 
@@ -73,6 +79,13 @@ func (c *client) UnmarshalJSONVariable(ctx context.Context, 
key string, pointer
 func (*client) GetConnection(ctx context.Context, connID string) (Connection, 
error) {
        // TODO: Lookup connection from env var (and handle JSON + URI forms)
 
+       // TODO: register Connection.Password and sensitive-keyed entries of
+       // Connection.Extra with a SecretsMasker so they are auto-redacted from
+       // subsequent task logs, matching Python's
+       // airflow.models.connection.Connection.get behaviour. Pairs with the
+       // "TODO: mask secrets here" hook in pkg/worker/runner.go's task log
+       // handler and the matching TODO on GetVariable above.
+
        httpClient := 
ctx.Value(sdkcontext.ApiClientContextKey).(api.ClientInterface)
 
        resp, err := httpClient.Connections().Get(ctx, connID)
@@ -84,7 +97,7 @@ func (*client) GetConnection(ctx context.Context, connID 
string) (Connection, er
                return Connection{}, err
        }
 
-       return connFromAPIResponse(resp)
+       return ConnFromAPIResponse(resp)
 }
 
 func (c *client) PushXCom(
diff --git a/go-sdk/sdk/connection.go b/go-sdk/sdk/connection.go
index 35835c2d523..1d0bcf30b6f 100644
--- a/go-sdk/sdk/connection.go
+++ b/go-sdk/sdk/connection.go
@@ -110,7 +110,11 @@ func (c Connection) GetURI() *url.URL {
        return uri
 }
 
-func connFromAPIResponse(resp *api.ConnectionResponse) (Connection, error) {
+// ConnFromAPIResponse converts an Execution-API ConnectionResponse into the
+// SDK's Connection type. It is exported so other internal SDK packages (for
+// example, the coordinator-mode runtime in bundlev1server/impl/coord) can
+// reuse the same conversion.
+func ConnFromAPIResponse(resp *api.ConnectionResponse) (Connection, error) {
        var err error
        conn := Connection{
                ID:       resp.ConnId,
diff --git a/go-sdk/sdk/sdk.go b/go-sdk/sdk/sdk.go
index 679ab22d8f4..e4f5a61a331 100644
--- a/go-sdk/sdk/sdk.go
+++ b/go-sdk/sdk/sdk.go
@@ -28,6 +28,15 @@ const (
        ConnectionEnvPrefix = "AIRFLOW_CONN_"
 )
 
+// VariableClient reads Airflow Variables.
+//
+// Go has no function overloading, so the "give me the raw string" and
+// "give me a decoded struct" cases are split into two methods rather
+// than one polymorphic call: GetVariable returns the raw string,
+// UnmarshalJSONVariable decodes a JSON-encoded variable into a
+// caller-supplied pointer. This mirrors the std-lib split between
+// os.LookupEnv and json.Unmarshal — each method has one job, and the
+// caller picks based on how the variable was stored.
 type VariableClient interface {
        // GetVariable returns the value of an Airflow Variable.
        //
@@ -43,6 +52,13 @@ type VariableClient interface {
        //                              // Other errors here, such as http 
network timeouts etc.
        //              }
        GetVariable(ctx context.Context, key string) (string, error)
+
+       // UnmarshalJSONVariable fetches a variable and unmarshals its value 
into
+       // pointer via json.Unmarshal. Use this when the variable was stored as 
a
+       // JSON object, array, or number; for plain string variables call
+       // GetVariable directly.
+       //
+       // pointer must be a non-nil pointer, as required by encoding/json.
        UnmarshalJSONVariable(ctx context.Context, key string, pointer any) 
error
 }
 

Reply via email to