This is an automated email from the ASF dual-hosted git repository.
jason810496 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 129d03ed7a9 Go-SDK: Add coordinator-mode protocol primitives and SDK
surface hooks (#67315)
129d03ed7a9 is described below
commit 129d03ed7a94f55d69f3040c5111dda2980b9ae1
Author: Jason(Zhe-You) Liu <[email protected]>
AuthorDate: Wed May 27 13:05:04 2026 +0800
Go-SDK: Add coordinator-mode protocol primitives and SDK surface hooks
(#67315)
* go-sdk: Add coordinator-mode protocol primitives and SDK surface hooks
First step toward landing the Go SDK coordinator-mode runtime
(ADR 0003, msgpack-over-IPC). Scaffolding only -- no entry point is wired
here, so go-plugin / Edge Worker behaviour is unchanged.
Adds the length-prefixed msgpack frame codec and the typed message
envelopes the runtime will exchange with the supervisor, the
sdkcontext.SdkClientContextKey injection hook on bundlev1.Task so a
follow-up PR can swap in a comm-socket-backed sdk.Client, and the small
sdk surface tweaks (ConnFromAPIResponse export, VariableClient interface
docs, secret-masking TODOs) the comm-socket client will rely on. Pulls
in github.com/vmihailenco/msgpack/v5 -- the encoding the supervisor
speaks.
* Fix: Update max frame size constant and enhance payload validation
* Fix: Tighten coordinator-protocol message decoding
- Add strict mapInt for required fields; require try_number on
StartupDetails so supervisor/runtime version-drift surfaces as a
decode error instead of a silent default.
- Use RFC3339Nano for SucceedTask and TaskState end_date so the
sub-second precision asTime parses round-trips on the wire.
- Change SetXComMsg.MapIndex to *int and omit when nil, matching
Python's SetXCom.map_index = None semantics and GetXComMsg.MapIndex.
* Add regression test for writeFrame oversize-payload guard
Pins the guard at the top of writeFrame so a future refactor cannot
silently drop the uint32-overflow protection. Uses unsafe.Slice to build
a fake-length payload without allocating multi-GiB buffers, since the
guard only reads len(payload) before any allocation or byte access.
The matching read-side guard in readFrame is dead code with MaxFrameSize
pinned at the uint32 maximum and cannot be exercised without modifying
production code; documented inline.
---
go-sdk/bundle/bundlev1/task.go | 7 +-
go-sdk/go.mod | 2 +
go-sdk/go.sum | 4 +
go-sdk/pkg/execution/frames.go | 286 +++++++++++++++++++++++
go-sdk/pkg/execution/frames_test.go | 258 +++++++++++++++++++++
go-sdk/pkg/execution/messages.go | 412 ++++++++++++++++++++++++++++++++++
go-sdk/pkg/execution/messages_test.go | 375 +++++++++++++++++++++++++++++++
go-sdk/pkg/sdkcontext/keys.go | 8 +
go-sdk/sdk/client.go | 15 +-
go-sdk/sdk/connection.go | 6 +-
go-sdk/sdk/sdk.go | 16 ++
11 files changed, 1386 insertions(+), 3 deletions(-)
diff --git a/go-sdk/bundle/bundlev1/task.go b/go-sdk/bundle/bundlev1/task.go
index 4271f4892bb..5277f40681c 100644
--- a/go-sdk/bundle/bundlev1/task.go
+++ b/go-sdk/bundle/bundlev1/task.go
@@ -45,7 +45,12 @@ func NewTaskFunction(fn any) (Task, error) {
func (f *taskFunction) Execute(ctx context.Context, logger *slog.Logger) error
{
fnType := f.fn.Type()
- sdkClient := sdk.NewClient()
+ var sdkClient sdk.Client
+ if injected, ok :=
ctx.Value(sdkcontext.SdkClientContextKey).(sdk.Client); ok {
+ sdkClient = injected
+ } else {
+ sdkClient = sdk.NewClient()
+ }
reflectArgs := make([]reflect.Value, fnType.NumIn())
for i := range reflectArgs {
diff --git a/go-sdk/go.mod b/go-sdk/go.mod
index bfb400eee94..f3bfcd4b0f6 100644
--- a/go-sdk/go.mod
+++ b/go-sdk/go.mod
@@ -16,6 +16,7 @@ require (
github.com/spf13/pflag v1.0.10
github.com/spf13/viper v1.20.1
github.com/stretchr/testify v1.11.1
+ github.com/vmihailenco/msgpack/v5 v5.4.1
google.golang.org/grpc v1.79.3
google.golang.org/protobuf v1.36.10
resty.dev/v3 v3.0.0-beta.2
@@ -38,6 +39,7 @@ require (
github.com/spf13/afero v1.12.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
+ github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
go.opentelemetry.io/otel v1.41.0 // indirect
go.opentelemetry.io/otel/trace v1.41.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
diff --git a/go-sdk/go.sum b/go-sdk/go.sum
index 5b7940672b1..a275d6b63c8 100644
--- a/go-sdk/go.sum
+++ b/go-sdk/go.sum
@@ -114,6 +114,10 @@ github.com/stretchr/testify v1.11.1
h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
github.com/stretchr/testify v1.11.1/go.mod
h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0
h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod
h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
+github.com/vmihailenco/msgpack/v5 v5.4.1
h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
+github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod
h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
+github.com/vmihailenco/tagparser/v2 v2.0.0
h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
+github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod
h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
go.opentelemetry.io/auto/sdk v1.2.1
h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod
h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/otel v1.41.0
h1:YlEwVsGAlCvczDILpUXpIpPSL/VPugt7zHThEMLce1c=
diff --git a/go-sdk/pkg/execution/frames.go b/go-sdk/pkg/execution/frames.go
new file mode 100644
index 00000000000..1d8f4671013
--- /dev/null
+++ b/go-sdk/pkg/execution/frames.go
@@ -0,0 +1,286 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package execution
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io"
+
+ "github.com/vmihailenco/msgpack/v5"
+)
+
+// MaxFrameSize is the maximum payload length of a single frame, in bytes.
+// The 4-byte length prefix bounds this to 2^32 - 1; matches Python's cap in
+// task-sdk comms.py:_FrameMixin.as_bytes (n >= 2**32 raises OverflowError).
+const MaxFrameSize = 1<<32 - 1
+
+// IncomingFrame represents a decoded frame received from the comm socket.
+type IncomingFrame struct {
+ ID int
+ Body map[string]any
+ Err map[string]any // non-nil only for response frames (3-element
arrays)
+}
+
+// encodeRequest encodes a request frame (2-element msgpack array: [id, body]).
+func encodeRequest(id int, body map[string]any) ([]byte, error) {
+ var buf bytes.Buffer
+ enc := msgpack.NewEncoder(&buf)
+ enc.UseCompactInts(true)
+
+ if err := enc.EncodeArrayLen(2); err != nil {
+ return nil, err
+ }
+ if err := enc.EncodeInt(int64(id)); err != nil {
+ return nil, err
+ }
+ if err := enc.Encode(body); err != nil {
+ return nil, err
+ }
+ return buf.Bytes(), nil
+}
+
+// writeFrame writes a length-prefixed msgpack payload to the writer.
+// Format: [4-byte big-endian length][payload bytes].
+//
+// The prefix and payload are concatenated into a single buffer and written
+// in one Write call so we never leave a half-framed message on the wire if
+// an io.Writer implementation does a short write between the two halves.
+func writeFrame(w io.Writer, payload []byte) error {
+ // Refuse to send a payload the peer would refuse to read. Without this
+ // guard, lengths >= 4 GiB would silently wrap in the uint32 conversion
+ // below and put a corrupt length prefix on the wire, desynchronising
+ // the peer instead of failing loudly here. Mirrors the OverflowError
+ // raised by task-sdk comms.py:_FrameMixin.as_bytes.
+ if len(payload) > MaxFrameSize {
+ return fmt.Errorf(
+ "frame payload length %d exceeds max %d",
+ len(payload),
+ MaxFrameSize,
+ )
+ }
+ buf := make([]byte, 4+len(payload))
+ binary.BigEndian.PutUint32(buf[:4], uint32(len(payload)))
+ copy(buf[4:], payload)
+ n, err := w.Write(buf)
+ if err != nil {
+ return fmt.Errorf("writing frame: %w", err)
+ }
+ if n < len(buf) {
+ return fmt.Errorf("writing frame: %w", io.ErrShortWrite)
+ }
+ return nil
+}
+
+// readFrame reads one length-prefixed msgpack frame from the reader and
decodes it.
+func readFrame(r io.Reader) (IncomingFrame, error) {
+ // Read 4-byte big-endian length prefix.
+ prefix := make([]byte, 4)
+ if _, err := io.ReadFull(r, prefix); err != nil {
+ return IncomingFrame{}, fmt.Errorf("reading length prefix: %w",
err)
+ }
+ payloadLen := binary.BigEndian.Uint32(prefix)
+ // Reject oversized frames defensively. A non-Python sender (or a
+ // MaxFrameSize lowered for memory-budget reasons) might violate the cap
+ // the reader is willing to allocate, so fail loudly here rather than
+ // trusting the peer.
+ if payloadLen > MaxFrameSize {
+ return IncomingFrame{}, fmt.Errorf(
+ "frame payload length %d exceeds max %d",
+ payloadLen,
+ MaxFrameSize,
+ )
+ }
+ payload := make([]byte, int(payloadLen))
+ if _, err := io.ReadFull(r, payload); err != nil {
+ return IncomingFrame{}, fmt.Errorf("reading payload (%d bytes):
%w", payloadLen, err)
+ }
+
+ return decodeFrame(payload)
+}
+
+// decodeFrame decodes a msgpack payload into an IncomingFrame.
+func decodeFrame(data []byte) (IncomingFrame, error) {
+ dec := msgpack.NewDecoder(bytes.NewReader(data))
+
+ arrLen, err := dec.DecodeArrayLen()
+ if err != nil {
+ return IncomingFrame{}, fmt.Errorf("decoding array header: %w",
err)
+ }
+ if arrLen < 2 {
+ return IncomingFrame{}, fmt.Errorf("unexpected frame arity %d,
need at least 2", arrLen)
+ }
+
+ id64, err := dec.DecodeInt64()
+ if err != nil {
+ return IncomingFrame{}, fmt.Errorf("decoding frame id: %w", err)
+ }
+
+ // Decode the body element.
+ bodyRaw, err := dec.DecodeInterface()
+ if err != nil {
+ return IncomingFrame{}, fmt.Errorf("decoding body: %w", err)
+ }
+ body, ok := toStringMap(bodyRaw)
+ if bodyRaw != nil && !ok {
+ return IncomingFrame{}, fmt.Errorf("body element: expected map,
got %T", bodyRaw)
+ }
+
+ // For response frames (3-element), decode the error element.
+ var errMap map[string]any
+ if arrLen >= 3 {
+ errRaw, err := dec.DecodeInterface()
+ if err != nil {
+ return IncomingFrame{}, fmt.Errorf("decoding error
element: %w", err)
+ }
+ errMap, ok = toStringMap(errRaw)
+ if errRaw != nil && !ok {
+ return IncomingFrame{}, fmt.Errorf("error element:
expected map, got %T", errRaw)
+ }
+ }
+
+ return IncomingFrame{
+ ID: int(id64),
+ Body: body,
+ Err: errMap,
+ }, nil
+}
+
+// toStringMap converts a decoded interface{} to map[string]any.
+// Returns nil, false if the input is nil or not a map.
+func toStringMap(v any) (map[string]any, bool) {
+ if v == nil {
+ return nil, false
+ }
+ switch m := v.(type) {
+ case map[string]any:
+ return m, true
+ case map[any]any:
+ result := make(map[string]any, len(m))
+ for k, val := range m {
+ result[fmt.Sprint(k)] = val
+ }
+ return result, true
+ default:
+ return nil, false
+ }
+}
+
+// mapString extracts a string value from a map.
+func mapString(m map[string]any, key string) (string, error) {
+ v, ok := m[key]
+ if !ok {
+ return "", fmt.Errorf("missing key %q", key)
+ }
+ s, ok := v.(string)
+ if !ok {
+ return "", fmt.Errorf("key %q: expected string, got %T", key, v)
+ }
+ return s, nil
+}
+
+// mapInt extracts an int value from a map. Returns an error if the key is
+// missing or the value is not a numeric type. Use this for fields the
+// supervisor is contractually required to send (e.g. try_number); a silent
+// default would mask supervisor/runtime version-drift bugs.
+func mapInt(m map[string]any, key string) (int, error) {
+ v, ok := m[key]
+ if !ok {
+ return 0, fmt.Errorf("missing key %q", key)
+ }
+ n, err := toInt(v)
+ if err != nil {
+ return 0, fmt.Errorf("key %q: %w", key, err)
+ }
+ return n, nil
+}
+
+// mapIntOr extracts an int value from a map, returning the default when the
+// key is missing OR the value is not a numeric type. Use this only for
+// genuinely optional fields where any decoding hiccup should fall back to
+// the default; for required fields, use mapInt.
+func mapIntOr(m map[string]any, key string, def int) int {
+ v, ok := m[key]
+ if !ok {
+ return def
+ }
+ n, err := toInt(v)
+ if err != nil {
+ return def
+ }
+ return n
+}
+
+// mapStringOr extracts a string value from a map, returning the default if
missing.
+func mapStringOr(m map[string]any, key string, def string) string {
+ v, ok := m[key]
+ if !ok {
+ return def
+ }
+ s, ok := v.(string)
+ if !ok {
+ return def
+ }
+ return s
+}
+
+// mapMap extracts a nested map from a map.
+func mapMap(m map[string]any, key string) map[string]any {
+ v, ok := m[key]
+ if !ok || v == nil {
+ return nil
+ }
+ sub, ok := toStringMap(v)
+ if !ok {
+ return nil
+ }
+ return sub
+}
+
+// toInt converts various numeric types from msgpack decoding to int.
+func toInt(v any) (int, error) {
+ switch n := v.(type) {
+ case int:
+ return n, nil
+ case int8:
+ return int(n), nil
+ case int16:
+ return int(n), nil
+ case int32:
+ return int(n), nil
+ case int64:
+ return int(n), nil
+ case uint:
+ return int(n), nil
+ case uint8:
+ return int(n), nil
+ case uint16:
+ return int(n), nil
+ case uint32:
+ return int(n), nil
+ case uint64:
+ return int(n), nil
+ case float32:
+ return int(n), nil
+ case float64:
+ return int(n), nil
+ default:
+ return 0, fmt.Errorf("expected numeric, got %T", v)
+ }
+}
diff --git a/go-sdk/pkg/execution/frames_test.go
b/go-sdk/pkg/execution/frames_test.go
new file mode 100644
index 00000000000..18d247bf362
--- /dev/null
+++ b/go-sdk/pkg/execution/frames_test.go
@@ -0,0 +1,258 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package execution
+
+import (
+ "bytes"
+ "encoding/binary"
+ "strconv"
+ "testing"
+ "unsafe"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/vmihailenco/msgpack/v5"
+)
+
+func TestEncodeRequest(t *testing.T) {
+ body := map[string]any{
+ "type": "GetVariable",
+ "key": "my_var",
+ }
+
+ data, err := encodeRequest(42, body)
+ require.NoError(t, err)
+
+ // Decode and verify structure.
+ dec := msgpack.NewDecoder(bytes.NewReader(data))
+ arrLen, err := dec.DecodeArrayLen()
+ require.NoError(t, err)
+ assert.Equal(t, 2, arrLen, "request frame should be 2-element array")
+
+ id, err := dec.DecodeInt64()
+ require.NoError(t, err)
+ assert.Equal(t, int64(42), id)
+
+ var decodedBody map[string]any
+ err = dec.Decode(&decodedBody)
+ require.NoError(t, err)
+ assert.Equal(t, "GetVariable", decodedBody["type"])
+ assert.Equal(t, "my_var", decodedBody["key"])
+}
+
+func TestWriteAndReadFrame(t *testing.T) {
+ body := map[string]any{
+ "type": "GetConnection",
+ "conn_id": "my_db",
+ }
+
+ payload, err := encodeRequest(7, body)
+ require.NoError(t, err)
+
+ // Write to buffer with length prefix.
+ var buf bytes.Buffer
+ err = writeFrame(&buf, payload)
+ require.NoError(t, err)
+
+ // Verify length prefix.
+ prefix := buf.Bytes()[:4]
+ expectedLen := uint32(len(payload))
+ assert.Equal(t, expectedLen, binary.BigEndian.Uint32(prefix))
+
+ // Read back.
+ frame, err := readFrame(&buf)
+ require.NoError(t, err)
+ assert.Equal(t, 7, frame.ID)
+ assert.Equal(t, "GetConnection", frame.Body["type"])
+ assert.Equal(t, "my_db", frame.Body["conn_id"])
+ assert.Nil(t, frame.Err)
+}
+
+func TestDecodeResponseFrame(t *testing.T) {
+ // Encode a 3-element response frame: [id, body, error]
+ var buf bytes.Buffer
+ enc := msgpack.NewEncoder(&buf)
+ enc.UseCompactInts(true)
+
+ require.NoError(t, enc.EncodeArrayLen(3))
+ require.NoError(t, enc.EncodeInt(5))
+ require.NoError(t, enc.Encode(map[string]any{
+ "type": "ConnectionResult",
+ "conn_id": "test_conn",
+ "host": "localhost",
+ }))
+ require.NoError(t, enc.Encode(nil)) // no error
+
+ frame, err := decodeFrame(buf.Bytes())
+ require.NoError(t, err)
+ assert.Equal(t, 5, frame.ID)
+ assert.Equal(t, "ConnectionResult", frame.Body["type"])
+ assert.Equal(t, "localhost", frame.Body["host"])
+ assert.Nil(t, frame.Err)
+}
+
+func TestDecodeResponseFrameWithError(t *testing.T) {
+ var buf bytes.Buffer
+ enc := msgpack.NewEncoder(&buf)
+ enc.UseCompactInts(true)
+
+ require.NoError(t, enc.EncodeArrayLen(3))
+ require.NoError(t, enc.EncodeInt(3))
+ require.NoError(t, enc.Encode(nil)) // nil body
+ require.NoError(t, enc.Encode(map[string]any{
+ "type": "ErrorResponse",
+ "error": "not_found",
+ "detail": "Variable 'x' not found",
+ }))
+
+ frame, err := decodeFrame(buf.Bytes())
+ require.NoError(t, err)
+ assert.Equal(t, 3, frame.ID)
+ assert.Nil(t, frame.Body)
+ assert.NotNil(t, frame.Err)
+ assert.Equal(t, "not_found", frame.Err["error"])
+}
+
+func TestDecodeFrameRejectsNonMapBody(t *testing.T) {
+ // A non-nil, non-map body element is a protocol violation; the decoder
+ // must surface it instead of silently turning the body into nil.
+ var buf bytes.Buffer
+ enc := msgpack.NewEncoder(&buf)
+ enc.UseCompactInts(true)
+
+ require.NoError(t, enc.EncodeArrayLen(2))
+ require.NoError(t, enc.EncodeInt(1))
+ require.NoError(t, enc.EncodeString("not a map"))
+
+ _, err := decodeFrame(buf.Bytes())
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "body element: expected map")
+}
+
+func TestDecodeFrameRejectsNonMapError(t *testing.T) {
+ // Same rule applies to the error element of a 3-tuple response frame.
+ var buf bytes.Buffer
+ enc := msgpack.NewEncoder(&buf)
+ enc.UseCompactInts(true)
+
+ require.NoError(t, enc.EncodeArrayLen(3))
+ require.NoError(t, enc.EncodeInt(2))
+ require.NoError(t, enc.Encode(nil))
+ require.NoError(t, enc.EncodeString("not a map"))
+
+ _, err := decodeFrame(buf.Bytes())
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "error element: expected map")
+}
+
+// TestWriteFrameRejectsOversizedPayload pins the guard at the top of
+// writeFrame against the rename/refactor that previously dropped its
+// coverage. The guard only inspects len(payload) before doing any allocation
+// or read of payload bytes, so we hand it a fake-length slice built with
+// unsafe.Slice (one real byte of backing storage, length > MaxFrameSize)
+// rather than allocating 4 GiB of real memory.
+//
+// The matching read-side guard at the top of readFrame is dead code with
+// MaxFrameSize pinned at the uint32 maximum (payloadLen is uint32, so
+// payloadLen > MaxFrameSize is never true) and cannot be exercised without
+// modifying production code; it remains as defense-in-depth in case
+// MaxFrameSize is ever lowered.
+func TestWriteFrameRejectsOversizedPayload(t *testing.T) {
+ if strconv.IntSize < 64 {
+ t.Skip("requires 64-bit int to construct a slice longer than
MaxFrameSize")
+ }
+ var backing byte
+ payload := unsafe.Slice(&backing, uint64(MaxFrameSize)+1)
+
+ err := writeFrame(&bytes.Buffer{}, payload)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "exceeds max")
+}
+
+func TestRoundTripMultipleFrames(t *testing.T) {
+ var buf bytes.Buffer
+
+ // Write two frames.
+ bodies := []map[string]any{
+ {"type": "GetVariable", "key": "v1"},
+ {"type": "GetVariable", "key": "v2"},
+ }
+ for i, body := range bodies {
+ payload, err := encodeRequest(i, body)
+ require.NoError(t, err)
+ require.NoError(t, writeFrame(&buf, payload))
+ }
+
+ // Read them back.
+ for i, expected := range bodies {
+ frame, err := readFrame(&buf)
+ require.NoError(t, err)
+ assert.Equal(t, i, frame.ID)
+ assert.Equal(t, expected["key"], frame.Body["key"])
+ }
+}
+
+func TestToStringMap(t *testing.T) {
+ tests := []struct {
+ name string
+ input any
+ want map[string]any
+ ok bool
+ }{
+ {"nil", nil, nil, false},
+ {"string map", map[string]any{"a": 1}, map[string]any{"a": 1},
true},
+ {"any key map", map[any]any{"b": 2}, map[string]any{"b": 2},
true},
+ {"not a map", "hello", nil, false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, ok := toStringMap(tt.input)
+ assert.Equal(t, tt.ok, ok)
+ if tt.ok {
+ assert.Equal(t, tt.want, got)
+ }
+ })
+ }
+}
+
+func TestToInt(t *testing.T) {
+ tests := []struct {
+ input any
+ want int
+ }{
+ {int8(42), 42},
+ {int16(42), 42},
+ {int32(42), 42},
+ {int64(42), 42},
+ {uint8(42), 42},
+ {uint16(42), 42},
+ {uint32(42), 42},
+ {uint64(42), 42},
+ {float32(42.0), 42},
+ {float64(42.0), 42},
+ {int(42), 42},
+ }
+ for _, tt := range tests {
+ got, err := toInt(tt.input)
+ require.NoError(t, err)
+ assert.Equal(t, tt.want, got)
+ }
+
+ _, err := toInt("not a number")
+ assert.Error(t, err)
+}
diff --git a/go-sdk/pkg/execution/messages.go b/go-sdk/pkg/execution/messages.go
new file mode 100644
index 00000000000..f7d9ae6cc27
--- /dev/null
+++ b/go-sdk/pkg/execution/messages.go
@@ -0,0 +1,412 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package execution
+
+import (
+ "fmt"
+ "time"
+)
+
+// Inbound messages (Supervisor -> Runtime).
+
+// TaskInstanceInfo holds task instance details from StartupDetails.
+type TaskInstanceInfo struct {
+ ID string
+ TaskID string
+ DagID string
+ RunID string
+ TryNumber int
+ DagVersionID string
+ MapIndex int
+ ContextCarrier map[string]any
+}
+
+func decodeTaskInstanceInfo(m map[string]any) (TaskInstanceInfo, error) {
+ if m == nil {
+ return TaskInstanceInfo{}, fmt.Errorf("nil task instance map")
+ }
+ id, err := mapString(m, "id")
+ if err != nil {
+ return TaskInstanceInfo{}, fmt.Errorf("ti.id: %w", err)
+ }
+ taskID, err := mapString(m, "task_id")
+ if err != nil {
+ return TaskInstanceInfo{}, fmt.Errorf("ti.task_id: %w", err)
+ }
+ dagID, err := mapString(m, "dag_id")
+ if err != nil {
+ return TaskInstanceInfo{}, fmt.Errorf("ti.dag_id: %w", err)
+ }
+ runID, err := mapString(m, "run_id")
+ if err != nil {
+ return TaskInstanceInfo{}, fmt.Errorf("ti.run_id: %w", err)
+ }
+ tryNumber, err := mapInt(m, "try_number")
+ if err != nil {
+ return TaskInstanceInfo{}, fmt.Errorf("ti.try_number: %w", err)
+ }
+ dagVersionID := mapStringOr(m, "dag_version_id", "")
+ mapIndex := mapIntOr(m, "map_index", -1)
+ contextCarrier := mapMap(m, "context_carrier")
+
+ return TaskInstanceInfo{
+ ID: id,
+ TaskID: taskID,
+ DagID: dagID,
+ RunID: runID,
+ TryNumber: tryNumber,
+ DagVersionID: dagVersionID,
+ MapIndex: mapIndex,
+ ContextCarrier: contextCarrier,
+ }, nil
+}
+
+// BundleInfoMsg holds bundle identification from StartupDetails.
+type BundleInfoMsg struct {
+ Name string
+ Version string
+}
+
+func decodeBundleInfo(m map[string]any) BundleInfoMsg {
+ if m == nil {
+ return BundleInfoMsg{}
+ }
+ return BundleInfoMsg{
+ Name: mapStringOr(m, "name", ""),
+ Version: mapStringOr(m, "version", ""),
+ }
+}
+
+// TIRunContext holds the runtime context for a task instance.
+type TIRunContext struct {
+ LogicalDate *time.Time
+ DataIntervalStart *time.Time
+ DataIntervalEnd *time.Time
+}
+
+func decodeTIRunContext(m map[string]any) (TIRunContext, error) {
+ if m == nil {
+ return TIRunContext{}, nil
+ }
+ ctx := TIRunContext{}
+ for _, f := range []struct {
+ key string
+ dst **time.Time
+ }{
+ {"logical_date", &ctx.LogicalDate},
+ {"data_interval_start", &ctx.DataIntervalStart},
+ {"data_interval_end", &ctx.DataIntervalEnd},
+ } {
+ raw, present := m[f.key]
+ if !present || raw == nil {
+ continue
+ }
+ t, err := asTime(raw)
+ if err != nil {
+ return TIRunContext{}, fmt.Errorf("ti_context.%s: %w",
f.key, err)
+ }
+ *f.dst = &t
+ }
+ return ctx, nil
+}
+
+// StartupDetails is sent by the supervisor to initiate task execution.
+type StartupDetails struct {
+ TI TaskInstanceInfo
+ DagRelPath string
+ BundleInfo BundleInfoMsg
+ StartDate time.Time
+ TIContext TIRunContext
+ SentryIntegration string
+}
+
+func decodeStartupDetails(m map[string]any) (*StartupDetails, error) {
+ tiMap := mapMap(m, "ti")
+ ti, err := decodeTaskInstanceInfo(tiMap)
+ if err != nil {
+ return nil, fmt.Errorf("decoding ti: %w", err)
+ }
+
+ dagRelPath := mapStringOr(m, "dag_rel_path", "")
+ bundleInfo := decodeBundleInfo(mapMap(m, "bundle_info"))
+
+ var startDate time.Time
+ if raw, present := m["start_date"]; present && raw != nil {
+ startDate, err = asTime(raw)
+ if err != nil {
+ return nil, fmt.Errorf("start_date: %w", err)
+ }
+ }
+
+ tiContext, err := decodeTIRunContext(mapMap(m, "ti_context"))
+ if err != nil {
+ return nil, fmt.Errorf("decoding ti_context: %w", err)
+ }
+ sentryIntegration := mapStringOr(m, "sentry_integration", "")
+
+ return &StartupDetails{
+ TI: ti,
+ DagRelPath: dagRelPath,
+ BundleInfo: bundleInfo,
+ StartDate: startDate,
+ TIContext: tiContext,
+ SentryIntegration: sentryIntegration,
+ }, nil
+}
+
+// Response types (for runtime-initiated requests).
+
+// ConnectionResult is the response to GetConnection.
+type ConnectionResult struct {
+ ConnID string
+ ConnType string
+ Host string
+ Schema string
+ Login string
+ Password string
+ Port int
+ Extra string
+}
+
+func decodeConnectionResult(m map[string]any) (*ConnectionResult, error) {
+ return &ConnectionResult{
+ ConnID: mapStringOr(m, "conn_id", ""),
+ ConnType: mapStringOr(m, "conn_type", ""),
+ Host: mapStringOr(m, "host", ""),
+ Schema: mapStringOr(m, "schema", ""),
+ Login: mapStringOr(m, "login", ""),
+ Password: mapStringOr(m, "password", ""),
+ Port: mapIntOr(m, "port", 0),
+ Extra: mapStringOr(m, "extra", ""),
+ }, nil
+}
+
+// VariableResult is the response to GetVariable.
+type VariableResult struct {
+ Key string
+ Value any
+}
+
+func decodeVariableResult(m map[string]any) (*VariableResult, error) {
+ return &VariableResult{
+ Key: mapStringOr(m, "key", ""),
+ Value: m["value"],
+ }, nil
+}
+
+// XComResult is the response to GetXCom.
+type XComResult struct {
+ Key string
+ Value any
+}
+
+func decodeXComResult(m map[string]any) (*XComResult, error) {
+ return &XComResult{
+ Key: mapStringOr(m, "key", ""),
+ Value: m["value"],
+ }, nil
+}
+
+// ErrorResponse represents an error returned by the supervisor.
+type ErrorResponse struct {
+ Error string
+ Detail any
+}
+
+func decodeErrorResponse(m map[string]any) *ErrorResponse {
+ if m == nil {
+ return nil
+ }
+ return &ErrorResponse{
+ Error: mapStringOr(m, "error", ""),
+ Detail: m["detail"],
+ }
+}
+
+// Outbound messages (Runtime -> Supervisor).
+
+// GetConnectionMsg is sent to request a connection from the supervisor.
+type GetConnectionMsg struct {
+ ConnID string
+}
+
+func (m GetConnectionMsg) toMap() map[string]any {
+ return map[string]any{
+ "type": "GetConnection",
+ "conn_id": m.ConnID,
+ }
+}
+
+// GetVariableMsg is sent to request a variable from the supervisor.
+type GetVariableMsg struct {
+ Key string
+}
+
+func (m GetVariableMsg) toMap() map[string]any {
+ return map[string]any{
+ "type": "GetVariable",
+ "key": m.Key,
+ }
+}
+
+// GetXComMsg is sent to request an XCom value from the supervisor.
+type GetXComMsg struct {
+ Key string
+ DagID string
+ TaskID string
+ RunID string
+ MapIndex *int
+ IncludePriorDates bool
+}
+
+func (m GetXComMsg) toMap() map[string]any {
+ result := map[string]any{
+ "type": "GetXCom",
+ "key": m.Key,
+ "dag_id": m.DagID,
+ "task_id": m.TaskID,
+ "run_id": m.RunID,
+ "include_prior_dates": m.IncludePriorDates,
+ }
+ if m.MapIndex != nil {
+ result["map_index"] = *m.MapIndex
+ }
+ return result
+}
+
+// SetXComMsg is sent to set an XCom value. MapIndex mirrors Python's
+// SetXCom.map_index (int | None): nil means "unmapped task", and is omitted
+// from the wire payload rather than encoded as a -1 sentinel.
+type SetXComMsg struct {
+ Key string
+ Value any
+ DagID string
+ TaskID string
+ RunID string
+ MapIndex *int
+ MappedLength *int
+}
+
+func (m SetXComMsg) toMap() map[string]any {
+ result := map[string]any{
+ "type": "SetXCom",
+ "key": m.Key,
+ "value": m.Value,
+ "dag_id": m.DagID,
+ "task_id": m.TaskID,
+ "run_id": m.RunID,
+ }
+ if m.MapIndex != nil {
+ result["map_index"] = *m.MapIndex
+ }
+ if m.MappedLength != nil {
+ result["mapped_length"] = *m.MappedLength
+ }
+ return result
+}
+
+// SucceedTaskMsg is sent as a terminal message when a task succeeds.
+type SucceedTaskMsg struct {
+ EndDate time.Time
+ TaskOutlets []any
+ OutletEvents []any
+}
+
+func (m SucceedTaskMsg) toMap() map[string]any {
+ taskOutlets := m.TaskOutlets
+ if taskOutlets == nil {
+ taskOutlets = []any{}
+ }
+ outletEvents := m.OutletEvents
+ if outletEvents == nil {
+ outletEvents = []any{}
+ }
+ return map[string]any{
+ "type": "SucceedTask",
+ "end_date": m.EndDate.UTC().Format(time.RFC3339Nano),
+ "task_outlets": taskOutlets,
+ "outlet_events": outletEvents,
+ }
+}
+
+// TaskState is the terminal non-success state reported via TaskStateMsg.
+// The wire values match Python's TaskInstanceState enum (and the generated
+// api.TerminalStateNonSuccess); we define a local typed string so call
+// sites get compile-time checking and don't have to import pkg/api just
+// for the constants.
+type TaskState string
+
+const (
+ TaskStateFailed TaskState = "failed"
+ TaskStateRemoved TaskState = "removed"
+ TaskStateSkipped TaskState = "skipped"
+)
+
+// TaskStateMsg is sent as a terminal message for failed/removed/skipped tasks.
+type TaskStateMsg struct {
+ State TaskState
+ EndDate time.Time
+}
+
+func (m TaskStateMsg) toMap() map[string]any {
+ return map[string]any{
+ "type": "TaskState",
+ "state": string(m.State),
+ "end_date": m.EndDate.UTC().Format(time.RFC3339Nano),
+ }
+}
+
+// Message dispatch.
+
+// decodeIncomingBody dispatches decoding of a body map based on its "type"
field.
+func decodeIncomingBody(m map[string]any) (any, error) {
+ if m == nil {
+ return nil, nil
+ }
+ typ, _ := m["type"].(string)
+ switch typ {
+ case "StartupDetails":
+ return decodeStartupDetails(m)
+ case "ConnectionResult":
+ return decodeConnectionResult(m)
+ case "VariableResult":
+ return decodeVariableResult(m)
+ case "XComResult":
+ return decodeXComResult(m)
+ case "ErrorResponse":
+ return decodeErrorResponse(m), nil
+ default:
+ return nil, fmt.Errorf("unknown message type %q", typ)
+ }
+}
+
+// asTime parses a time value that may be a time.Time (from msgpack timestamp
ext)
+// or a string (ISO 8601 format).
+func asTime(v any) (time.Time, error) {
+ if v == nil {
+ return time.Time{}, fmt.Errorf("nil time value")
+ }
+ switch t := v.(type) {
+ case time.Time:
+ return t, nil
+ case string:
+ return time.Parse(time.RFC3339Nano, t)
+ default:
+ return time.Time{}, fmt.Errorf("expected time, got %T", v)
+ }
+}
diff --git a/go-sdk/pkg/execution/messages_test.go
b/go-sdk/pkg/execution/messages_test.go
new file mode 100644
index 00000000000..ed3951b5ba9
--- /dev/null
+++ b/go-sdk/pkg/execution/messages_test.go
@@ -0,0 +1,375 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package execution
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestDecodeStartupDetails(t *testing.T) {
+ m := map[string]any{
+ "type": "StartupDetails",
+ "ti": map[string]any{
+ "id":
"550e8400-e29b-41d4-a716-446655440000",
+ "task_id": "extract",
+ "dag_id": "tutorial_dag",
+ "run_id": "manual__2024-01-15",
+ "try_number": int64(1),
+ "dag_version_id": "abc-123",
+ "map_index": int64(-1),
+ },
+ "dag_rel_path": "dags/tutorial.go",
+ "bundle_info": map[string]any{
+ "name": "example_dags",
+ "version": "1.0.0",
+ },
+ "start_date": "2024-01-15T10:30:00Z",
+ "sentry_integration": "",
+ "ti_context": map[string]any{
+ "logical_date": "2024-01-15T00:00:00Z",
+ "data_interval_start": "2024-01-14T00:00:00Z",
+ "data_interval_end": "2024-01-15T00:00:00Z",
+ },
+ }
+
+ details, err := decodeStartupDetails(m)
+ require.NoError(t, err)
+
+ assert.Equal(t, "550e8400-e29b-41d4-a716-446655440000", details.TI.ID)
+ assert.Equal(t, "extract", details.TI.TaskID)
+ assert.Equal(t, "tutorial_dag", details.TI.DagID)
+ assert.Equal(t, "manual__2024-01-15", details.TI.RunID)
+ assert.Equal(t, 1, details.TI.TryNumber)
+ assert.Equal(t, -1, details.TI.MapIndex)
+ assert.Equal(t, "dags/tutorial.go", details.DagRelPath)
+ assert.Equal(t, "example_dags", details.BundleInfo.Name)
+ assert.Equal(t, "1.0.0", details.BundleInfo.Version)
+ assert.NotNil(t, details.TIContext.LogicalDate)
+}
+
+func TestDecodeStartupDetails_MalformedStartDate(t *testing.T) {
+ // A present-but-malformed start_date must surface as a decode error;
+ // silently leaving startDate as the zero time would let tasks run with
+ // incorrect provenance.
+ m := map[string]any{
+ "type": "StartupDetails",
+ "ti": map[string]any{
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "task_id": "t", "dag_id": "d", "run_id": "r",
+ "try_number": int64(1),
+ },
+ "start_date": "not-a-timestamp",
+ }
+ _, err := decodeStartupDetails(m)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "start_date")
+}
+
+func TestDecodeStartupDetails_MalformedTIRunContext(t *testing.T) {
+ m := map[string]any{
+ "type": "StartupDetails",
+ "ti": map[string]any{
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "task_id": "t", "dag_id": "d", "run_id": "r",
+ "try_number": int64(1),
+ },
+ "ti_context": map[string]any{
+ "logical_date": "garbage",
+ },
+ }
+ _, err := decodeStartupDetails(m)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "logical_date")
+}
+
+func TestDecodeStartupDetails_RequiresTryNumber(t *testing.T) {
+ // try_number is required in Python's TaskInstance model; a missing or
+ // wrong-typed value must surface as a decode error rather than silently
+ // defaulting and masking supervisor/runtime version-drift bugs.
+ t.Run("missing", func(t *testing.T) {
+ m := map[string]any{
+ "type": "StartupDetails",
+ "ti": map[string]any{
+ "id":
"550e8400-e29b-41d4-a716-446655440000",
+ "task_id": "t", "dag_id": "d", "run_id": "r",
+ },
+ }
+ _, err := decodeStartupDetails(m)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "try_number")
+ })
+
+ t.Run("wrong type", func(t *testing.T) {
+ m := map[string]any{
+ "type": "StartupDetails",
+ "ti": map[string]any{
+ "id":
"550e8400-e29b-41d4-a716-446655440000",
+ "task_id": "t",
+ "dag_id": "d",
+ "run_id": "r",
+ "try_number": "1",
+ },
+ }
+ _, err := decodeStartupDetails(m)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "try_number")
+ })
+}
+
+func TestDecodeStartupDetails_MissingOptionalTimestamps(t *testing.T) {
+ // start_date and ti_context fields are optional; omitting them must
+ // still decode cleanly (no error, zero/nil values).
+ m := map[string]any{
+ "type": "StartupDetails",
+ "ti": map[string]any{
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "task_id": "t", "dag_id": "d", "run_id": "r",
+ "try_number": int64(1),
+ },
+ }
+ details, err := decodeStartupDetails(m)
+ require.NoError(t, err)
+ assert.True(t, details.StartDate.IsZero())
+ assert.Nil(t, details.TIContext.LogicalDate)
+ assert.Nil(t, details.TIContext.DataIntervalStart)
+ assert.Nil(t, details.TIContext.DataIntervalEnd)
+}
+
+func TestDecodeConnectionResult(t *testing.T) {
+ m := map[string]any{
+ "type": "ConnectionResult",
+ "conn_id": "my_db",
+ "conn_type": "postgres",
+ "host": "db.example.com",
+ "schema": "mydb",
+ "login": "user",
+ "password": "secret",
+ "port": int64(5432),
+ "extra": `{"sslmode":"require"}`,
+ }
+
+ result, err := decodeConnectionResult(m)
+ require.NoError(t, err)
+ assert.Equal(t, "my_db", result.ConnID)
+ assert.Equal(t, "postgres", result.ConnType)
+ assert.Equal(t, "db.example.com", result.Host)
+ assert.Equal(t, "mydb", result.Schema)
+ assert.Equal(t, "user", result.Login)
+ assert.Equal(t, "secret", result.Password)
+ assert.Equal(t, 5432, result.Port)
+}
+
+func TestDecodeVariableResult(t *testing.T) {
+ m := map[string]any{
+ "type": "VariableResult",
+ "key": "my_var",
+ "value": "hello",
+ }
+
+ result, err := decodeVariableResult(m)
+ require.NoError(t, err)
+ assert.Equal(t, "my_var", result.Key)
+ assert.Equal(t, "hello", result.Value)
+}
+
+func TestDecodeXComResult(t *testing.T) {
+ m := map[string]any{
+ "type": "XComResult",
+ "key": "return_value",
+ "value": map[string]any{"data": "processed"},
+ }
+
+ result, err := decodeXComResult(m)
+ require.NoError(t, err)
+ assert.Equal(t, "return_value", result.Key)
+ valMap, ok := result.Value.(map[string]any)
+ require.True(t, ok)
+ assert.Equal(t, "processed", valMap["data"])
+}
+
+func TestDecodeErrorResponseNil(t *testing.T) {
+ assert.Nil(t, decodeErrorResponse(nil))
+}
+
+func TestGetConnectionMsgToMap(t *testing.T) {
+ msg := GetConnectionMsg{ConnID: "my_db"}
+ m := msg.toMap()
+ assert.Equal(t, "GetConnection", m["type"])
+ assert.Equal(t, "my_db", m["conn_id"])
+}
+
+func TestGetVariableMsgToMap(t *testing.T) {
+ msg := GetVariableMsg{Key: "my_var"}
+ m := msg.toMap()
+ assert.Equal(t, "GetVariable", m["type"])
+ assert.Equal(t, "my_var", m["key"])
+}
+
+func TestGetXComMsgToMapWithMapIndex(t *testing.T) {
+ mapIdx := 3
+ msg := GetXComMsg{
+ Key: "result",
+ DagID: "dag1",
+ TaskID: "task1",
+ RunID: "run1",
+ MapIndex: &mapIdx,
+ IncludePriorDates: true,
+ }
+ m := msg.toMap()
+ assert.Equal(t, "GetXCom", m["type"])
+ assert.Equal(t, 3, m["map_index"])
+ assert.Equal(t, true, m["include_prior_dates"])
+}
+
+func TestGetXComMsgToMapNilMapIndex(t *testing.T) {
+ msg := GetXComMsg{Key: "result", DagID: "d", TaskID: "t", RunID: "r"}
+ m := msg.toMap()
+ _, hasMapIndex := m["map_index"]
+ assert.False(t, hasMapIndex)
+}
+
+func TestSetXComMsgToMap(t *testing.T) {
+ t.Run("nil map_index is omitted", func(t *testing.T) {
+ // Unmapped tasks must omit map_index entirely, matching
Python's
+ // SetXCom.map_index = None semantics; a -1 sentinel would
conflate
+ // "unmapped" with "explicit index -1".
+ msg := SetXComMsg{
+ Key: "output", Value: 42,
+ DagID: "dag1", TaskID: "task1", RunID: "run1",
+ }
+ m := msg.toMap()
+ assert.Equal(t, "SetXCom", m["type"])
+ assert.Equal(t, 42, m["value"])
+ _, hasMapIndex := m["map_index"]
+ assert.False(t, hasMapIndex)
+ _, hasMappedLength := m["mapped_length"]
+ assert.False(t, hasMappedLength)
+ })
+
+ t.Run("non-nil map_index is emitted", func(t *testing.T) {
+ idx := 3
+ msg := SetXComMsg{
+ Key: "output", Value: 42,
+ DagID: "dag1", TaskID: "task1", RunID: "run1",
MapIndex: &idx,
+ }
+ m := msg.toMap()
+ assert.Equal(t, 3, m["map_index"])
+ })
+}
+
+func TestSucceedTaskMsgToMap(t *testing.T) {
+ endDate := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+ msg := SucceedTaskMsg{EndDate: endDate}
+ m := msg.toMap()
+ assert.Equal(t, "SucceedTask", m["type"])
+ assert.Equal(t, "2024-01-15T10:30:00Z", m["end_date"])
+ assert.Equal(t, []any{}, m["task_outlets"])
+ assert.Equal(t, []any{}, m["outlet_events"])
+}
+
+func TestSucceedTaskMsgToMap_PreservesSubsecondPrecision(t *testing.T) {
+ // end_date is formatted with RFC3339Nano so sub-second precision
+ // round-trips through asTime (which parses RFC3339Nano). Truncating to
+ // whole seconds would lose ordering for closely-spaced terminal events.
+ endDate := time.Date(2024, 1, 15, 10, 30, 0, 123456789, time.UTC)
+ msg := SucceedTaskMsg{EndDate: endDate}
+ m := msg.toMap()
+ assert.Equal(t, "2024-01-15T10:30:00.123456789Z", m["end_date"])
+}
+
+func TestTaskStateMsgToMap(t *testing.T) {
+ endDate := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+ msg := TaskStateMsg{State: TaskStateFailed, EndDate: endDate}
+ m := msg.toMap()
+ assert.Equal(t, "TaskState", m["type"])
+ assert.Equal(t, "failed", m["state"])
+}
+
+func TestTaskStateMsgToMap_PreservesSubsecondPrecision(t *testing.T) {
+ endDate := time.Date(2024, 1, 15, 10, 30, 0, 123456789, time.UTC)
+ msg := TaskStateMsg{State: TaskStateFailed, EndDate: endDate}
+ m := msg.toMap()
+ assert.Equal(t, "2024-01-15T10:30:00.123456789Z", m["end_date"])
+}
+
+func TestTaskStateConstants_WireValues(t *testing.T) {
+ // Pin each enum constant to the exact wire string Python's
+ // TaskInstanceState expects. Renaming these constants is fine;
+ // changing the wire value would silently break the protocol.
+ cases := map[TaskState]string{
+ TaskStateFailed: "failed",
+ TaskStateRemoved: "removed",
+ TaskStateSkipped: "skipped",
+ }
+ for state, wire := range cases {
+ assert.Equal(t, wire, string(state))
+ m := TaskStateMsg{State: state}.toMap()
+ assert.Equal(t, wire, m["state"], "wire value for %s", state)
+ }
+}
+
+func TestDecodeIncomingBodyDispatch(t *testing.T) {
+ t.Run("ConnectionResult", func(t *testing.T) {
+ body := map[string]any{"type": "ConnectionResult", "conn_id":
"x"}
+ result, err := decodeIncomingBody(body)
+ require.NoError(t, err)
+ _, ok := result.(*ConnectionResult)
+ assert.True(t, ok)
+ })
+
+ t.Run("nil", func(t *testing.T) {
+ result, err := decodeIncomingBody(nil)
+ require.NoError(t, err)
+ assert.Nil(t, result)
+ })
+
+ t.Run("unknown type", func(t *testing.T) {
+ _, err := decodeIncomingBody(map[string]any{"type":
"UnknownMsg"})
+ assert.Error(t, err)
+ })
+}
+
+func TestAsTime(t *testing.T) {
+ t.Run("from string", func(t *testing.T) {
+ ts, err := asTime("2024-01-15T10:30:00Z")
+ require.NoError(t, err)
+ assert.Equal(t, 2024, ts.Year())
+ assert.Equal(t, time.January, ts.Month())
+ })
+
+ t.Run("from time.Time", func(t *testing.T) {
+ now := time.Now()
+ ts, err := asTime(now)
+ require.NoError(t, err)
+ assert.Equal(t, now, ts)
+ })
+
+ t.Run("nil", func(t *testing.T) {
+ _, err := asTime(nil)
+ assert.Error(t, err)
+ })
+
+ t.Run("wrong type", func(t *testing.T) {
+ _, err := asTime(42)
+ assert.Error(t, err)
+ })
+}
diff --git a/go-sdk/pkg/sdkcontext/keys.go b/go-sdk/pkg/sdkcontext/keys.go
index 0dbc2c60194..ad83dfce3bd 100644
--- a/go-sdk/pkg/sdkcontext/keys.go
+++ b/go-sdk/pkg/sdkcontext/keys.go
@@ -22,6 +22,7 @@ type (
apiClientContextKey struct{}
workerContextKey struct{}
runtimeTIContextKey struct{}
+ sdkClientContextKey struct{}
)
var (
@@ -32,4 +33,11 @@ var (
RuntimeTIContextKey = runtimeTIContextKey{}
ApiClientContextKey = apiClientContextKey{}
WorkerContextKey = workerContextKey{}
+
+ // SdkClientContextKey, when present, holds an sdk.Client implementation
+ // that should be injected into task functions instead of constructing a
+ // default HTTP-backed client. The coordinator-mode runtime uses this to
+ // route task SDK calls (GetVariable, GetConnection, ...) over the
+ // supervisor comm socket rather than to the Execution API.
+ SdkClientContextKey = sdkClientContextKey{}
)
diff --git a/go-sdk/sdk/client.go b/go-sdk/sdk/client.go
index d2aa1f40351..f719fdc4787 100644
--- a/go-sdk/sdk/client.go
+++ b/go-sdk/sdk/client.go
@@ -57,6 +57,12 @@ func (*client) GetVariable(ctx context.Context, key string)
(string, error) {
}
return "", err
}
+ // TODO: register secret-named variables with a SecretsMasker so the
+ // returned value is automatically redacted from subsequent task logs,
+ // matching Python's airflow.models.variable.Variable.get behaviour.
+ // Pairs with the "TODO: mask secrets here" hook in
+ // pkg/worker/runner.go's task log handler — both halves are needed
+ // before secret masking actually works end-to-end.
return *resp.Value, nil
}
@@ -73,6 +79,13 @@ func (c *client) UnmarshalJSONVariable(ctx context.Context,
key string, pointer
func (*client) GetConnection(ctx context.Context, connID string) (Connection,
error) {
// TODO: Lookup connection from env var (and handle JSON + URI forms)
+ // TODO: register Connection.Password and sensitive-keyed entries of
+ // Connection.Extra with a SecretsMasker so they are auto-redacted from
+ // subsequent task logs, matching Python's
+ // airflow.models.connection.Connection.get behaviour. Pairs with the
+ // "TODO: mask secrets here" hook in pkg/worker/runner.go's task log
+ // handler and the matching TODO on GetVariable above.
+
httpClient :=
ctx.Value(sdkcontext.ApiClientContextKey).(api.ClientInterface)
resp, err := httpClient.Connections().Get(ctx, connID)
@@ -84,7 +97,7 @@ func (*client) GetConnection(ctx context.Context, connID
string) (Connection, er
return Connection{}, err
}
- return connFromAPIResponse(resp)
+ return ConnFromAPIResponse(resp)
}
func (c *client) PushXCom(
diff --git a/go-sdk/sdk/connection.go b/go-sdk/sdk/connection.go
index 35835c2d523..1d0bcf30b6f 100644
--- a/go-sdk/sdk/connection.go
+++ b/go-sdk/sdk/connection.go
@@ -110,7 +110,11 @@ func (c Connection) GetURI() *url.URL {
return uri
}
-func connFromAPIResponse(resp *api.ConnectionResponse) (Connection, error) {
+// ConnFromAPIResponse converts an Execution-API ConnectionResponse into the
+// SDK's Connection type. It is exported so other internal SDK packages (for
+// example, the coordinator-mode runtime in bundlev1server/impl/coord) can
+// reuse the same conversion.
+func ConnFromAPIResponse(resp *api.ConnectionResponse) (Connection, error) {
var err error
conn := Connection{
ID: resp.ConnId,
diff --git a/go-sdk/sdk/sdk.go b/go-sdk/sdk/sdk.go
index 679ab22d8f4..e4f5a61a331 100644
--- a/go-sdk/sdk/sdk.go
+++ b/go-sdk/sdk/sdk.go
@@ -28,6 +28,15 @@ const (
ConnectionEnvPrefix = "AIRFLOW_CONN_"
)
+// VariableClient reads Airflow Variables.
+//
+// Go has no function overloading, so the "give me the raw string" and
+// "give me a decoded struct" cases are split into two methods rather
+// than one polymorphic call: GetVariable returns the raw string,
+// UnmarshalJSONVariable decodes a JSON-encoded variable into a
+// caller-supplied pointer. This mirrors the std-lib split between
+// os.LookupEnv and json.Unmarshal — each method has one job, and the
+// caller picks based on how the variable was stored.
type VariableClient interface {
// GetVariable returns the value of an Airflow Variable.
//
@@ -43,6 +52,13 @@ type VariableClient interface {
// // Other errors here, such as http
network timeouts etc.
// }
GetVariable(ctx context.Context, key string) (string, error)
+
+ // UnmarshalJSONVariable fetches a variable and unmarshals its value
into
+ // pointer via json.Unmarshal. Use this when the variable was stored as
a
+ // JSON object, array, or number; for plain string variables call
+ // GetVariable directly.
+ //
+ // pointer must be a non-nil pointer, as required by encoding/json.
UnmarshalJSONVariable(ctx context.Context, key string, pointer any)
error
}