This is an automated email from the ASF dual-hosted git repository.

liuhan pushed a commit to branch reduce-handle-connect-time
in repository https://gitbox.apache.org/repos/asf/skywalking-rover.git


The following commit(s) were added to refs/heads/reduce-handle-connect-time by 
this push:
     new a536ed4  update conntrack
a536ed4 is described below

commit a536ed49ed07410d9d8c142ddea994a4e4254d6d
Author: mrproliu <[email protected]>
AuthorDate: Sat Dec 28 15:21:51 2024 +0800

    update conntrack
---
 pkg/accesslog/collector/connection.go | 160 ++++++++++++++++++++---------
 pkg/tools/ip/conntrack.go             | 188 +++++++++++++++++++++++++---------
 2 files changed, 255 insertions(+), 93 deletions(-)

diff --git a/pkg/accesslog/collector/connection.go 
b/pkg/accesslog/collector/connection.go
index e3b52e9..f56ed05 100644
--- a/pkg/accesslog/collector/connection.go
+++ b/pkg/accesslog/collector/connection.go
@@ -18,11 +18,14 @@
 package collector
 
 import (
+       "container/list"
        "context"
        "encoding/binary"
        "fmt"
        "net"
        "os"
+       "sync"
+       "time"
 
        "github.com/docker/go-units"
 
@@ -50,6 +53,8 @@ var connectionLogger = logger.GetLogger("access_log", 
"collector", "connection")
 
 var connectionCollectInstance = NewConnectionCollector()
 
+var connectionAnalyzeRetryTime = time.Second * 2
+
 type ConnectCollector struct {
        eventQueue *btf.EventQueue
 }
@@ -81,7 +86,7 @@ func (c *ConnectCollector) Start(m *module.Manager, ctx 
*common.AccessLogContext
        }
        c.eventQueue = btf.NewEventQueue("connection resolver", 
ctx.Config.ConnectionAnalyze.AnalyzeParallels,
                ctx.Config.ConnectionAnalyze.QueueSize, func(num int) 
btf.PartitionContext {
-                       return newConnectionPartitionContext(ctx, track, 
m.FindModule(process.ModuleName).(process.K8sOperator))
+                       return NewConnectionPartitionContext(ctx, track, 
m.FindModule(process.ModuleName).(process.K8sOperator))
                })
        c.eventQueue.RegisterReceiver(ctx.BPF.SocketConnectionEventQueue, 
int(perCPUBufferSize),
                ctx.Config.ConnectionAnalyze.ParseParallels, func() interface{} 
{
@@ -131,22 +136,69 @@ func (c *ConnectCollector) Stop() {
 }
 
 type ConnectionPartitionContext struct {
-       context     *common.AccessLogContext
-       connTracker *ip.ConnTrack
-       k8sOperator process.K8sOperator
+       context         *common.AccessLogContext
+       connTracker     *ip.ConnTrack
+       k8sOperator     process.K8sOperator
+       retryableEvents *list.List
+       retryableMutex  sync.Mutex
 }
 
-func newConnectionPartitionContext(ctx *common.AccessLogContext, connTracker 
*ip.ConnTrack,
+func NewConnectionPartitionContext(ctx *common.AccessLogContext, connTracker 
*ip.ConnTrack,
        k8sOperator process.K8sOperator) *ConnectionPartitionContext {
        return &ConnectionPartitionContext{
-               context:     ctx,
-               connTracker: connTracker,
-               k8sOperator: k8sOperator,
+               context:         ctx,
+               connTracker:     connTracker,
+               k8sOperator:     k8sOperator,
+               retryableEvents: list.New(),
        }
 }
 
 func (c *ConnectionPartitionContext) Start(ctx context.Context) {
+       go func() {
+               ticker := time.NewTicker(connectionAnalyzeRetryTime)
+               for {
+                       select {
+                       case <-ctx.Done():
+                               return
+                       case <-ticker.C:
+                               c.analyzeFailureEvent()
+                       }
+               }
+       }()
+}
+
+func (c *ConnectionPartitionContext) analyzeFailureEvent() {
+       for {
+               event := c.PopEventToRetry()
+               if event == nil {
+                       break
+               }
+               socketPair, err := c.BuildSocketFromConnectEvent(event)
+               if err != nil {
+                       connectionLogger.Debugf("retry to analyze the connect 
event failure, connection ID: %d, randomID: %d, error: %v",
+                               event.ConID, event.RandomID, err)
+                       continue
+               }
+               c.OnConnectionSocketFinished(event, socketPair)
+       }
+}
 
+func (c *ConnectionPartitionContext) AddEventToRetry(e 
*events.SocketConnectEvent) {
+       c.retryableMutex.Lock()
+       defer c.retryableMutex.Unlock()
+       c.retryableEvents.PushBack(e)
+}
+
+func (c *ConnectionPartitionContext) PopEventToRetry() 
*events.SocketConnectEvent {
+       c.retryableMutex.Lock()
+       defer c.retryableMutex.Unlock()
+       if c.retryableEvents.Len() == 0 {
+               return nil
+       }
+       element := c.retryableEvents.Front()
+       event := element.Value.(*events.SocketConnectEvent)
+       c.retryableEvents.Remove(element)
+       return event
 }
 
 func (c *ConnectionPartitionContext) Consume(data interface{}) {
@@ -156,16 +208,12 @@ func (c *ConnectionPartitionContext) Consume(data 
interface{}) {
                        "pid: %d, fd: %d, role: %s: func: %s, family: %d, 
success: %d, conntrack exist: %t",
                        event.ConID, event.RandomID, event.PID, event.SocketFD, 
enums.ConnectionRole(event.Role), enums.SocketFunctionName(event.FuncName),
                        event.SocketFamily, event.ConnectSuccess, 
event.ConnTrackUpstreamPort != 0)
-               socketPair := c.buildSocketFromConnectEvent(event)
-               if socketPair == nil {
-                       connectionLogger.Debugf("cannot found the socket paire 
from connect event, connection ID: %d, randomID: %d",
-                               event.ConID, event.RandomID)
+               socketPair, err := c.BuildSocketFromConnectEvent(event)
+               if err != nil {
+                       c.AddEventToRetry(event)
                        return
                }
-               connectionLogger.Debugf("build socket pair success, connection 
ID: %d, randomID: %d, role: %s, local: %s:%d, remote: %s:%d",
-                       event.ConID, event.RandomID, socketPair.Role, 
socketPair.SrcIP, socketPair.SrcPort, socketPair.DestIP, socketPair.DestPort)
-               c.context.ConnectionMgr.OnConnectEvent(event, socketPair)
-               forwarder.SendConnectEvent(c.context, event, socketPair)
+               c.OnConnectionSocketFinished(event, socketPair)
        case *events.SocketCloseEvent:
                connectionLogger.Debugf("receive close event, connection ID: 
%d, randomID: %d, pid: %d, fd: %d",
                        event.ConnectionID, event.RandomID, event.PID, 
event.SocketFD)
@@ -174,7 +222,7 @@ func (c *ConnectionPartitionContext) Consume(data 
interface{}) {
        }
 }
 
-func (c *ConnectionPartitionContext) fixSocketFamilyIfNeed(event 
*events.SocketConnectEvent, result *ip.SocketPair) {
+func (c *ConnectionPartitionContext) FixSocketFamilyIfNeed(event 
*events.SocketConnectEvent, result *ip.SocketPair) {
        if result == nil {
                return
        }
@@ -194,39 +242,55 @@ func (c *ConnectionPartitionContext) 
fixSocketFamilyIfNeed(event *events.SocketC
        }
 }
 
-func (c *ConnectionPartitionContext) buildSocketFromConnectEvent(event 
*events.SocketConnectEvent) *ip.SocketPair {
+func (c *ConnectionPartitionContext) OnConnectionSocketFinished(event 
*events.SocketConnectEvent, socketPair *ip.SocketPair) {
+       if socketPair == nil {
+               connectionLogger.Debugf("cannot found the socket paire from 
connect event, connection ID: %d, randomID: %d",
+                       event.ConID, event.RandomID)
+               return
+       }
+       connectionLogger.Debugf("build socket pair success, connection ID: %d, 
randomID: %d, role: %s, local: %s:%d, remote: %s:%d",
+               event.ConID, event.RandomID, socketPair.Role, socketPair.SrcIP, 
socketPair.SrcPort, socketPair.DestIP, socketPair.DestPort)
+       c.context.ConnectionMgr.OnConnectEvent(event, socketPair)
+       forwarder.SendConnectEvent(c.context, event, socketPair)
+}
+
+func (c *ConnectionPartitionContext) BuildSocketFromConnectEvent(event 
*events.SocketConnectEvent) (*ip.SocketPair, error) {
        if event.SocketFamily != unix.AF_INET && event.SocketFamily != 
unix.AF_INET6 && event.SocketFamily != enums.SocketFamilyUnknown {
                // if not ipv4, ipv6 or unknown, ignore
-               return nil
+               return nil, nil
        }
-       socketPair := c.buildSocketPair(event)
-       if socketPair != nil && socketPair.IsValid() {
+       pair, err := c.BuildSocketPair(event)
+       if err != nil {
+               return nil, err
+       }
+       if pair != nil && pair.IsValid() {
                connectionLogger.Debugf("found the connection from the connect 
event is valid, connection ID: %d, randomID: %d",
                        event.ConID, event.RandomID)
-               return socketPair
+               return pair, nil
        }
        // if only the local port not success, maybe the upstream port is not 
open, so it could be continued
-       if c.isOnlyLocalPortEmpty(socketPair) {
+       if c.IsOnlyLocalPortEmpty(pair) {
                event.ConnectSuccess = 0
                connectionLogger.Debugf("the connection from the connect event 
is only the local port is empty, connection ID: %d, randomID: %d",
                        event.ConID, event.RandomID)
-               return socketPair
+               return pair, nil
        }
 
-       pair, err := ip.ParseSocket(event.PID, event.SocketFD)
+       pair, err = ip.ParseSocket(event.PID, event.SocketFD)
        if err != nil {
-               connectionLogger.Debugf("cannot found the socket, pid: %d, 
socket FD: %d", event.PID, event.SocketFD)
-               return nil
+               log.Debugf("cannot found the socket, pid: %d, socket FD: %d, 
error: %v", event.PID, event.SocketFD, err)
+               // because if the socket is not found, the connection is not 
valid, should not return the error for retry
+               return nil, nil
        }
        connectionLogger.Debugf("found the connection from the socket, 
connection ID: %d, randomID: %d",
                event.ConID, event.RandomID)
        pair.Role = enums.ConnectionRole(event.Role)
-       c.fixSocketFamilyIfNeed(event, pair)
-       c.tryToUpdateSocketFromConntrack(event, pair)
-       return pair
+       c.FixSocketFamilyIfNeed(event, pair)
+       c.TryToUpdateSocketFromConntrack(event, pair)
+       return pair, nil
 }
 
-func (c *ConnectionPartitionContext) isOnlyLocalPortEmpty(socketPair 
*ip.SocketPair) bool {
+func (c *ConnectionPartitionContext) IsOnlyLocalPortEmpty(socketPair 
*ip.SocketPair) bool {
        if socketPair == nil {
                return false
        }
@@ -238,7 +302,7 @@ func (c *ConnectionPartitionContext) 
isOnlyLocalPortEmpty(socketPair *ip.SocketP
        return socketPair.IsValid()
 }
 
-func (c *ConnectionPartitionContext) buildSocketPair(event 
*events.SocketConnectEvent) *ip.SocketPair {
+func (c *ConnectionPartitionContext) BuildSocketPair(event 
*events.SocketConnectEvent) (*ip.SocketPair, error) {
        var result *ip.SocketPair
        haveConnTrack := false
        if event.SocketFamily == unix.AF_INET {
@@ -290,18 +354,19 @@ func (c *ConnectionPartitionContext) 
buildSocketPair(event *events.SocketConnect
        }
 
        if haveConnTrack {
-               return result
+               return result, nil
        }
 
-       c.fixSocketFamilyIfNeed(event, result)
-       c.tryToUpdateSocketFromConntrack(event, result)
-       return result
+       c.FixSocketFamilyIfNeed(event, result)
+       // support retry to update the socket from conntrack
+       err := c.TryToUpdateSocketFromConntrack(event, result)
+       return result, err
 }
 
-func (c *ConnectionPartitionContext) tryToUpdateSocketFromConntrack(event 
*events.SocketConnectEvent, socket *ip.SocketPair) {
-       if socket == nil || !socket.IsValid() || 
tools.IsLocalHostAddress(socket.DestIP) &&
+func (c *ConnectionPartitionContext) TryToUpdateSocketFromConntrack(event 
*events.SocketConnectEvent, socket *ip.SocketPair) error {
+       if socket == nil || !socket.IsValid() || 
tools.IsLocalHostAddress(socket.DestIP) ||
                event.FuncName == enums.SocketFunctionNameAccept { // accept 
event don't need to update the remote address
-               return
+               return nil
        }
        if c.context.ConnectionMgr.ProcessIsDetectBy(event.PID, api.Kubernetes) 
{
                isPodIP, err := c.k8sOperator.IsPodIP(socket.DestIP)
@@ -312,20 +377,23 @@ func (c *ConnectionPartitionContext) 
tryToUpdateSocketFromConntrack(event *event
                if isPodIP {
                        connectionLogger.Debugf("detect the remote IP is pod 
IP, connection ID: %d, randomID: %d, remote: %s",
                                event.ConID, event.RandomID, socket.DestIP)
-                       return
+                       return nil
                }
        }
-       connectionLogger.Infof("try to update the remote address from 
conntrack, connection ID: %d, randomID: %d, func: %s, remote: %s:%d",
-               event.ConID, event.RandomID, 
enums.SocketFunctionName(event.FuncName), socket.DestIP, socket.DestPort)
+       connectionLogger.Infof("try to update the remote address from 
conntrack, connection ID: %d, randomID: %d, func: %s, local: %s:%d, remote: 
%s:%d",
+               event.ConID, event.RandomID, 
enums.SocketFunctionName(event.FuncName), socket.SrcIP, socket.SrcPort, 
socket.DestIP, socket.DestPort)
        if c.connTracker != nil {
                // if no contract and socket data is valid, then trying to get 
the remote address from the socket
                // to encase the remote address is not the real remote address
                originalIP := socket.DestIP
                originalPort := socket.DestPort
-               if c.connTracker.UpdateRealPeerAddress(socket) {
-                       connectionLogger.Debugf("update the socket address from 
conntrack success, "+
-                               "connection ID: %d, randomID: %d, original 
remote: %s:%d, new remote: %s:%d",
-                               event.ConID, event.RandomID, originalIP, 
originalPort, socket.DestIP, socket.DestPort)
+               err := c.connTracker.UpdateRealPeerAddress(socket)
+               if err != nil {
+                       return fmt.Errorf("update the socket address from 
conntrack failure, error: %v", err)
                }
+               connectionLogger.Infof("update the socket address from 
conntrack success, "+
+                       "connection ID: %d, randomID: %d, original remote: 
%s:%d, new remote: %s:%d",
+                       event.ConID, event.RandomID, originalIP, originalPort, 
socket.DestIP, socket.DestPort)
        }
+       return nil
 }
diff --git a/pkg/tools/ip/conntrack.go b/pkg/tools/ip/conntrack.go
index 20cc29f..978569d 100644
--- a/pkg/tools/ip/conntrack.go
+++ b/pkg/tools/ip/conntrack.go
@@ -18,83 +18,169 @@
 package ip
 
 import (
-       "github.com/mdlayher/netlink"
-       "github.com/ti-mo/netfilter"
+       "context"
+       "fmt"
+
        "net/netip"
        "syscall"
+       "time"
 
-       "github.com/apache/skywalking-rover/pkg/logger"
+       "github.com/mdlayher/netlink"
 
        "github.com/ti-mo/conntrack"
+       "github.com/ti-mo/netfilter"
+
+       "k8s.io/apimachinery/pkg/util/cache"
+
+       "github.com/apache/skywalking-rover/pkg/logger"
 )
 
-var log = logger.GetLogger("tools", "ip")
+var (
+       log = logger.GetLogger("tools", "ip")
+
+       // monitorExpireTime is the time to expire the conntrack session
+       monitorExpireTime = time.Second * 20
+)
 
 type ConnTrack struct {
-       client *conntrack.Conn
+       queryClient   *conntrack.Conn
+       monitorClient *conntrack.Conn
+       monitorExpire *cache.Expiring
+       eventChain    chan conntrack.Event
+
+       ctx    context.Context
+       cancel context.CancelFunc
 }
 
 func NewConnTrack() (*ConnTrack, error) {
-       client, err := conntrack.Dial(nil)
+       queryClient, err := conntrack.Dial(nil)
        if err != nil {
                return nil, err
        }
 
-       go func() {
-               evCh := make(chan conntrack.Event, 2048)
+       return &ConnTrack{
+               queryClient:   queryClient,
+               eventChain:    make(chan conntrack.Event, 2048),
+               monitorExpire: cache.NewExpiring(),
+       }, nil
+}
+
+func (c *ConnTrack) StartMonitoring(ctx context.Context) error {
+       c.ctx, c.cancel = context.WithCancel(ctx)
+       errors := make(chan error, 2)
+       errChain, err := c.monitor0(ctx)
+       if err != nil {
+               return err
+       }
 
-               errCh, err := client.Listen(evCh, 4, []netfilter.NetlinkGroup{
-                       netfilter.GroupCTNew, // watching for new conntrack 
events
-               })
-               if err != nil {
-                       log.Error(err)
+       go func() {
+               e := <-errChain
+               errors <- e
+
+               for {
+                       select {
+                       case <-ctx.Done():
+                               return
+                       case monitoringError := <-errors:
+                               log.Warnf("monitoring conntrack failure, will 
re-try monitoring after 5 second. error: %v", monitoringError)
+                               time.Sleep(time.Second * 5)
+
+                               errChain, monitoringError = c.monitor0(ctx)
+                               if monitoringError != nil {
+                                       errors <- monitoringError
+                                       continue
+                               }
+                               mError := <-errChain
+                               errors <- mError
+                       }
                }
+       }()
+       return nil
+}
 
-               client.SetReadBuffer(26214400) // 25MB
-               // Listen to Conntrack events from all network namespaces on 
the system.
-               err = client.SetOption(netlink.ListenAllNSID, true)
-               if err != nil {
-                       log.Error(err)
+func (c *ConnTrack) monitor0(ctx context.Context) (chan error, error) {
+       isFirstMonitoring := true
+       if c.monitorClient != nil {
+               isFirstMonitoring = false
+               if e := c.monitorClient.Close(); e != nil {
+                       log.Warnf("close the conntack monitor client error: 
%v", e)
                }
+       }
 
-               // Start a goroutine to print all incoming messages on the 
event channel.
-               go func() {
-                       for {
-                               e := <-evCh
+       cl, err := conntrack.Dial(nil)
+       if err != nil {
+               return nil, err
+       }
+       c.monitorClient = cl
+       errCh, err := c.monitorClient.Listen(c.eventChain, 4, 
[]netfilter.NetlinkGroup{
+               netfilter.GroupCTNew, // watching for new conntrack events
+       })
+       if err != nil {
+               return nil, err
+       }
+
+       // Set the read buffer to 25MB
+       // encase: no buffer space available
+       c.monitorClient.SetReadBuffer(26214400)
+       // Listen to Conntrack events from all network namespaces on the system.
+       err = c.monitorClient.SetOption(netlink.ListenAllNSID, true)
+       if err != nil {
+               return nil, err
+       }
+
+       // is not the first monitoring, then return(no need re monitoring the 
channel)
+       if !isFirstMonitoring {
+               return errCh, nil
+       }
+       go func() {
+               for {
+                       select {
+                       case <-ctx.Done():
+                               return
+                       case e := <-c.eventChain:
                                if e.Flow.TupleOrig.Proto.DestinationPort == 53 
{
                                        continue
                                }
-                               log.Infof("conntrack: type: %s, origin: 
%s:%d->%s:%d, reply: %s:%d->%s:%d", e.Type,
-                                       e.Flow.TupleOrig.IP.SourceAddress, 
e.Flow.TupleOrig.Proto.SourcePort,
-                                       e.Flow.TupleOrig.IP.DestinationAddress, 
e.Flow.TupleOrig.Proto.DestinationPort,
-                                       e.Flow.TupleReply.IP.SourceAddress, 
e.Flow.TupleReply.Proto.SourcePort,
-                                       
e.Flow.TupleReply.IP.DestinationAddress, 
e.Flow.TupleReply.Proto.DestinationPort)
+                               c.monitorExpire.Set(conntrackExpireKey{
+                                       sourceIP:   
e.Flow.TupleOrig.IP.SourceAddress.String(),
+                                       destIP:     
e.Flow.TupleOrig.IP.DestinationAddress.String(),
+                                       sourcePort: 
e.Flow.TupleOrig.Proto.SourcePort,
+                                       destPort:   
e.Flow.TupleOrig.Proto.DestinationPort,
+                               }, conntrackExpireValue{
+                                       realIP:   
e.Flow.TupleReply.IP.SourceAddress.String(),
+                                       realPort: 
e.Flow.TupleReply.Proto.SourcePort,
+                               }, monitorExpireTime)
                        }
-               }()
-
-               // Stop the program as soon as an error is caught in a decoder 
goroutine.
-               if err := <-errCh; err != nil {
-                       log.Errorf("conntrack error: %v", err)
                }
        }()
 
-       query, err := conntrack.Dial(nil)
-       if err != nil {
-               return nil, err
-       }
-       return &ConnTrack{client: query}, nil
+       return errCh, nil
 }
 
-func (c *ConnTrack) UpdateRealPeerAddress(addr *SocketPair) bool {
+func (c *ConnTrack) UpdateRealPeerAddress(addr *SocketPair) error {
+       key := conntrackExpireKey{
+               sourceIP:   addr.SrcIP,
+               destIP:     addr.DestIP,
+               sourcePort: addr.SrcPort,
+               destPort:   addr.DestPort,
+       }
+       val, exist := c.monitorExpire.Get(key)
+       if exist {
+               v := val.(conntrackExpireValue)
+               addr.DestIP = v.realIP
+               addr.DestPort = v.realPort
+               log.Debugf("update real peer address: %s:%d", addr.DestIP, 
addr.DestPort)
+               c.monitorExpire.Delete(key)
+               return nil
+       }
+
        srcIP, err := netip.ParseAddr(addr.SrcIP)
        if err != nil {
-               log.Errorf("cannot parse the address: %s, error: %v", 
addr.SrcIP, err)
-               return false
+               return fmt.Errorf("parsing src IP address failure: %s, error: 
%v", addr.SrcIP, err)
        }
        destIP, err := netip.ParseAddr(addr.DestIP)
        if err != nil {
-               log.Errorf("cannot parse the address: %s, error: %v", 
addr.DestIP, err)
-               return false
+               return fmt.Errorf("parsing dest IP address failure: %s, error: 
%v", addr.DestIP, err)
        }
        flow := conntrack.Flow{
                TupleOrig: conntrack.Tuple{
@@ -110,14 +196,22 @@ func (c *ConnTrack) UpdateRealPeerAddress(addr 
*SocketPair) bool {
                },
        }
 
-       get, err := c.client.Get(flow)
+       get, err := c.queryClient.Get(flow)
        if err != nil {
-               log.Errorf("cannot get the conntrack session, error: %v", err)
-               return false
+               return err
        }
 
        addr.DestIP = get.TupleReply.IP.SourceAddress.String()
        addr.DestPort = get.TupleReply.Proto.SourcePort
-       log.Infof("update real peer address: %s:%d", addr.DestIP, addr.DestPort)
-       return true
+       return nil
+}
+
+type conntrackExpireKey struct {
+       sourceIP, destIP     string
+       sourcePort, destPort uint16
+}
+
+type conntrackExpireValue struct {
+       realIP   string
+       realPort uint16
 }

Reply via email to