amoghrajesh commented on code in PR #56079:
URL: https://github.com/apache/airflow/pull/56079#discussion_r2378646161


##########
go-sdk/bundle/bundlev1/bundlev1server/plugin.go:
##########
@@ -0,0 +1,185 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package bundlev1server
+
+import (
+       "context"
+       "errors"
+       "log/slog"
+       "os"
+       "sync"
+
+       "github.com/google/uuid"
+       "github.com/hashicorp/go-hclog"
+       "github.com/hashicorp/go-plugin"
+       "github.com/spf13/viper"
+       "google.golang.org/grpc"
+       "google.golang.org/grpc/codes"
+       "google.golang.org/grpc/status"
+
+       "github.com/apache/airflow/go-sdk/bundle/bundlev1"
+       proto "github.com/apache/airflow/go-sdk/internal/protov1"
+       "github.com/apache/airflow/go-sdk/pkg/api"
+       "github.com/apache/airflow/go-sdk/pkg/worker"
+)
+
+// BundleGRPCPlugin is an implementation of the 
github.com/hashicorp/go-plugin#Plugin and
+// github.com/hashicorp/go-plugin#GRPCPlugin interfaces, indicating how to
+// serve [bundlev1.BundleProvider] as gRPC plugins for go-plugin.
+type BundleGRPCPlugin struct {
+       plugin.NetRPCUnsupportedPlugin
+       Factory func() bundlev1.BundleProvider
+}
+
+// Type assetion -- it must be a grpc plugin
+var _ plugin.GRPCPlugin = (*BundleGRPCPlugin)(nil)
+
+// Type assetion -- it must be a rpc plugin (even if it just returns errors)

Review Comment:
   ```suggestion
   // Type assertion -- it must be a rpc plugin (even if it just returns errors)
   ```



##########
go-sdk/bundle/bundlev1/bundlev1server/plugin.go:
##########
@@ -0,0 +1,185 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package bundlev1server
+
+import (
+       "context"
+       "errors"
+       "log/slog"
+       "os"
+       "sync"
+
+       "github.com/google/uuid"
+       "github.com/hashicorp/go-hclog"
+       "github.com/hashicorp/go-plugin"
+       "github.com/spf13/viper"
+       "google.golang.org/grpc"
+       "google.golang.org/grpc/codes"
+       "google.golang.org/grpc/status"
+
+       "github.com/apache/airflow/go-sdk/bundle/bundlev1"
+       proto "github.com/apache/airflow/go-sdk/internal/protov1"
+       "github.com/apache/airflow/go-sdk/pkg/api"
+       "github.com/apache/airflow/go-sdk/pkg/worker"
+)
+
+// BundleGRPCPlugin is an implementation of the 
github.com/hashicorp/go-plugin#Plugin and
+// github.com/hashicorp/go-plugin#GRPCPlugin interfaces, indicating how to
+// serve [bundlev1.BundleProvider] as gRPC plugins for go-plugin.
+type BundleGRPCPlugin struct {
+       plugin.NetRPCUnsupportedPlugin
+       Factory func() bundlev1.BundleProvider
+}
+
+// Type assetion -- it must be a grpc plugin

Review Comment:
   ```suggestion
   // Type assertion -- it must be a grpc plugin
   ```



##########
go-sdk/pkg/bundles/shared/interface.go:
##########
@@ -0,0 +1,31 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Package shared contains shared data between the host and plugins.

Review Comment:
   ```suggestion
   // Package shared contains shared data between the worker and plugins.
   ```



##########
go-sdk/bundle/bundlev1/bundlev1server/plugin.go:
##########
@@ -0,0 +1,185 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package bundlev1server
+
+import (
+       "context"
+       "errors"
+       "log/slog"
+       "os"
+       "sync"
+
+       "github.com/google/uuid"
+       "github.com/hashicorp/go-hclog"
+       "github.com/hashicorp/go-plugin"
+       "github.com/spf13/viper"
+       "google.golang.org/grpc"
+       "google.golang.org/grpc/codes"
+       "google.golang.org/grpc/status"
+
+       "github.com/apache/airflow/go-sdk/bundle/bundlev1"
+       proto "github.com/apache/airflow/go-sdk/internal/protov1"
+       "github.com/apache/airflow/go-sdk/pkg/api"
+       "github.com/apache/airflow/go-sdk/pkg/worker"
+)
+
+// BundleGRPCPlugin is an implementation of the 
github.com/hashicorp/go-plugin#Plugin and
+// github.com/hashicorp/go-plugin#GRPCPlugin interfaces, indicating how to
+// serve [bundlev1.BundleProvider] as gRPC plugins for go-plugin.
+type BundleGRPCPlugin struct {
+       plugin.NetRPCUnsupportedPlugin
+       Factory func() bundlev1.BundleProvider
+}
+
+// Type assetion -- it must be a grpc plugin
+var _ plugin.GRPCPlugin = (*BundleGRPCPlugin)(nil)
+
+// Type assetion -- it must be a rpc plugin (even if it just returns errors)
+var _ plugin.Plugin = (*BundleGRPCPlugin)(nil)
+
+func (p *BundleGRPCPlugin) GRPCServer(broker *plugin.GRPCBroker, s 
*grpc.Server) error {
+       impl := p.Factory()
+       proto.RegisterDagBundleServer(s, &server{
+               Impl: impl,
+       })
+       return nil
+}
+
+// GRPCClient implements plugin.GRPCPlugin.
+func (p *BundleGRPCPlugin) GRPCClient(
+       ctx context.Context,
+       broker *plugin.GRPCBroker,
+       c *grpc.ClientConn,
+) (interface{}, error) {
+       return nil, errors.New("bundlev1server only implements gRPC servers")
+}
+
+type server struct {
+       sync.RWMutex
+       proto.UnimplementedDagBundleServer
+       Impl bundlev1.BundleProvider
+
+       bundle bundlev1.Bundle
+}
+
+func (g *server) GetMetadata(
+       ctx context.Context,
+       _ *proto.GetMetadata_Request,
+) (*proto.GetMetadata_Response, error) {
+       ver := g.Impl.GetBundleVersion()
+
+       info := proto.BundleInfo_builder{
+               Name:    &ver.Name,
+               Version: ver.Version,
+       }.Build()
+       resp := proto.GetMetadata_Response_builder{Bundle: info}.Build()
+
+       return resp, nil
+}
+
+func (g *server) getCachedBundle(_ context.Context) bundlev1.Bundle {
+       g.RWMutex.RLock()
+       defer g.RWMutex.RUnlock()
+       return g.bundle
+}
+
+func (g *server) getBundle(ctx context.Context) (bundlev1.Bundle, error) {
+       if b := g.getCachedBundle(ctx); b != nil {
+               return b, nil
+       }
+
+       g.RWMutex.Lock()
+       defer g.RWMutex.Unlock()
+
+       reg := bundlev1.New()
+       err := g.Impl.RegisterDags(reg)
+       if err != nil {
+               return nil, err
+       }
+       g.bundle = reg
+
+       return g.bundle, err
+}
+
+func (g *server) Execute(
+       ctx context.Context,
+       req *proto.Execute_Request,
+) (*proto.Execute_Response, error) {
+       if executeTask := req.GetTask(); executeTask != nil {
+               return nil, g.executeTask(ctx, executeTask)
+       }
+
+       which := req.WhichWorkload().String()
+       return nil, status.Errorf(codes.Unimplemented, "Unimplmeneted workload 
%q", which)
+}
+
+func (g *server) executeTask(ctx context.Context, executeTask 
*proto.ExecuteTaskWorkload) error {
+       ti := executeTask.GetTi()
+       bundle := executeTask.GetBundleInfo()
+       id, err := uuid.Parse(ti.GetId().GetValue())
+       if err != nil {
+               return status.Errorf(codes.InvalidArgument, "unable to parse 
UUID: %s", err)
+       }
+       workload := api.ExecuteTaskWorkload{
+               Token: executeTask.GetToken(),
+               TI: bundlev1.TaskInstance{
+                       DagId:     ti.GetDagId(),
+                       Id:        id,
+                       RunId:     ti.GetRunId(),
+                       TaskId:    ti.GetTaskId(),
+                       TryNumber: int(ti.GetTryNumber()),
+                       // TODO: support otel context carrier
+                       // ContextCarrier: 
(map[string]any)(ti.GetOtelContext()),
+               },
+               BundleInfo: bundlev1.BundleInfo{
+                       Name: bundle.GetName(),
+               },
+       }
+
+       if ti.HasMapIndex() {
+               idx := int(ti.GetMapIndex())
+               workload.TI.MapIndex = &idx
+       }
+
+       if bundle.HasVersion() {
+               ver := bundle.GetVersion()
+               workload.BundleInfo.Version = &ver
+       }
+
+       if executeTask.HasLogPath() {
+               path := executeTask.GetLogPath()
+               workload.LogPath = &path
+       }
+
+       dagBundle, err := g.getBundle(ctx)
+       if err != nil {
+               return status.Errorf(codes.NotFound, "dag bundle not found: 
%#v", workload.BundleInfo)
+       }
+
+       w := worker.NewWithBundle(dagBundle, slog.Default())
+
+       hclog.Default().
+               Warn("Does viper work", "url", 
viper.GetString("execution-api-url"), "env", 
os.Getenv("AIRFLOW__EXECUTION_API_URL"))
+       w, err = w.WithServer(viper.GetString("execution-api-url"))
+       if err != nil {
+               slog.ErrorContext(ctx, "Error setting ExecutionAPI sxerver for 
worker", "err", err)

Review Comment:
   ```suggestion
                slog.ErrorContext(ctx, "Error setting ExecutionAPI server for 
worker", "err", err)
   ```



##########
go-sdk/celery/app.go:
##########
@@ -80,27 +72,58 @@ func Run(ctx context.Context, config Config) error {
                celery.WithBroker(broker),
        )
 
+       tasks := &celeryTasksRunner{d}
+       fmt.Printf("%#v\n", viper.AllKeys())
+
        for _, queue := range config.Queues {
                app.Register(
                        "execute_workload",
                        queue,
                        func(ctx context.Context, p *celery.TaskParam) error {
-                               p.NameArgs("payload")
-                               payload := p.MustString("payload")
-
-                               var workload api.ExecuteTaskWorkload
-                               if err := json.Unmarshal([]byte(payload), 
&workload); err != nil {
-                                       return err
+                               err := tasks.ExecuteWorkloadTask(ctx, p)
+                               if err != nil {
+                                       slog.ErrorContext(ctx, "Celery Task 
failed", "err", err)
                                }
-                               return worker.ExecuteTaskWorkload(ctx, workload)
+                               return err
                        },
                )
        }
 
        slog.Info("waiting for tasks", "queues", config.Queues)
-       err = app.Run(ctx)
+       return app.Run(ctx)
+}
+
+func (state *celeryTasksRunner) ExecuteWorkloadTask(
+       ctx context.Context,
+       p *celery.TaskParam,
+) error {
+       p.NameArgs("payload")
+       payload := p.MustString("payload")
+
+       var workload bundlev1.ExecuteTaskWorkload
+       if err := json.Unmarshal([]byte(payload), &workload); err != nil {
+               return err
+       }
+
+       client, err := state.ClientForBundle(workload.BundleInfo.Name, 
workload.BundleInfo.Version)
+       if err != nil {
+               // TODO: This Should write something to the log file
+               return err
+       }
+       // TODO: Don't kill the backend process here, but instead kill it after 
a bit of idleness. See if we can
+       // reuse the process for multiple tasks too
+       defer client.Kill()

Review Comment:
   Maybe use context with timeout to kill it?
   
   
   ```
   ctx, cancel := context.WithTimeout(context.Background(), <some timeout>)
   defer cancel()
   go func() {
       <-ctx.Done()
       client.Kill()
   }()
   ```



##########
go-sdk/bundle/bundlev1/bundlev1server/doc.go:
##########
@@ -0,0 +1,25 @@
+// Licensed to the Apache Software Foundation (ASF) under one

Review Comment:
   Doesn't have to be done but bundleserver can live in `pkg/bundles`? Or is 
there a reason for new top level directory for server and client for bundles?



##########
go-sdk/pkg/bundles/shared/interface.go:
##########
@@ -0,0 +1,31 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Package shared contains shared data between the host and plugins.
+package shared
+
+import (
+       "github.com/hashicorp/go-plugin"
+)
+
+// Handshake is a common handshake that is shared by plugin and host.

Review Comment:
   ```suggestion
   // Handshake is a common handshake that is shared by plugin and worker.
   ```



##########
go-sdk/pkg/bundles/shared/interface.go:
##########
@@ -0,0 +1,31 @@
+// Licensed to the Apache Software Foundation (ASF) under one

Review Comment:
   Again a nit, but could we call this `handshake.go`?



##########
go-sdk/pkg/bundles/shared/discovery.go:
##########
@@ -0,0 +1,252 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package shared
+
+import (
+       "context"
+       "errors"
+       "fmt"
+       "log/slog"
+       "os"
+       "os/exec"
+       "path/filepath"
+
+       "github.com/hashicorp/go-hclog"
+       "github.com/hashicorp/go-plugin"
+       "github.com/spf13/viper"
+
+       "github.com/apache/airflow/go-sdk/bundle/bundlev1"
+       "github.com/apache/airflow/go-sdk/pkg/bundles/bundleclientv1"
+       "github.com/apache/airflow/go-sdk/pkg/logging/shclog"
+)
+
+// Discovery handles finding and loading DAG bundles
+type Discovery struct {
+       logger        *slog.Logger
+       hcLogger      hclog.Logger
+       bundlesFolder string
+       bundles       map[string]map[string]string // bundle_name -> version -> 
path to binary
+}
+
+var BundleNotFound = errors.New("")
+
+// NewDiscovery creates a an object responsible for finding and looking for 
possible DAG bundle binaries
+func NewDiscovery(bundlesFolder string, logger *slog.Logger) *Discovery {
+       if logger == nil {
+               logger = slog.Default()
+       }
+
+       return &Discovery{
+               logger:        logger,
+               bundlesFolder: bundlesFolder,
+               bundles:       make(map[string]map[string]string),
+               hcLogger:      shclog.New(logger),
+       }
+}
+
+func (d *Discovery) versionsForNamedBundle(name string) map[string]string {
+       versions, exists := d.bundles[name]
+       if exists {
+               return versions
+       }
+
+       // We couldn't find the specific named bundle
+
+       // First we see if a "default" was configured
+       versions, exists = d.bundles[viper.GetString("bundles.default_bundle")]
+       if exists {
+               return versions
+       }
+
+       // Else we see if there is exactly one bundle name registered, and if 
so we return that
+       if len(d.bundles) == 1 {
+               exists = true
+               for key, versions := range d.bundles {
+                       // Just pull the first value out
+                       d.logger.Debug(
+                               "Using sole bundle as fallback bundle",
+                               "bundle",
+                               name,
+                               "fallback_bundle",
+                               key,
+                       )
+                       return versions
+               }
+       }
+       return nil
+}
+
+func (d *Discovery) ClientForBundle(name string, version *string) 
(*plugin.Client, error) {
+       var key string
+       if version != nil {
+               key = *version
+       }
+
+       versions := d.versionsForNamedBundle(name)
+       if versions == nil {
+               // We couldn't find the specific named bundle
+               return nil, fmt.Errorf(
+                       "%wno dag bundle named %q found (and no fallback 
suitable)",
+                       BundleNotFound,
+                       name,
+               )
+       }
+
+       cmd, exists := versions[key]
+       if !exists {
+               // We couldn't find the specific version, but lets see if we 
have just a single version and use that in
+               // its place
+               if key == "" && len(versions) == 1 {
+                       exists = true
+                       for key, cmd = range versions {
+                               // Just pull the first value out
+                               d.logger.Info(
+                                       "Unable to find unversioned bundle as 
requested, using only version as fallback",
+                                       "bundle",
+                                       name,
+                                       "fallback_version",
+                                       key,
+                               )
+                               break
+                       }
+               }
+       }
+
+       if !exists {
+               if key == "" {
+                       key = "<unversioned>"
+               }
+               return nil, fmt.Errorf("%wno version %q found for dag bundle 
%q", BundleNotFound, key, name)
+       }
+       return d.makeClient(cmd), nil
+}
+
+func (d *Discovery) DiscoverBundles(ctx context.Context) error {
+       // Find all files in the bundles directory
+       files, err := filepath.Glob(filepath.Join(d.bundlesFolder, "*"))
+       if err != nil {
+               return fmt.Errorf("failed to read bundles directory: %w", err)
+       }
+
+       self, err := os.Executable()
+       if err != nil {
+               self = ""
+       }
+
+       for _, file := range files {
+               if ctx.Err() != nil {
+                       // Check if we are done.
+                       return ctx.Err()
+               }
+
+               // Check if file is executable
+               if !isExecutable(file) {
+                       continue
+               }
+
+               abs, err := filepath.Abs(file)
+               if err != nil {
+                       d.logger.Warn("Unable to load resolve file path", 
"file", file, "err", err)
+                       continue
+               }
+               if self != "" && self == abs {
+                       d.logger.Warn("Not trying to load ourselves as a 
plugin", "file", file)
+                       continue
+               }
+
+               d.logger.Debug("Found potential bundle", slog.String("path", 
file))
+
+               // TODO: Use a sync.WaitGroup to parallelize running multiple 
procs without blowing concurrency up and fork-bombing
+               // the host
+               bundle, err := d.getBundleVersionInfo(file)
+               if err != nil {
+                       d.logger.Warn("Unable to load BundleMetadata", "file", 
file, "err", err)
+                       continue

Review Comment:
   This would silently skip the bundle if lets say the binary has wrong 
permissions?



##########
go-sdk/celery/app.go:
##########
@@ -80,27 +72,58 @@ func Run(ctx context.Context, config Config) error {
                celery.WithBroker(broker),
        )
 
+       tasks := &celeryTasksRunner{d}
+       fmt.Printf("%#v\n", viper.AllKeys())

Review Comment:
   Accidentally added?



##########
go-sdk/pkg/worker/task.go:
##########
@@ -0,0 +1,158 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package worker
+
+import (
+       "context"
+       "fmt"
+       "log/slog"
+       "reflect"
+       "runtime"
+
+       "github.com/apache/airflow/go-sdk/sdk"
+)
+
+type (
+       taskFunction struct {
+               fn       reflect.Value
+               fullName string
+       }
+       Task interface {
+               Execute(ctx context.Context, logger *slog.Logger) error
+       }
+)
+
+var _ Task = (*taskFunction)(nil)
+
+func NewTaskFunction(fn any) (Task, error) {
+       v := reflect.ValueOf(fn)
+       fullName := runtime.FuncForPC(v.Pointer()).Name()
+       f := &taskFunction{v, fullName}
+       return f, f.validateFn(v.Type())
+}
+
+func (f *taskFunction) Execute(ctx context.Context, logger *slog.Logger) error 
{
+       fnType := f.fn.Type()
+
+       reflectArgs := make([]reflect.Value, fnType.NumIn())
+       for i := range reflectArgs {

Review Comment:
   Not right now, but worth considering. We do reflection args validation 
during "execution" here, could we maybe do it during registration? (execution 
is a hot path)



##########
go-sdk/bundle/bundlev1/bundlev1server/doc.go:
##########
@@ -0,0 +1,25 @@
+// Licensed to the Apache Software Foundation (ASF) under one

Review Comment:
   pkg/bundles/bundlev1/server/     <= server implementation  
   pkg/bundles/bundlev1/client/     <= client implementation
   
   
   Maybe



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to