This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 18d51b76669 Add Logs HTTP Server to Go SDK worker. (#56101)
18d51b76669 is described below
commit 18d51b76669457bb6474a649f61c75cf667424ae
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Thu Sep 25 22:37:47 2025 +0100
Add Logs HTTP Server to Go SDK worker. (#56101)
This uses the same JWT auth mechanism as the the python side does, and it
when
properly configured lets the Airflow UI.
In order to test this, I added a `hostname` config option, this somewhat
mirrors the feature in Airflow's python codebase; over there it is a the
namy
of a module+fn/callable to use, but in go that level of dynamic behaviour is
not possible, so for go it is a simple hard-coded string.
---
go-sdk/README.md | 1 -
go-sdk/celery/app.go | 28 ++-
go-sdk/celery/commands/run.go | 22 ++-
go-sdk/celery/config.go | 1 -
go-sdk/go.mod | 4 +
go-sdk/go.sum | 8 +
go-sdk/pkg/logging/server/server.go | 328 +++++++++++++++++++++++++++++++
go-sdk/pkg/logging/server/server_test.go | 317 +++++++++++++++++++++++++++++
go-sdk/pkg/worker/runner.go | 4 +
9 files changed, 695 insertions(+), 18 deletions(-)
diff --git a/go-sdk/README.md b/go-sdk/README.md
index 8cdad559ff3..3e94af2929b 100644
--- a/go-sdk/README.md
+++ b/go-sdk/README.md
@@ -99,7 +99,6 @@ This SDK currently will:
A non-exhaustive list of features we have yet to implement
- Support for putting tasks into state other than success or
failed/up-for-retry (deferred, failed-without-retries etc.)
-- HTTP Log server to view logs from in-progress tasks
- Remote task logs (i.e. S3/GCS etc)
- XCom reading/writing from API server
- XCom reading/writing from other XCom backends
diff --git a/go-sdk/celery/app.go b/go-sdk/celery/app.go
index 303138cc88a..fd34bace6a0 100644
--- a/go-sdk/celery/app.go
+++ b/go-sdk/celery/app.go
@@ -24,6 +24,7 @@ import (
"log/slog"
"os"
"os/signal"
+ "time"
celery "github.com/marselester/gopher-celery"
celeryredis "github.com/marselester/gopher-celery/goredis"
@@ -33,6 +34,7 @@ import (
"github.com/apache/airflow/go-sdk/bundle/bundlev1"
"github.com/apache/airflow/go-sdk/bundle/bundlev1/bundlev1client"
"github.com/apache/airflow/go-sdk/pkg/bundles/shared"
+ "github.com/apache/airflow/go-sdk/pkg/logging/server"
)
type celeryTasksRunner struct {
@@ -40,9 +42,24 @@ type celeryTasksRunner struct {
}
func Run(ctx context.Context, config Config) error {
+ if len(config.Queues) == 0 {
+ return fmt.Errorf("no queues defined")
+ }
+
ctx, stop := signal.NotifyContext(ctx, os.Interrupt)
defer stop()
+ log := slog.Default().With("logger", "celery.app")
+
+ // TODO: use a config struct in the config, rather than error prone
strings!
+
+ logServer, err := server.NewFromConfig(viper.GetViper())
+ if err != nil {
+ return err
+ }
+
+ go logServer.ListenAndServe(ctx, time.Duration(0))
+
d := shared.NewDiscovery(viper.GetString("bundles.folder"), nil)
d.DiscoverBundles(ctx)
@@ -51,21 +68,18 @@ func Run(ctx context.Context, config Config) error {
})
defer func() {
if err := c.Close(); err != nil {
- slog.ErrorContext(ctx, "failed to close Redis client",
"err", err)
+ log.ErrorContext(ctx, "failed to close Redis client",
"err", err)
}
}()
if _, err := c.Ping(ctx).Result(); err != nil {
- slog.ErrorContext(ctx, "Redis connection failed", "err", err)
+ log.ErrorContext(ctx, "Redis connection failed", "err", err)
return err
}
broker := celeryredis.NewBroker(
celeryredis.WithClient(c),
)
- if len(config.Queues) == 0 {
- return fmt.Errorf("no queues defined")
- }
broker.Observe(config.Queues)
app := celery.NewApp(
@@ -81,14 +95,14 @@ func Run(ctx context.Context, config Config) error {
func(ctx context.Context, p *celery.TaskParam) error {
err := tasks.ExecuteWorkloadTask(ctx, p)
if err != nil {
- slog.ErrorContext(ctx, "Celery Task
failed", "err", err)
+ log.ErrorContext(ctx, "Celery Task
failed", "err", err)
}
return err
},
)
}
- slog.Info("waiting for tasks", "queues", config.Queues)
+ log.Info("waiting for tasks", "queues", config.Queues)
return app.Run(ctx)
}
diff --git a/go-sdk/celery/commands/run.go b/go-sdk/celery/commands/run.go
index 3cbcadc4a38..d4b2939996a 100644
--- a/go-sdk/celery/commands/run.go
+++ b/go-sdk/celery/commands/run.go
@@ -50,16 +50,20 @@ var runCmd = &cobra.Command{
}
func init() {
- runCmd.Flags().StringP("broker-address", "b", "", "Celery Broker
host:port to connect to")
- runCmd.Flags().
- StringP("execution-api-url", "e",
"http://localhost:8080/execution/", "Execution API to connect to")
- runCmd.Flags().StringSliceP("queues", "q", []string{"default"}, "Celery
queues to listen on")
- runCmd.Flags().
- StringP("bundles-folder", "", "", "Folder containing the
compiled dag bundle executables")
+ flags := runCmd.Flags()
+ flags.StringP("broker-address", "b", "", "Celery Broker host:port to
connect to")
+ flags.StringP(
+ "execution-api-url",
+ "e",
+ "http://localhost:8080/execution/",
+ "Execution API to connect to",
+ )
+ flags.StringSliceP("queues", "q", []string{"default"}, "Celery queues
to listen on")
+ flags.StringP("bundles-folder", "", "", "Folder containing the compiled
dag bundle executables")
runCmd.MarkFlagRequired("broker-address")
runCmd.MarkFlagRequired("bundles-folder")
- runCmd.Flags().
- SetAnnotation("broker-address", "viper-mapping",
[]string{"celery.broker-address"})
- runCmd.Flags().SetAnnotation("bundles-folder", "viper-mapping",
[]string{"bundles.folder"})
+ flags.SetAnnotation("broker-address", "viper-mapping",
[]string{"broker_address"})
+ flags.SetAnnotation("queues", "viper-mapping", []string{"queues"})
+ flags.SetAnnotation("bundles-folder", "viper-mapping",
[]string{"bundles.folder"})
}
diff --git a/go-sdk/celery/config.go b/go-sdk/celery/config.go
index cf6f38cbd60..f7e715da70f 100644
--- a/go-sdk/celery/config.go
+++ b/go-sdk/celery/config.go
@@ -19,7 +19,6 @@ package celery
type Config struct {
BrokerAddr string `mapstructure:"broker_address"`
- Port int `mapstructure:"port"`
Queues []string `mapstructure:"queues"`
BundlesFolder string `mapstructure:"bundles-folder"`
}
diff --git a/go-sdk/go.mod b/go-sdk/go.mod
index f6f20a970b7..bcf328d37f9 100644
--- a/go-sdk/go.mod
+++ b/go-sdk/go.mod
@@ -4,6 +4,7 @@ go 1.24
require (
github.com/cappuccinotm/slogx v1.4.2
+ github.com/golang-jwt/jwt/v5 v5.3.0
github.com/hashicorp/go-hclog v1.6.3
github.com/hashicorp/go-plugin v1.7.0
github.com/ivanpirog/coloredcobra v1.0.1
@@ -39,6 +40,8 @@ require (
github.com/spf13/cast v1.7.1 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
+ go.opentelemetry.io/otel v1.29.0 // indirect
+ go.opentelemetry.io/otel/trace v1.29.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sync v0.12.0 // indirect
@@ -57,6 +60,7 @@ require (
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/redis/go-redis/v9 v9.7.3
+ github.com/samber/slog-http v1.8.2
golang.org/x/sys v0.32.0 // indirect
golang.org/x/text v0.23.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
diff --git a/go-sdk/go.sum b/go-sdk/go.sum
index 64450712160..c8fea1d41b8 100644
--- a/go-sdk/go.sum
+++ b/go-sdk/go.sum
@@ -38,6 +38,8 @@ github.com/go-logfmt/logfmt v0.6.0
h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi
github.com/go-logfmt/logfmt v0.6.0/go.mod
h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-viper/mapstructure/v2 v2.4.0
h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod
h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
+github.com/golang-jwt/jwt/v5 v5.3.0
h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
+github.com/golang-jwt/jwt/v5 v5.3.0/go.mod
h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/protobuf v1.5.4
h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod
h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/gomodule/redigo v1.8.9
h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws=
@@ -94,6 +96,8 @@ github.com/rogpeppe/go-internal v1.12.0/go.mod
h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99
github.com/russross/blackfriday/v2 v2.1.0/go.mod
h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.7.0
h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
github.com/sagikazarmark/locafero v0.7.0/go.mod
h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
+github.com/samber/slog-http v1.8.2
h1:4UJ5n+kw8BYo1pn+mu03M/DTqAZj6FFOawhLj8MYENk=
+github.com/samber/slog-http v1.8.2/go.mod
h1:PAcQQrYFo5KM7Qbk50gNNwKEAMGCyfsw6GN5dI0iv9g=
github.com/sourcegraph/conc v0.3.0
h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod
h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
@@ -119,6 +123,10 @@ github.com/stretchr/testify v1.10.0
h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf
github.com/stretchr/testify v1.10.0/go.mod
h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0
h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod
h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
+go.opentelemetry.io/otel v1.29.0
h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw=
+go.opentelemetry.io/otel v1.29.0/go.mod
h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8=
+go.opentelemetry.io/otel/trace v1.29.0
h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4=
+go.opentelemetry.io/otel/trace v1.29.0/go.mod
h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod
h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
diff --git a/go-sdk/pkg/logging/server/server.go
b/go-sdk/pkg/logging/server/server.go
new file mode 100644
index 00000000000..18321aa2f41
--- /dev/null
+++ b/go-sdk/pkg/logging/server/server.go
@@ -0,0 +1,328 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Package server implements an HTTP server to make in-progress task logs
available to the Airflow UI
+package server
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "io/fs"
+ "log/slog"
+ "net"
+ "net/http"
+ "slices"
+ "strings"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+ sloghttp "github.com/samber/slog-http"
+ "github.com/spf13/viper"
+)
+
+const (
+ DefaultAudience = "task-instance-logs"
+ DefaultAlgorithm = "HS512"
+ DefaultWorkerLogServerPort = 8793
+)
+
+// We need it to be a pointer, so it can't be a const.
+
+var DefaultClockLeeway = 30 * time.Second
+
+type Config struct {
+ BaseLogFolder string `mapstructure:"base_log_folder"`
+ Port int `mapstructure:"port"`
+
+ SecretKey string `mapstructure:"secret_key"`
+ ClockLeeway *time.Duration `mapstructure:"clock_leeway"`
+ Algorithm string `mapstructure:"algorithm"`
+ Audiences []string `mapstructure:"audiences"`
+}
+
+var DefaultConfig = Config{
+ BaseLogFolder: ".",
+ Port: DefaultWorkerLogServerPort,
+ ClockLeeway: &DefaultClockLeeway,
+ Algorithm: DefaultAlgorithm,
+ Audiences: []string{DefaultAudience},
+}
+
+type LogServer struct {
+ server *http.Server
+ logger *slog.Logger
+
+ jwtParser *jwt.Parser
+ jwtKeyFunc func(*jwt.Token) (any, error)
+ fileServer http.Handler
+ fs dotFileHidingFileSystem
+}
+
+func init() {
+ sloghttp.RequestIDHeaderKey = "Correlation-ID"
+}
+
+func NewFromConfig(v *viper.Viper) (*LogServer, error) {
+ // TODO: Unmarshal doesn't seem to work with configs from env.
Something like needs binding? first?
+ cfg := DefaultConfig
+
+ cfg.BaseLogFolder = v.GetString("logging.base_log_folder")
+ cfg.SecretKey = v.GetString("logging.secret_key")
+
+ if v.IsSet("logging.worker_log_server_port") {
+ cfg.Port = v.GetInt("logging.worker_log_server_port")
+ }
+
+ if v.IsSet("logging.clock_leeway") {
+ if v.GetString("logging.clock_leeway") == "" {
+ cfg.ClockLeeway = nil
+ } else {
+ leeway := v.GetDuration("logging.clock_leeway")
+ cfg.ClockLeeway = &leeway
+ }
+ }
+
+ if v.IsSet("logging.audiences") {
+ cfg.Audiences = v.GetStringSlice("logging.audiences")
+ }
+
+ if v.IsSet("logging.algorithm") {
+ cfg.Algorithm = v.GetString("logging.algorithm")
+ }
+
+ return NewLogServer(nil, cfg)
+}
+
+func NewLogServer(logger *slog.Logger, cfg Config) (*LogServer, error) {
+ mux := http.NewServeMux()
+
+ if cfg.SecretKey == "" {
+ return nil, errors.New("logging.secret_key config option must
be provided")
+ }
+
+ if logger == nil {
+ logger = slog.Default().With("logger", "pkg.logging.server")
+ }
+ parser, err := makeJWTParser(cfg)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create log HTTP server: %w",
err)
+ }
+
+ handler := sloghttp.Recovery(mux)
+ handler = sloghttp.NewWithFilters(
+ logger.WithGroup("http"),
+ sloghttp.IgnorePath("/favicon.ico"),
+ )(
+ handler,
+ )
+
+ fs := dotFileHidingFileSystem{http.Dir(cfg.BaseLogFolder)}
+
+ server := &LogServer{
+ server: &http.Server{
+ Handler: handler,
+ Addr: fmt.Sprintf(":%d", cfg.Port),
+ },
+ logger: logger,
+
+ jwtParser: parser,
+ jwtKeyFunc: func(*jwt.Token) (any, error) { return
[]byte(cfg.SecretKey), nil },
+ fs: fs,
+ fileServer: http.FileServer(fs),
+ }
+
+ mux.Handle("GET /favicon.ico", http.NotFoundHandler())
+ mux.Handle("GET /log/", http.StripPrefix(
+ "/log/",
+ server.validateToken(http.HandlerFunc(server.ServeLog)),
+ ))
+
+ return server, nil
+}
+
+func makeJWTParser(cfg Config) (*jwt.Parser, error) {
+ if cfg.Algorithm == "" {
+ cfg.Algorithm = DefaultAlgorithm
+ }
+
+ if !slices.Contains(jwt.GetAlgorithms(), cfg.Algorithm) {
+ return nil, fmt.Errorf("unknown jwt signing algorithm %q",
cfg.Algorithm)
+ }
+
+ opts := []jwt.ParserOption{
+ jwt.WithAudience(cfg.Audiences...),
+ jwt.WithValidMethods([]string{cfg.Algorithm}),
+ jwt.WithExpirationRequired(),
+ }
+ if cfg.ClockLeeway != nil {
+ opts = append(opts, jwt.WithLeeway(*cfg.ClockLeeway))
+ }
+ parser := jwt.NewParser(opts...)
+ return parser, nil
+}
+
+func (l *LogServer) ListenAndServe(ctx context.Context, shutdownTimeout
time.Duration) error {
+ // This is what l.server.ListenAndServe does, but we copy it here so we
can call Serve directly in tests
+ addr := l.server.Addr
+ if addr == "" {
+ addr = ":http"
+ }
+ ln, err := net.Listen("tcp", addr)
+ if err != nil {
+ return err
+ }
+ return l.Serve(ctx, shutdownTimeout, ln)
+}
+
+func (l *LogServer) Serve(
+ ctx context.Context,
+ shutdownTimeout time.Duration,
+ ln net.Listener,
+) error {
+ uncancelalbleCtx := context.WithoutCancel(ctx)
+ idleConnsClosed := make(chan struct{})
+ go func() {
+ // Wait until the original context passed to `Run` is done
+ <-ctx.Done()
+
+ l.logger.Info("Shutting down log server")
+
+ shutdownCtx, cancel := context.WithTimeout(uncancelalbleCtx,
shutdownTimeout)
+ l.server.Shutdown(shutdownCtx)
+
+ close(idleConnsClosed)
+ cancel() // To avoid a context leak, we always close the
context when we are done
+ }()
+ l.server.BaseContext = func(sock net.Listener) context.Context {
+ l.logger.Info("Listening for logs", "addr",
"http://"+sock.Addr().String())
+ return uncancelalbleCtx
+ }
+
+ if err := l.server.Serve(ln); err != http.ErrServerClosed {
+ // Error starting or closing listener:
+ return err
+ }
+
+ <-idleConnsClosed
+ return nil
+}
+
+func (l *LogServer) validateToken(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ auth := r.Header.Get("Authorization")
+
+ if auth == "" {
+ w.WriteHeader(http.StatusForbidden)
+ w.Write([]byte("Authorization header missing"))
+ return
+ }
+
+ token, err := l.jwtParser.Parse(strings.TrimPrefix(auth,
"Bearer "), l.jwtKeyFunc)
+ if err != nil {
+ l.logger.Error(
+ "Token validation failed",
+ slog.Group("http", slog.String("id",
sloghttp.GetRequestID(r))),
+ "err",
+ err,
+ )
+ w.WriteHeader(http.StatusForbidden)
+ w.Write([]byte("Invalid Authorization header"))
+ return
+ }
+
+ fnameClaim, ok :=
token.Claims.(jwt.MapClaims)["filename"].(string)
+ if !ok || fnameClaim != r.URL.Path {
+ l.logger.Error(
+ "Claim is for a different path than the URL",
+ "fnClain",
+ fnameClaim,
+ "r.URL.Path",
+ r.URL.Path,
+ )
+ w.WriteHeader(http.StatusForbidden)
+ w.Write([]byte("Invalid Authorization header"))
+ return
+ }
+
+ next.ServeHTTP(w, r)
+ })
+}
+
+func (l *LogServer) ServeLog(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add(sloghttp.RequestIDHeaderKey, sloghttp.GetRequestID(r))
+ // TODO: validate token
+ l.fileServer.ServeHTTP(w, r)
+}
+
+// containsDotFile reports whether name contains a path element starting with
a period.
+// The name is assumed to be a delimited by forward slashes, as guaranteed
+// by the http.FileSystem interface.
+func containsDotFile(name string) bool {
+ parts := strings.Split(name, "/")
+ for _, part := range parts {
+ if strings.HasPrefix(part, ".") {
+ return true
+ }
+ }
+ return false
+}
+
+// dotFileHidingFile is the http.File use in dotFileHidingFileSystem.
+// It is used to wrap the Readdir method of http.File so that we can
+// remove files and directories that start with a period from its output.
+type dotFileHidingFile struct {
+ http.File
+}
+
+// Readdir is a wrapper around the Readdir method of the embedded File
+// that filters out all files that start with a period in their name.
+func (f dotFileHidingFile) Readdir(n int) (fis []fs.FileInfo, err error) {
+ files, err := f.File.Readdir(n)
+ for _, file := range files { // Filters out the dot files
+ if !strings.HasPrefix(file.Name(), ".") {
+ fis = append(fis, file)
+ }
+ }
+ if err == nil && n > 0 && len(fis) == 0 {
+ err = io.EOF
+ }
+ return
+}
+
+// dotFileHidingFileSystem is an http.FileSystem that hides
+// hidden "dot files" from being served.
+type dotFileHidingFileSystem struct {
+ http.FileSystem
+}
+
+// Open is a wrapper around the Open method of the embedded FileSystem
+// that serves a 403 permission error when name has a file or directory
+// with whose name starts with a period in its path.
+func (fsys dotFileHidingFileSystem) Open(name string) (http.File, error) {
+ slog.Default().Debug("Trying to open", "name", name)
+ if containsDotFile(name) { // If dot file, return 403 response
+ return nil, fs.ErrPermission
+ }
+
+ file, err := fsys.FileSystem.Open(name)
+ if err != nil {
+ return nil, err
+ }
+ return dotFileHidingFile{file}, nil
+}
diff --git a/go-sdk/pkg/logging/server/server_test.go
b/go-sdk/pkg/logging/server/server_test.go
new file mode 100644
index 00000000000..abcd8c1f92e
--- /dev/null
+++ b/go-sdk/pkg/logging/server/server_test.go
@@ -0,0 +1,317 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package server
+
+import (
+ "context"
+ "io/fs"
+ "log/slog"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "testing/fstest"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+ "github.com/spf13/viper"
+ "github.com/stretchr/testify/suite"
+)
+
+type ServerTestSuite struct {
+ suite.Suite
+}
+
+func (s *ServerTestSuite) TestContainsDotFile() {
+ tests := []struct {
+ input string
+ expected bool
+ }{
+ {"foo/bar/baz", false},
+ {".foo/bar", true},
+ {"foo/.bar/baz", true},
+ {"foo/bar/.baz", true},
+ {".", true},
+ {"./foo", true},
+ {"", false},
+ }
+ for _, tt := range tests {
+ s.Equal(tt.expected, containsDotFile(tt.input), "input: %s",
tt.input)
+ }
+}
+
+func (s *ServerTestSuite) TestDotFileHidingFileReaddir_FiltersDotFiles() {
+ // Use MapFS for backing files
+ fsMap := fstest.MapFS{
+ "foo.txt": {Data: []byte("foo")},
+ ".hidden.txt": {Data: []byte("hidden")},
+ "bar.log": {Data: []byte("bar")},
+ ".dotfile": {Data: []byte("dot")},
+ }
+ root, err := fs.Sub(fsMap, ".")
+ s.Require().NoError(err)
+ // Open all files in root
+ files, err := fs.ReadDir(root, ".")
+ s.Require().NoError(err)
+
+ // Convert DirEntry to FileInfo
+ fileInfos := make([]fs.FileInfo, 0, len(files))
+ for _, f := range files {
+ info, err := f.Info()
+ s.Require().NoError(err)
+ fileInfos = append(fileInfos, info)
+ }
+ df := dotFileHidingFile{mockHTTPFile{files: fileInfos}}
+ filtered, err := df.Readdir(-1)
+ s.NoError(err)
+ names := []string{}
+ for _, f := range filtered {
+ names = append(names, f.Name())
+ }
+ s.Equal([]string{"bar.log", "foo.txt"}, names)
+}
+
+type mockHTTPFile struct {
+ http.File
+ files []fs.FileInfo
+}
+
+func (m mockHTTPFile) Readdir(n int) ([]fs.FileInfo, error) {
+ return m.files, nil
+}
+
+func (s *ServerTestSuite) TestDotFileHidingFileSystem_Open_RejectsDotFiles() {
+ fsMap := fstest.MapFS{
+ ".hidden/file.txt": {Data: []byte("secret")},
+ }
+ fsys := dotFileHidingFileSystem{http.FS(fsMap)}
+ _, err := fsys.Open(".hidden/file.txt")
+ s.ErrorIs(err, fs.ErrPermission)
+}
+
+func (s *ServerTestSuite)
TestDotFileHidingFileSystem_Open_DelegatesForNormalFile() {
+ fsMap := fstest.MapFS{
+ "normal.txt": {Data: []byte("normal")},
+ }
+ fsys := dotFileHidingFileSystem{http.FS(fsMap)}
+ f, err := fsys.Open("normal.txt")
+ s.NoError(err)
+ _, ok := f.(dotFileHidingFile)
+ s.True(ok)
+}
+
+func (s *ServerTestSuite) TestMakeJWTParser_ValidAlgorithm() {
+ cfg := DefaultConfig
+ cfg.SecretKey = "testkey"
+ parser, err := makeJWTParser(cfg)
+ s.NoError(err)
+ s.NotNil(parser)
+}
+
+func (s *ServerTestSuite) TestMakeJWTParser_InvalidAlgorithm() {
+ cfg := DefaultConfig
+ cfg.Algorithm = "invalid-algo"
+ _, err := makeJWTParser(cfg)
+ s.Error(err)
+}
+
+func (s *ServerTestSuite) TestNewLogServer_MissingSecretKey() {
+ cfg := DefaultConfig
+ cfg.SecretKey = ""
+ _, err := NewLogServer(nil, cfg)
+ s.Error(err)
+}
+
+func (s *ServerTestSuite) TestNewFromConfig_MissingSecretKey() {
+ v := viper.New()
+ v.Set("logging.base_log_folder", ".")
+ v.Set("logging.secret_key", "")
+ _, err := NewFromConfig(v)
+ s.Error(err)
+}
+
+func (s *ServerTestSuite) TestNewFromConfig_AllFields() {
+ v := viper.New()
+ v.Set("logging.base_log_folder", ".")
+ v.Set("logging.secret_key", "mysecret")
+ v.Set("logging.worker_log_server_port", 12345)
+ v.Set("logging.clock_leeway", "10s")
+ v.Set("logging.audiences", []string{"aud1", "aud2"})
+ v.Set("logging.algorithm", "HS256")
+ srv, err := NewFromConfig(v)
+ s.NoError(err)
+ s.NotNil(srv)
+}
+
+func (s *ServerTestSuite) TestLogServer_ServeLog_ForwardsToFileServer() {
+ fsMap := fstest.MapFS{
+ "foo.log": {Data: []byte("logdata")},
+ }
+ cfg := DefaultConfig
+ cfg.SecretKey = "testkey"
+ cfg.BaseLogFolder = "."
+ srv, err := NewLogServer(slog.Default(), cfg)
+ s.NoError(err)
+ // override fileServer and fs to use MapFS
+ srv.fs = dotFileHidingFileSystem{http.FS(fsMap)}
+ srv.fileServer = http.FileServer(srv.fs)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/foo.log", nil)
+ srv.ServeLog(rec, req)
+ s.Contains([]int{http.StatusOK, http.StatusNotFound}, rec.Code)
+}
+
+func (s *ServerTestSuite) TestValidateToken_MissingAuthHeader() {
+ cfg := DefaultConfig
+ cfg.SecretKey = "testkey"
+ srv, err := NewLogServer(slog.Default(), cfg)
+ s.NoError(err)
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/foo.log", nil)
+ handler := srv.validateToken(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ s.Fail("should not call next")
+ }))
+ handler.ServeHTTP(rec, req)
+ s.Equal(http.StatusForbidden, rec.Code)
+ s.Contains(rec.Body.String(), "Authorization header missing")
+}
+
+func (s *ServerTestSuite) TestValidateToken_InvalidToken() {
+ cfg := DefaultConfig
+ cfg.SecretKey = "testkey"
+ srv, err := NewLogServer(slog.Default(), cfg)
+ s.NoError(err)
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/foo.log", nil)
+ req.Header.Set("Authorization", "Bearer invalidtoken")
+ handler := srv.validateToken(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ s.Fail("should not call next")
+ }))
+ handler.ServeHTTP(rec, req)
+ s.Equal(http.StatusForbidden, rec.Code)
+ s.Contains(rec.Body.String(), "Invalid Authorization header")
+}
+
+func (s *ServerTestSuite) TestValidateToken_ValidTokenWrongFilename() {
+ cfg := DefaultConfig
+ cfg.SecretKey = "testkey"
+ srv, err := NewLogServer(slog.Default(), cfg)
+ s.NoError(err)
+
+ claims := jwt.MapClaims{
+ "filename": "/other.log",
+ "aud": DefaultAudience,
+ "exp": time.Now().Add(1 * time.Minute).Unix(),
+ }
+ token := jwt.NewWithClaims(jwt.GetSigningMethod(cfg.Algorithm), claims)
+ signed, _ := token.SignedString([]byte(cfg.SecretKey))
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/foo.log", nil)
+ req.Header.Set("Authorization", "Bearer "+signed)
+ handler := srv.validateToken(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ s.Fail("should not call next")
+ }))
+ handler.ServeHTTP(rec, req)
+ s.Equal(http.StatusForbidden, rec.Code)
+ s.Contains(rec.Body.String(), "Invalid Authorization header")
+}
+
+func (s *ServerTestSuite) TestValidateToken_ValidToken_CallsNext() {
+ cfg := DefaultConfig
+ cfg.SecretKey = "testkey"
+ srv, err := NewLogServer(slog.Default(), cfg)
+ s.NoError(err)
+
+ claims := jwt.MapClaims{
+ "filename": "/foo.log",
+ "aud": DefaultAudience,
+ "exp": time.Now().Add(1 * time.Minute).Unix(),
+ }
+ token := jwt.NewWithClaims(jwt.GetSigningMethod(cfg.Algorithm), claims)
+ signed, _ := token.SignedString([]byte(cfg.SecretKey))
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/foo.log", nil)
+ req.Header.Set("Authorization", "Bearer "+signed)
+ called := false
+ handler := srv.validateToken(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ called = true
+ w.WriteHeader(http.StatusOK)
+ }))
+ handler.ServeHTTP(rec, req)
+ s.True(called)
+ s.Equal(http.StatusOK, rec.Code)
+}
+
+// TestRun starts the server on an available port and verifies it can be shut
down via context cancellation.
+func (s *ServerTestSuite) TestLogServer_Run_StartsAndShutsDownCleanly() {
+ cfg := DefaultConfig
+ cfg.SecretKey = "testkey"
+ cfg.Port = 0 // 0 means choose any available port
+
+ srv, err := NewLogServer(slog.Default(), cfg)
+ s.Require().NoError(err)
+
+ // Override handler to prevent actual file serving and simplify test
+ srv.server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r
*http.Request) {
+ w.WriteHeader(http.StatusOK)
+ })
+
+ // Listen on a random port
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ s.Require().NoError(err)
+ defer ln.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ runErrCh := make(chan error)
+ go func() {
+ runErrCh <- srv.Serve(ctx, 2*time.Second, ln)
+ }()
+
+ // Connect to the server to ensure it's up
+ client := &http.Client{Timeout: 1 * time.Second}
+
+ url := "http://" + ln.Addr().String()
+ var resp *http.Response
+ for i := 0; i < 10; i++ {
+ resp, err = client.Get(url)
+ if err == nil {
+ break
+ }
+ time.Sleep(100 * time.Millisecond)
+ }
+ s.Require().NoError(err)
+ s.Equal(http.StatusOK, resp.StatusCode)
+ // Now shut down
+ cancel()
+ // Wait for shutdown
+ select {
+ case err := <-runErrCh:
+ s.NoError(err)
+ case <-time.After(5 * time.Second):
+ s.Fail("server did not shut down in time")
+ }
+}
+
+func TestServerTestSuite(t *testing.T) {
+ suite.Run(t, new(ServerTestSuite))
+}
diff --git a/go-sdk/pkg/worker/runner.go b/go-sdk/pkg/worker/runner.go
index 165bd17ebe8..f5dd790fa39 100644
--- a/go-sdk/pkg/worker/runner.go
+++ b/go-sdk/pkg/worker/runner.go
@@ -175,6 +175,10 @@ func (h *heartbeater) Run(
}
func (w *worker) ExecuteTaskWorkload(ctx context.Context, workload
api.ExecuteTaskWorkload) error {
+ if hostname := viper.GetString("hostname"); hostname != "" {
+ Hostname = hostname
+ }
+
// Store the workload in the context so we can get at task id, etc,
variables
taskContext, cancelTaskCtx := context.WithCancelCause(
context.WithValue(ctx, sdkcontext.WorkloadContextKey, workload),