amoghrajesh commented on code in PR #56079: URL: https://github.com/apache/airflow/pull/56079#discussion_r2378646161
########## go-sdk/bundle/bundlev1/bundlev1server/plugin.go: ########## @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package bundlev1server + +import ( + "context" + "errors" + "log/slog" + "os" + "sync" + + "github.com/google/uuid" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-plugin" + "github.com/spf13/viper" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/apache/airflow/go-sdk/bundle/bundlev1" + proto "github.com/apache/airflow/go-sdk/internal/protov1" + "github.com/apache/airflow/go-sdk/pkg/api" + "github.com/apache/airflow/go-sdk/pkg/worker" +) + +// BundleGRPCPlugin is an implementation of the github.com/hashicorp/go-plugin#Plugin and +// github.com/hashicorp/go-plugin#GRPCPlugin interfaces, indicating how to +// serve [bundlev1.BundleProvider] as gRPC plugins for go-plugin. +type BundleGRPCPlugin struct { + plugin.NetRPCUnsupportedPlugin + Factory func() bundlev1.BundleProvider +} + +// Type assetion -- it must be a grpc plugin +var _ plugin.GRPCPlugin = (*BundleGRPCPlugin)(nil) + +// Type assetion -- it must be a rpc plugin (even if it just returns errors) Review Comment: ```suggestion // Type assertion -- it must be a rpc plugin (even if it just returns errors) ``` ########## go-sdk/bundle/bundlev1/bundlev1server/plugin.go: ########## @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package bundlev1server + +import ( + "context" + "errors" + "log/slog" + "os" + "sync" + + "github.com/google/uuid" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-plugin" + "github.com/spf13/viper" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/apache/airflow/go-sdk/bundle/bundlev1" + proto "github.com/apache/airflow/go-sdk/internal/protov1" + "github.com/apache/airflow/go-sdk/pkg/api" + "github.com/apache/airflow/go-sdk/pkg/worker" +) + +// BundleGRPCPlugin is an implementation of the github.com/hashicorp/go-plugin#Plugin and +// github.com/hashicorp/go-plugin#GRPCPlugin interfaces, indicating how to +// serve [bundlev1.BundleProvider] as gRPC plugins for go-plugin. +type BundleGRPCPlugin struct { + plugin.NetRPCUnsupportedPlugin + Factory func() bundlev1.BundleProvider +} + +// Type assetion -- it must be a grpc plugin Review Comment: ```suggestion // Type assertion -- it must be a grpc plugin ``` ########## go-sdk/pkg/bundles/shared/interface.go: ########## @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Package shared contains shared data between the host and plugins. Review Comment: ```suggestion // Package shared contains shared data between the worker and plugins. ``` ########## go-sdk/bundle/bundlev1/bundlev1server/plugin.go: ########## @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package bundlev1server + +import ( + "context" + "errors" + "log/slog" + "os" + "sync" + + "github.com/google/uuid" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-plugin" + "github.com/spf13/viper" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/apache/airflow/go-sdk/bundle/bundlev1" + proto "github.com/apache/airflow/go-sdk/internal/protov1" + "github.com/apache/airflow/go-sdk/pkg/api" + "github.com/apache/airflow/go-sdk/pkg/worker" +) + +// BundleGRPCPlugin is an implementation of the github.com/hashicorp/go-plugin#Plugin and +// github.com/hashicorp/go-plugin#GRPCPlugin interfaces, indicating how to +// serve [bundlev1.BundleProvider] as gRPC plugins for go-plugin. +type BundleGRPCPlugin struct { + plugin.NetRPCUnsupportedPlugin + Factory func() bundlev1.BundleProvider +} + +// Type assetion -- it must be a grpc plugin +var _ plugin.GRPCPlugin = (*BundleGRPCPlugin)(nil) + +// Type assetion -- it must be a rpc plugin (even if it just returns errors) +var _ plugin.Plugin = (*BundleGRPCPlugin)(nil) + +func (p *BundleGRPCPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error { + impl := p.Factory() + proto.RegisterDagBundleServer(s, &server{ + Impl: impl, + }) + return nil +} + +// GRPCClient implements plugin.GRPCPlugin. +func (p *BundleGRPCPlugin) GRPCClient( + ctx context.Context, + broker *plugin.GRPCBroker, + c *grpc.ClientConn, +) (interface{}, error) { + return nil, errors.New("bundlev1server only implements gRPC servers") +} + +type server struct { + sync.RWMutex + proto.UnimplementedDagBundleServer + Impl bundlev1.BundleProvider + + bundle bundlev1.Bundle +} + +func (g *server) GetMetadata( + ctx context.Context, + _ *proto.GetMetadata_Request, +) (*proto.GetMetadata_Response, error) { + ver := g.Impl.GetBundleVersion() + + info := proto.BundleInfo_builder{ + Name: &ver.Name, + Version: ver.Version, + }.Build() + resp := proto.GetMetadata_Response_builder{Bundle: info}.Build() + + return resp, nil +} + +func (g *server) getCachedBundle(_ context.Context) bundlev1.Bundle { + g.RWMutex.RLock() + defer g.RWMutex.RUnlock() + return g.bundle +} + +func (g *server) getBundle(ctx context.Context) (bundlev1.Bundle, error) { + if b := g.getCachedBundle(ctx); b != nil { + return b, nil + } + + g.RWMutex.Lock() + defer g.RWMutex.Unlock() + + reg := bundlev1.New() + err := g.Impl.RegisterDags(reg) + if err != nil { + return nil, err + } + g.bundle = reg + + return g.bundle, err +} + +func (g *server) Execute( + ctx context.Context, + req *proto.Execute_Request, +) (*proto.Execute_Response, error) { + if executeTask := req.GetTask(); executeTask != nil { + return nil, g.executeTask(ctx, executeTask) + } + + which := req.WhichWorkload().String() + return nil, status.Errorf(codes.Unimplemented, "Unimplmeneted workload %q", which) +} + +func (g *server) executeTask(ctx context.Context, executeTask *proto.ExecuteTaskWorkload) error { + ti := executeTask.GetTi() + bundle := executeTask.GetBundleInfo() + id, err := uuid.Parse(ti.GetId().GetValue()) + if err != nil { + return status.Errorf(codes.InvalidArgument, "unable to parse UUID: %s", err) + } + workload := api.ExecuteTaskWorkload{ + Token: executeTask.GetToken(), + TI: bundlev1.TaskInstance{ + DagId: ti.GetDagId(), + Id: id, + RunId: ti.GetRunId(), + TaskId: ti.GetTaskId(), + TryNumber: int(ti.GetTryNumber()), + // TODO: support otel context carrier + // ContextCarrier: (map[string]any)(ti.GetOtelContext()), + }, + BundleInfo: bundlev1.BundleInfo{ + Name: bundle.GetName(), + }, + } + + if ti.HasMapIndex() { + idx := int(ti.GetMapIndex()) + workload.TI.MapIndex = &idx + } + + if bundle.HasVersion() { + ver := bundle.GetVersion() + workload.BundleInfo.Version = &ver + } + + if executeTask.HasLogPath() { + path := executeTask.GetLogPath() + workload.LogPath = &path + } + + dagBundle, err := g.getBundle(ctx) + if err != nil { + return status.Errorf(codes.NotFound, "dag bundle not found: %#v", workload.BundleInfo) + } + + w := worker.NewWithBundle(dagBundle, slog.Default()) + + hclog.Default(). + Warn("Does viper work", "url", viper.GetString("execution-api-url"), "env", os.Getenv("AIRFLOW__EXECUTION_API_URL")) + w, err = w.WithServer(viper.GetString("execution-api-url")) + if err != nil { + slog.ErrorContext(ctx, "Error setting ExecutionAPI sxerver for worker", "err", err) Review Comment: ```suggestion slog.ErrorContext(ctx, "Error setting ExecutionAPI server for worker", "err", err) ``` ########## go-sdk/celery/app.go: ########## @@ -80,27 +72,58 @@ func Run(ctx context.Context, config Config) error { celery.WithBroker(broker), ) + tasks := &celeryTasksRunner{d} + fmt.Printf("%#v\n", viper.AllKeys()) + for _, queue := range config.Queues { app.Register( "execute_workload", queue, func(ctx context.Context, p *celery.TaskParam) error { - p.NameArgs("payload") - payload := p.MustString("payload") - - var workload api.ExecuteTaskWorkload - if err := json.Unmarshal([]byte(payload), &workload); err != nil { - return err + err := tasks.ExecuteWorkloadTask(ctx, p) + if err != nil { + slog.ErrorContext(ctx, "Celery Task failed", "err", err) } - return worker.ExecuteTaskWorkload(ctx, workload) + return err }, ) } slog.Info("waiting for tasks", "queues", config.Queues) - err = app.Run(ctx) + return app.Run(ctx) +} + +func (state *celeryTasksRunner) ExecuteWorkloadTask( + ctx context.Context, + p *celery.TaskParam, +) error { + p.NameArgs("payload") + payload := p.MustString("payload") + + var workload bundlev1.ExecuteTaskWorkload + if err := json.Unmarshal([]byte(payload), &workload); err != nil { + return err + } + + client, err := state.ClientForBundle(workload.BundleInfo.Name, workload.BundleInfo.Version) + if err != nil { + // TODO: This Should write something to the log file + return err + } + // TODO: Don't kill the backend process here, but instead kill it after a bit of idleness. See if we can + // reuse the process for multiple tasks too + defer client.Kill() Review Comment: Maybe use context with timeout to kill it? ``` ctx, cancel := context.WithTimeout(context.Background(), <some timeout>) defer cancel() go func() { <-ctx.Done() client.Kill() }() ``` ########## go-sdk/bundle/bundlev1/bundlev1server/doc.go: ########## @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one Review Comment: Doesn't have to be done but bundleserver can live in `pkg/bundles`? Or is there a reason for new top level directory for server and client for bundles? ########## go-sdk/pkg/bundles/shared/interface.go: ########## @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Package shared contains shared data between the host and plugins. +package shared + +import ( + "github.com/hashicorp/go-plugin" +) + +// Handshake is a common handshake that is shared by plugin and host. Review Comment: ```suggestion // Handshake is a common handshake that is shared by plugin and worker. ``` ########## go-sdk/pkg/bundles/shared/interface.go: ########## @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one Review Comment: Again a nit, but could we call this `handshake.go`? ########## go-sdk/pkg/bundles/shared/discovery.go: ########## @@ -0,0 +1,252 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package shared + +import ( + "context" + "errors" + "fmt" + "log/slog" + "os" + "os/exec" + "path/filepath" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-plugin" + "github.com/spf13/viper" + + "github.com/apache/airflow/go-sdk/bundle/bundlev1" + "github.com/apache/airflow/go-sdk/pkg/bundles/bundleclientv1" + "github.com/apache/airflow/go-sdk/pkg/logging/shclog" +) + +// Discovery handles finding and loading DAG bundles +type Discovery struct { + logger *slog.Logger + hcLogger hclog.Logger + bundlesFolder string + bundles map[string]map[string]string // bundle_name -> version -> path to binary +} + +var BundleNotFound = errors.New("") + +// NewDiscovery creates a an object responsible for finding and looking for possible DAG bundle binaries +func NewDiscovery(bundlesFolder string, logger *slog.Logger) *Discovery { + if logger == nil { + logger = slog.Default() + } + + return &Discovery{ + logger: logger, + bundlesFolder: bundlesFolder, + bundles: make(map[string]map[string]string), + hcLogger: shclog.New(logger), + } +} + +func (d *Discovery) versionsForNamedBundle(name string) map[string]string { + versions, exists := d.bundles[name] + if exists { + return versions + } + + // We couldn't find the specific named bundle + + // First we see if a "default" was configured + versions, exists = d.bundles[viper.GetString("bundles.default_bundle")] + if exists { + return versions + } + + // Else we see if there is exactly one bundle name registered, and if so we return that + if len(d.bundles) == 1 { + exists = true + for key, versions := range d.bundles { + // Just pull the first value out + d.logger.Debug( + "Using sole bundle as fallback bundle", + "bundle", + name, + "fallback_bundle", + key, + ) + return versions + } + } + return nil +} + +func (d *Discovery) ClientForBundle(name string, version *string) (*plugin.Client, error) { + var key string + if version != nil { + key = *version + } + + versions := d.versionsForNamedBundle(name) + if versions == nil { + // We couldn't find the specific named bundle + return nil, fmt.Errorf( + "%wno dag bundle named %q found (and no fallback suitable)", + BundleNotFound, + name, + ) + } + + cmd, exists := versions[key] + if !exists { + // We couldn't find the specific version, but lets see if we have just a single version and use that in + // its place + if key == "" && len(versions) == 1 { + exists = true + for key, cmd = range versions { + // Just pull the first value out + d.logger.Info( + "Unable to find unversioned bundle as requested, using only version as fallback", + "bundle", + name, + "fallback_version", + key, + ) + break + } + } + } + + if !exists { + if key == "" { + key = "<unversioned>" + } + return nil, fmt.Errorf("%wno version %q found for dag bundle %q", BundleNotFound, key, name) + } + return d.makeClient(cmd), nil +} + +func (d *Discovery) DiscoverBundles(ctx context.Context) error { + // Find all files in the bundles directory + files, err := filepath.Glob(filepath.Join(d.bundlesFolder, "*")) + if err != nil { + return fmt.Errorf("failed to read bundles directory: %w", err) + } + + self, err := os.Executable() + if err != nil { + self = "" + } + + for _, file := range files { + if ctx.Err() != nil { + // Check if we are done. + return ctx.Err() + } + + // Check if file is executable + if !isExecutable(file) { + continue + } + + abs, err := filepath.Abs(file) + if err != nil { + d.logger.Warn("Unable to load resolve file path", "file", file, "err", err) + continue + } + if self != "" && self == abs { + d.logger.Warn("Not trying to load ourselves as a plugin", "file", file) + continue + } + + d.logger.Debug("Found potential bundle", slog.String("path", file)) + + // TODO: Use a sync.WaitGroup to parallelize running multiple procs without blowing concurrency up and fork-bombing + // the host + bundle, err := d.getBundleVersionInfo(file) + if err != nil { + d.logger.Warn("Unable to load BundleMetadata", "file", file, "err", err) + continue Review Comment: This would silently skip the bundle if lets say the binary has wrong permissions? ########## go-sdk/celery/app.go: ########## @@ -80,27 +72,58 @@ func Run(ctx context.Context, config Config) error { celery.WithBroker(broker), ) + tasks := &celeryTasksRunner{d} + fmt.Printf("%#v\n", viper.AllKeys()) Review Comment: Accidentally added? ########## go-sdk/pkg/worker/task.go: ########## @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package worker + +import ( + "context" + "fmt" + "log/slog" + "reflect" + "runtime" + + "github.com/apache/airflow/go-sdk/sdk" +) + +type ( + taskFunction struct { + fn reflect.Value + fullName string + } + Task interface { + Execute(ctx context.Context, logger *slog.Logger) error + } +) + +var _ Task = (*taskFunction)(nil) + +func NewTaskFunction(fn any) (Task, error) { + v := reflect.ValueOf(fn) + fullName := runtime.FuncForPC(v.Pointer()).Name() + f := &taskFunction{v, fullName} + return f, f.validateFn(v.Type()) +} + +func (f *taskFunction) Execute(ctx context.Context, logger *slog.Logger) error { + fnType := f.fn.Type() + + reflectArgs := make([]reflect.Value, fnType.NumIn()) + for i := range reflectArgs { Review Comment: Not right now, but worth considering. We do reflection args validation during "execution" here, could we maybe do it during registration? (execution is a hot path) ########## go-sdk/bundle/bundlev1/bundlev1server/doc.go: ########## @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one Review Comment: pkg/bundles/bundlev1/server/ <= server implementation pkg/bundles/bundlev1/client/ <= client implementation Maybe -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
