This is an automated email from the ASF dual-hosted git repository.

sruehl pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/plc4x.git


The following commit(s) were added to refs/heads/develop by this push:
     new e56d9fccfe feat(plc4go/spi): make DefaultCodec burn less cpu cycles
e56d9fccfe is described below

commit e56d9fccfef9c9b481afd58559456d8b9be6bb4e
Author: Sebastian Rühl <[email protected]>
AuthorDate: Tue Nov 11 14:09:31 2025 +0100

    feat(plc4go/spi): make DefaultCodec burn less cpu cycles
---
 plc4go/spi/default/DefaultCodec.go      | 121 +++++++++++--
 plc4go/spi/default/DefaultCodec_test.go | 291 +++++++++++++++++++++++---------
 2 files changed, 321 insertions(+), 91 deletions(-)

diff --git a/plc4go/spi/default/DefaultCodec.go 
b/plc4go/spi/default/DefaultCodec.go
index 498695a483..9ea2df7c1d 100644
--- a/plc4go/spi/default/DefaultCodec.go
+++ b/plc4go/spi/default/DefaultCodec.go
@@ -87,6 +87,8 @@ type defaultCodec struct {
        running                 atomic.Bool
        stateChange             sync.Mutex
        activeWorker            sync.WaitGroup
+       notifyExpireWorker      chan struct{} `ignore:"true"`
+       notifyReceiveWorker     chan struct{} `ignore:"true"`
 
        receiveTimeout                 time.Duration
        traceDefaultMessageCodecWorker bool
@@ -118,6 +120,8 @@ func buildDefaultCodec(defaultCodecRequirements 
DefaultCodecRequirements, transp
                defaultIncomingMessageChannel:  make(chan spi.Message, 100),
                expectations:                   []spi.Expectation{},
                customMessageHandling:          customMessageHandler,
+               notifyExpireWorker:             make(chan struct{}),
+               notifyReceiveWorker:            make(chan struct{}),
                receiveTimeout:                 receiveTimeout,
                traceDefaultMessageCodecWorker: traceDefaultMessageCodecWorker 
|| config.TraceDefaultMessageCodecWorker,
                log:                            customLogger,
@@ -158,7 +162,7 @@ func (m *defaultCodec) ConnectWithContext(ctx 
context.Context) error {
        }
 
        m.log.Debug().Msg("Message codec currently not running, starting worker 
now")
-       m.startWorker()
+       m.startWorkers()
        m.running.Store(true)
        m.log.Trace().Msg("connected")
        return nil
@@ -172,6 +176,8 @@ func (m *defaultCodec) Disconnect() error {
        }
        m.log.Trace().Msg("Disconnecting")
        m.running.Store(false)
+       close(m.notifyExpireWorker)
+       close(m.notifyReceiveWorker)
        m.log.Trace().Msg("Waiting for worker to shutdown")
        m.activeWorker.Wait()
        m.log.Trace().Msg("worker shut down")
@@ -195,6 +201,14 @@ func (m *defaultCodec) Expect(ctx context.Context, 
acceptsMessage spi.AcceptsMes
        expectation := newDefaultExpectation(ctx, ttl, acceptsMessage, 
handleMessage, handleError)
        m.expectations = append(m.expectations, expectation)
        m.log.Debug().Stringer("expectation", expectation).Msg("Added 
expectation")
+       select {
+       case m.notifyExpireWorker <- struct{}{}:
+       default:
+       }
+       select {
+       case m.notifyReceiveWorker <- struct{}{}:
+       default:
+       }
 }
 
 func (m *defaultCodec) SendRequest(ctx context.Context, message spi.Message, 
acceptsMessage spi.AcceptsMessage, handleMessage spi.HandleMessage, handleError 
spi.HandleError, ttl time.Duration) error {
@@ -206,7 +220,7 @@ func (m *defaultCodec) SendRequest(ctx context.Context, 
message spi.Message, acc
        return m.Send(message)
 }
 
-func (m *defaultCodec) TimeoutExpectations(now time.Time) {
+func (m *defaultCodec) TimeoutExpectations(now time.Time) time.Duration {
        m.expectationsChangeMutex.Lock() // TODO: Note: would be nice if this 
is a read mutex which can be upgraded
        defer m.expectationsChangeMutex.Unlock()
        m.expectations = slices.DeleteFunc(m.expectations, func(expectation 
spi.Expectation) bool {
@@ -234,6 +248,14 @@ func (m *defaultCodec) TimeoutExpectations(now time.Time) {
                }
                return false
        })
+       nextExpire := 30 * time.Second
+       for _, expectation := range m.expectations {
+               expiresIn := time.Until(expectation.GetExpiration())
+               if expiresIn < nextExpire {
+                       nextExpire = expiresIn
+               }
+       }
+       return nextExpire
 }
 
 func (m *defaultCodec) HandleMessages(message spi.Message) bool {
@@ -271,12 +293,79 @@ func (m *defaultCodec) HandleMessages(message 
spi.Message) bool {
        return messageHandled
 }
 
-func (m *defaultCodec) startWorker() {
-       m.log.Trace().Msg("starting worker")
-       m.activeWorker.Go(m.Work)
+func (m *defaultCodec) startWorkers() {
+       m.log.Trace().Msg("starting workers")
+       m.startExpire()
+       m.startReceive()
+}
+
+func (m *defaultCodec) startExpire() {
+       m.log.Trace().Msg("starting expire worker")
+       m.activeWorker.Go(m.ExpireWork)
+}
+
+func (m *defaultCodec) startReceive() {
+       m.log.Trace().Msg("starting receive worker")
+       m.activeWorker.Go(m.ReceiveWork)
+}
+
+func (m *defaultCodec) ExpireWork() {
+       workerLog := m.log.With().Logger()
+       if !m.traceDefaultMessageCodecWorker {
+               workerLog = zerolog.Nop()
+       }
+       workerLog.Trace().Msg("Starting expire work")
+       defer workerLog.Trace().Msg("expire work ended")
+
+       defer func() {
+               if err := recover(); err != nil {
+                       m.log.Error().
+                               Str("stack", string(debug.Stack())).
+                               Interface("err", err).
+                               Msg("panic-ed")
+               }
+               if m.running.Load() {
+                       workerLog.Warn().Msg("Keep running")
+                       m.startExpire()
+               } else {
+                       workerLog.Info().Msg("expire worker terminated")
+               }
+       }()
+
+       // Start an endless loop
+mainLoop:
+       for m.running.Load() {
+               workerLog.Trace().Msg("expire mainloop cycle")
+               now := time.Now()
+
+               // Guard against empty expectations
+               m.expectationsChangeMutex.RLock()
+               numberOfExpectations := len(m.expectations)
+               m.expectationsChangeMutex.RUnlock()
+               if numberOfExpectations <= 0 && m.customMessageHandling == nil {
+                       workerLog.Trace().Msg("no available expectations")
+                       timer := time.NewTimer(30 * time.Second)
+                       select {
+                       case <-m.notifyExpireWorker:
+                               workerLog.Trace().Msg("waking up because of 
notification")
+                       case <-timer.C:
+                               workerLog.Trace().Msg("waking up for next 
expire")
+                       }
+                       continue mainLoop
+               }
+               nextExpire := m.TimeoutExpectations(now)
+               workerLog.Debug().Dur("nextExpire", nextExpire).Msg("waiting 
for next expire")
+               timer := time.NewTimer(nextExpire)
+               select {
+               case <-m.notifyExpireWorker:
+                       workerLog.Trace().Msg("waking up because of 
notification")
+               case <-timer.C:
+                       workerLog.Trace().Msg("waking up for next expire")
+               }
+       }
 }
 
-func (m *defaultCodec) Work() {
+func (m *defaultCodec) ReceiveWork() {
        workerLog := m.log.With().Logger()
        if !m.traceDefaultMessageCodecWorker {
                workerLog = zerolog.Nop()
@@ -293,9 +382,9 @@ func (m *defaultCodec) Work() {
                }
                if m.running.Load() {
                        workerLog.Warn().Msg("Keep running")
-                       m.startWorker()
+                       m.startReceive()
                } else {
-                       workerLog.Info().Msg("Worker terminated")
+                       workerLog.Info().Msg("receive worker terminated")
                }
        }()
 
@@ -311,11 +400,7 @@ mainLoop:
                } else {
                        workerLog.Debug().Stringer("processingTime", 
processingTime).Msg("no need to sleep") // we use stringer instead of Dur to 
have it a bit more readable
                }
-               workerLog.Trace().Msg("Working")
-               // Check for any expired expectations.
-               // (Doing this outside the loop lets us expire expectations 
even if no input is coming in)
-               now := time.Now()
-               lastLoopTime = now
+               workerLog.Trace().Msg("receive mainloop cycle")
 
                // Guard against empty expectations
                m.expectationsChangeMutex.RLock()
@@ -323,9 +408,15 @@ mainLoop:
                m.expectationsChangeMutex.RUnlock()
                if numberOfExpectations <= 0 && m.customMessageHandling == nil {
                        workerLog.Trace().Msg("no available expectations")
-                       continue mainLoop
+                       timer := time.NewTimer(30 * time.Second)
+                       select {
+                       case <-m.notifyReceiveWorker:
+                               workerLog.Trace().Msg("waking up because of 
notification")
+                       case <-timer.C:
+                               workerLog.Trace().Msg("waking up for next 
receive")
+                               continue mainLoop
+                       }
                }
-               m.TimeoutExpectations(now)
 
                workerLog.Trace().Msg("Receiving message")
                // Check for incoming messages.
diff --git a/plc4go/spi/default/DefaultCodec_test.go 
b/plc4go/spi/default/DefaultCodec_test.go
index 074901ab9a..4e1f071444 100644
--- a/plc4go/spi/default/DefaultCodec_test.go
+++ b/plc4go/spi/default/DefaultCodec_test.go
@@ -22,17 +22,15 @@ package _default
 import (
        "context"
        "fmt"
-       "os"
-       "sync"
        "sync/atomic"
        "testing"
        "time"
 
        "github.com/google/uuid"
        "github.com/pkg/errors"
-       "github.com/rs/zerolog/log"
        "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/mock"
+       "github.com/stretchr/testify/require"
 
        "github.com/apache/plc4x/plc4go/spi"
        "github.com/apache/plc4x/plc4go/spi/options"
@@ -246,25 +244,25 @@ func TestNewDefaultCodec(t *testing.T) {
                options           []options.WithOption
        }
        tests := []struct {
-               name string
-               args args
-               want DefaultCodec
+               name       string
+               args       args
+               wantAssert func(*testing.T, DefaultCodec) bool
        }{
                {
                        name: "create it",
-                       want: &defaultCodec{
-                               expectations:   []spi.Expectation{},
-                               receiveTimeout: 10 * time.Second,
-                               log:            log.Logger,
+                       wantAssert: func(t *testing.T, got DefaultCodec) bool {
+                               require.IsType(t, &defaultCodec{}, got)
+                               d := got.(*defaultCodec)
+                               assert.NotNil(t, 
d.defaultIncomingMessageChannel)
+                               assert.Equal(t, 10*time.Second, 
d.receiveTimeout)
+                               return true
                        },
                },
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
                        got := NewDefaultCodec(tt.args.requirements, 
tt.args.transportInstance, tt.args.options...)
-                       assert.NotNil(t, 
got.(*defaultCodec).defaultIncomingMessageChannel)
-                       got.(*defaultCodec).defaultIncomingMessageChannel = nil 
// Not comparable
-                       assert.Equalf(t, tt.want, got, "NewDefaultCodec(%v, %v, 
%v)", tt.args.requirements, tt.args.transportInstance, tt.args.options)
+                       assert.Truef(t, tt.wantAssert(t, got), 
"NewDefaultCodec(%v, %v, %v)", tt.args.requirements, tt.args.transportInstance, 
tt.args.options)
                })
        }
 }
@@ -297,38 +295,45 @@ func Test_buildDefaultCodec(t *testing.T) {
                options                  []options.WithOption
        }
        tests := []struct {
-               name string
-               args args
-               want DefaultCodec
+               name       string
+               args       args
+               wantAssert func(*testing.T, DefaultCodec) bool
        }{
                {
                        name: "build it",
-                       want: &defaultCodec{
-                               expectations:   []spi.Expectation{},
-                               receiveTimeout: 10 * time.Second,
-                               log:            log.Logger,
+                       wantAssert: func(t *testing.T, got DefaultCodec) bool {
+                               require.IsType(t, &defaultCodec{}, got)
+                               d := got.(*defaultCodec)
+                               assert.NotNil(t, 
d.defaultIncomingMessageChannel)
+                               assert.Equal(t, 10*time.Second, 
d.receiveTimeout)
+                               return true
                        },
                },
                {
                        name: "build it with custom handler",
                        args: args{
                                options: []options.WithOption{
-                                       withCustomMessageHandler{},
+                                       withCustomMessageHandler{
+                                               customMessageHandler: func(_ 
DefaultCodecRequirements, _ spi.Message) bool {
+                                                       return true
+                                               },
+                                       },
                                },
                        },
-                       want: &defaultCodec{
-                               expectations:   []spi.Expectation{},
-                               receiveTimeout: 10 * time.Second,
-                               log:            log.Logger,
+                       wantAssert: func(t *testing.T, got DefaultCodec) bool {
+                               require.IsType(t, &defaultCodec{}, got)
+                               d := got.(*defaultCodec)
+                               assert.NotNil(t, 
d.defaultIncomingMessageChannel)
+                               assert.Equal(t, 10*time.Second, 
d.receiveTimeout)
+                               assert.NotNil(t, d.customMessageHandling)
+                               return true
                        },
                },
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
                        got := 
buildDefaultCodec(tt.args.defaultCodecRequirements, tt.args.transportInstance, 
tt.args.options...)
-                       assert.NotNil(t, 
got.(*defaultCodec).defaultIncomingMessageChannel)
-                       got.(*defaultCodec).defaultIncomingMessageChannel = nil 
// Not comparable
-                       assert.Equalf(t, tt.want, got, "buildDefaultCodec(%v, 
%v, %v)", tt.args.defaultCodecRequirements, tt.args.transportInstance, 
tt.args.options)
+                       assert.Truef(t, tt.wantAssert(t, got), 
"buildDefaultCodec(%v, %v, %v)", tt.args.defaultCodecRequirements, 
tt.args.transportInstance, tt.args.options)
                })
        }
 }
@@ -368,6 +373,8 @@ func Test_defaultCodec_Connect(t *testing.T) {
                                transportInstance:             
tt.fields.transportInstance,
                                defaultIncomingMessageChannel: 
tt.fields.defaultIncomingMessageChannel,
                                expectations:                  
tt.fields.expectations,
+                               notifyExpireWorker:            make(chan 
struct{}, 100),
+                               notifyReceiveWorker:           make(chan 
struct{}, 100),
                                customMessageHandling:         
tt.fields.customMessageHandling,
                                log:                           
testutils.ProduceTestingLogger(t),
                        }
@@ -435,6 +442,8 @@ func Test_defaultCodec_ConnectWithContext(t *testing.T) {
                                transportInstance:             
tt.fields.transportInstance,
                                defaultIncomingMessageChannel: 
tt.fields.defaultIncomingMessageChannel,
                                expectations:                  
tt.fields.expectations,
+                               notifyExpireWorker:            make(chan 
struct{}, 100),
+                               notifyReceiveWorker:           make(chan 
struct{}, 100),
                                customMessageHandling:         
tt.fields.customMessageHandling,
                                log:                           
testutils.ProduceTestingLogger(t),
                        }
@@ -486,6 +495,8 @@ func Test_defaultCodec_Disconnect(t *testing.T) {
                                transportInstance:             
tt.fields.transportInstance,
                                defaultIncomingMessageChannel: 
tt.fields.defaultIncomingMessageChannel,
                                expectations:                  
tt.fields.expectations,
+                               notifyExpireWorker:            make(chan 
struct{}, 100),
+                               notifyReceiveWorker:           make(chan 
struct{}, 100),
                                customMessageHandling:         
tt.fields.customMessageHandling,
                                log:                           
testutils.ProduceTestingLogger(t),
                        }
@@ -536,6 +547,8 @@ func Test_defaultCodec_Expect(t *testing.T) {
                                transportInstance:             
tt.fields.transportInstance,
                                defaultIncomingMessageChannel: 
tt.fields.defaultIncomingMessageChannel,
                                expectations:                  
tt.fields.expectations,
+                               notifyExpireWorker:            make(chan 
struct{}, 100),
+                               notifyReceiveWorker:           make(chan 
struct{}, 100),
                                customMessageHandling:         
tt.fields.customMessageHandling,
                                log:                           
testutils.ProduceTestingLogger(t),
                        }
@@ -569,6 +582,8 @@ func Test_defaultCodec_GetDefaultIncomingMessageChannel(t 
*testing.T) {
                                transportInstance:             
tt.fields.transportInstance,
                                defaultIncomingMessageChannel: 
tt.fields.defaultIncomingMessageChannel,
                                expectations:                  
tt.fields.expectations,
+                               notifyExpireWorker:            make(chan 
struct{}, 100),
+                               notifyReceiveWorker:           make(chan 
struct{}, 100),
                                customMessageHandling:         
tt.fields.customMessageHandling,
                                log:                           
testutils.ProduceTestingLogger(t),
                        }
@@ -602,6 +617,8 @@ func Test_defaultCodec_GetTransportInstance(t *testing.T) {
                                transportInstance:             
tt.fields.transportInstance,
                                defaultIncomingMessageChannel: 
tt.fields.defaultIncomingMessageChannel,
                                expectations:                  
tt.fields.expectations,
+                               notifyExpireWorker:            make(chan 
struct{}, 100),
+                               notifyReceiveWorker:           make(chan 
struct{}, 100),
                                customMessageHandling:         
tt.fields.customMessageHandling,
                                log:                           
testutils.ProduceTestingLogger(t),
                        }
@@ -829,6 +846,8 @@ func Test_defaultCodec_HandleMessages(t *testing.T) {
                                transportInstance:             
tt.fields.transportInstance,
                                defaultIncomingMessageChannel: 
tt.fields.defaultIncomingMessageChannel,
                                expectations:                  
tt.fields.expectations,
+                               notifyExpireWorker:            make(chan 
struct{}, 100),
+                               notifyReceiveWorker:           make(chan 
struct{}, 100),
                                customMessageHandling:         
tt.fields.customMessageHandling,
                                log:                           
testutils.ProduceTestingLogger(t),
                        }
@@ -862,6 +881,8 @@ func Test_defaultCodec_IsRunning(t *testing.T) {
                                transportInstance:             
tt.fields.transportInstance,
                                defaultIncomingMessageChannel: 
tt.fields.defaultIncomingMessageChannel,
                                expectations:                  
tt.fields.expectations,
+                               notifyExpireWorker:            make(chan 
struct{}, 100),
+                               notifyReceiveWorker:           make(chan 
struct{}, 100),
                                customMessageHandling:         
tt.fields.customMessageHandling,
                                log:                           
testutils.ProduceTestingLogger(t),
                        }
@@ -938,6 +959,8 @@ func Test_defaultCodec_SendRequest(t *testing.T) {
                                transportInstance:             
tt.fields.transportInstance,
                                defaultIncomingMessageChannel: 
tt.fields.defaultIncomingMessageChannel,
                                expectations:                  
tt.fields.expectations,
+                               notifyExpireWorker:            make(chan 
struct{}, 100),
+                               notifyReceiveWorker:           make(chan 
struct{}, 100),
                                customMessageHandling:         
tt.fields.customMessageHandling,
                                log:                           
testutils.ProduceTestingLogger(t),
                        }
@@ -963,28 +986,30 @@ func Test_defaultCodec_TimeoutExpectations(t *testing.T) {
                fields fields
                args   args
                setup  func(t *testing.T, fields *fields, args *args)
+               want   time.Duration
        }{
                {
                        name: "timeout it (no expectations)",
+                       want: 30 * time.Second,
                },
                {
                        name: "timeout some",
                        fields: fields{
                                expectations: []spi.Expectation{
                                        &defaultExpectation{ // Expired
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return nil
                                                },
                                        },
                                        &defaultExpectation{ // Expired errors
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return errors.New("yep")
                                                },
                                        },
                                        &defaultExpectation{ // Fine
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return errors.New("yep")
                                                },
@@ -992,7 +1017,7 @@ func Test_defaultCodec_TimeoutExpectations(t *testing.T) {
                                        },
                                        &defaultExpectation{ // Context error
                                                Context: func() context.Context 
{
-                                                       ctx, cancelFunc := 
context.WithCancel(context.Background())
+                                                       ctx, cancelFunc := 
context.WithCancel(t.Context())
                                                        cancelFunc() // Cancel 
it instantly
                                                        return ctx
                                                }(),
@@ -1004,6 +1029,7 @@ func Test_defaultCodec_TimeoutExpectations(t *testing.T) {
                                },
                        },
                        args: args{now: time.Time{}.Add(2 * time.Hour)},
+                       want: time.Until(time.Time{}.Add(3 * time.Hour)),
                },
                {
                        name: "timeout some (ensure everyone is called)",
@@ -1024,21 +1050,21 @@ func Test_defaultCodec_TimeoutExpectations(t 
*testing.T) {
                                })
                                fields.expectations = []spi.Expectation{
                                        &defaultExpectation{ // Expired
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        handle1.Store(true)
                                                        return nil
                                                },
                                        },
                                        &defaultExpectation{ // Expired errors
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        handle2.Store(true)
                                                        return errors.New("yep")
                                                },
                                        },
                                        &defaultExpectation{ // Fine
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        handle3.Store(true)
                                                        return errors.New("yep")
@@ -1047,7 +1073,7 @@ func Test_defaultCodec_TimeoutExpectations(t *testing.T) {
                                        },
                                        &defaultExpectation{ // Context error
                                                Context: func() context.Context 
{
-                                                       ctx, cancelFunc := 
context.WithCancel(context.Background())
+                                                       ctx, cancelFunc := 
context.WithCancel(t.Context())
                                                        cancelFunc() // Cancel 
it instantly
                                                        return ctx
                                                }(),
@@ -1058,7 +1084,7 @@ func Test_defaultCodec_TimeoutExpectations(t *testing.T) {
                                                Expiration: time.Time{}.Add(3 * 
time.Hour),
                                        },
                                        &defaultExpectation{ // Fine
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        handle5.Store(true)
                                                        return errors.New("yep")
@@ -1067,6 +1093,7 @@ func Test_defaultCodec_TimeoutExpectations(t *testing.T) {
                                        },
                                }
                        },
+                       want: time.Until(time.Time{}.Add(3 * time.Hour)),
                },
        }
        for _, tt := range tests {
@@ -1079,20 +1106,18 @@ func Test_defaultCodec_TimeoutExpectations(t 
*testing.T) {
                                transportInstance:             
tt.fields.transportInstance,
                                defaultIncomingMessageChannel: 
tt.fields.defaultIncomingMessageChannel,
                                expectations:                  
tt.fields.expectations,
+                               notifyExpireWorker:            make(chan 
struct{}, 100),
+                               notifyReceiveWorker:           make(chan 
struct{}, 100),
                                customMessageHandling:         
tt.fields.customMessageHandling,
                                log:                           
testutils.ProduceTestingLogger(t),
                        }
-                       m.TimeoutExpectations(tt.args.now)
-                       // TODO: handle error is called async so we sleep here 
a bit. Not sure if we want to sync something here at all
-                       time.Sleep(100 * time.Millisecond)
+                       assert.Equalf(t, tt.want, 
m.TimeoutExpectations(tt.args.now), "TimeoutExpectations(%v)", tt.args.now)
+                       m.wg.Wait()
                })
        }
 }
 
-func Test_defaultCodec_Work(t *testing.T) {
-       if os.Getenv("ENABLE_RANDOMLY_FAILING_TESTS") == "" {
-               t.Skip("Skipping randomly failing tests")
-       }
+func Test_defaultCodec_ReceiveWork(t *testing.T) {
        type fields struct {
                DefaultCodecRequirements      DefaultCodecRequirements
                transportInstance             transports.TransportInstance
@@ -1123,19 +1148,19 @@ func Test_defaultCodec_Work(t *testing.T) {
                        fields: fields{
                                expectations: []spi.Expectation{
                                        &defaultExpectation{ // Expired
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return nil
                                                },
                                        },
                                        &defaultExpectation{ // Expired errors
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return errors.New("yep")
                                                },
                                        },
                                        &defaultExpectation{ // Fine
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return errors.New("yep")
                                                },
@@ -1143,7 +1168,7 @@ func Test_defaultCodec_Work(t *testing.T) {
                                        },
                                        &defaultExpectation{ // Context error
                                                Context: func() context.Context 
{
-                                                       ctx, cancelFunc := 
context.WithCancel(context.Background())
+                                                       ctx, cancelFunc := 
context.WithCancel(t.Context())
                                                        cancelFunc() // Cancel 
it instantly
                                                        return ctx
                                                }(),
@@ -1169,19 +1194,19 @@ func Test_defaultCodec_Work(t *testing.T) {
                        fields: fields{
                                expectations: []spi.Expectation{
                                        &defaultExpectation{ // Expired
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return nil
                                                },
                                        },
                                        &defaultExpectation{ // Expired errors
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return errors.New("yep")
                                                },
                                        },
                                        &defaultExpectation{ // Fine
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return errors.New("yep")
                                                },
@@ -1189,7 +1214,7 @@ func Test_defaultCodec_Work(t *testing.T) {
                                        },
                                        &defaultExpectation{ // Context error
                                                Context: func() context.Context 
{
-                                                       ctx, cancelFunc := 
context.WithCancel(context.Background())
+                                                       ctx, cancelFunc := 
context.WithCancel(t.Context())
                                                        cancelFunc() // Cancel 
it instantly
                                                        return ctx
                                                }(),
@@ -1215,19 +1240,19 @@ func Test_defaultCodec_Work(t *testing.T) {
                        fields: fields{
                                expectations: []spi.Expectation{
                                        &defaultExpectation{ // Expired
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return nil
                                                },
                                        },
                                        &defaultExpectation{ // Expired errors
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return errors.New("yep")
                                                },
                                        },
                                        &defaultExpectation{ // Fine
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return errors.New("yep")
                                                },
@@ -1235,7 +1260,7 @@ func Test_defaultCodec_Work(t *testing.T) {
                                        },
                                        &defaultExpectation{ // Context error
                                                Context: func() context.Context 
{
-                                                       ctx, cancelFunc := 
context.WithCancel(context.Background())
+                                                       ctx, cancelFunc := 
context.WithCancel(t.Context())
                                                        cancelFunc() // Cancel 
it instantly
                                                        return ctx
                                                }(),
@@ -1262,7 +1287,7 @@ func Test_defaultCodec_Work(t *testing.T) {
                                defaultIncomingMessageChannel: make(chan 
spi.Message, 1),
                                expectations: []spi.Expectation{
                                        &defaultExpectation{ // Fine
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return errors.New("yep")
                                                },
@@ -1285,19 +1310,19 @@ func Test_defaultCodec_Work(t *testing.T) {
                        fields: fields{
                                expectations: []spi.Expectation{
                                        &defaultExpectation{ // Expired
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return nil
                                                },
                                        },
                                        &defaultExpectation{ // Expired errors
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return errors.New("yep")
                                                },
                                        },
                                        &defaultExpectation{ // Fine
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return errors.New("yep")
                                                },
@@ -1305,7 +1330,7 @@ func Test_defaultCodec_Work(t *testing.T) {
                                        },
                                        &defaultExpectation{ // Context error
                                                Context: func() context.Context 
{
-                                                       ctx, cancelFunc := 
context.WithCancel(context.Background())
+                                                       ctx, cancelFunc := 
context.WithCancel(t.Context())
                                                        cancelFunc() // Cancel 
it instantly
                                                        return ctx
                                                }(),
@@ -1334,19 +1359,19 @@ func Test_defaultCodec_Work(t *testing.T) {
                                },
                                expectations: []spi.Expectation{
                                        &defaultExpectation{ // Expired
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return nil
                                                },
                                        },
                                        &defaultExpectation{ // Expired errors
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return errors.New("yep")
                                                },
                                        },
                                        &defaultExpectation{ // Fine
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return errors.New("yep")
                                                },
@@ -1354,7 +1379,7 @@ func Test_defaultCodec_Work(t *testing.T) {
                                        },
                                        &defaultExpectation{ // Context error
                                                Context: func() context.Context 
{
-                                                       ctx, cancelFunc := 
context.WithCancel(context.Background())
+                                                       ctx, cancelFunc := 
context.WithCancel(t.Context())
                                                        cancelFunc() // Cancel 
it instantly
                                                        return ctx
                                                }(),
@@ -1383,19 +1408,19 @@ func Test_defaultCodec_Work(t *testing.T) {
                                },
                                expectations: []spi.Expectation{
                                        &defaultExpectation{ // Expired
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return nil
                                                },
                                        },
                                        &defaultExpectation{ // Expired errors
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return errors.New("yep")
                                                },
                                        },
                                        &defaultExpectation{ // Fine
-                                               Context: context.Background(),
+                                               Context: t.Context(),
                                                HandleError: func(err error) 
error {
                                                        return errors.New("yep")
                                                },
@@ -1403,7 +1428,7 @@ func Test_defaultCodec_Work(t *testing.T) {
                                        },
                                        &defaultExpectation{ // Context error
                                                Context: func() context.Context 
{
-                                                       ctx, cancelFunc := 
context.WithCancel(context.Background())
+                                                       ctx, cancelFunc := 
context.WithCancel(t.Context())
                                                        cancelFunc() // Cancel 
it instantly
                                                        return ctx
                                                }(),
@@ -1435,25 +1460,30 @@ func Test_defaultCodec_Work(t *testing.T) {
                                transportInstance:             
tt.fields.transportInstance,
                                defaultIncomingMessageChannel: 
tt.fields.defaultIncomingMessageChannel,
                                expectations:                  
tt.fields.expectations,
+                               notifyExpireWorker:            make(chan 
struct{}, 100),
+                               notifyReceiveWorker:           make(chan 
struct{}, 100),
                                customMessageHandling:         
tt.fields.customMessageHandling,
                                log:                           
testutils.ProduceTestingLogger(t),
                        }
                        if tt.manipulator != nil {
                                tt.manipulator(t, m)
                        }
-                       var wg sync.WaitGroup
-                       t.Cleanup(wg.Wait)
-                       wg.Go(func() {
+                       t.Cleanup(m.wg.Wait)
+                       m.wg.Go(func() {
                                // Stop after 200ms
-                               time.Sleep(200 * time.Millisecond)
+                               timer := time.NewTimer(200 * time.Millisecond)
+                               select {
+                               case <-timer.C:
+                               case <-t.Context().Done():
+                               }
                                m.running.Store(false)
                        })
-                       m.Work()
+                       m.ReceiveWork()
                })
        }
 }
 
-func Test_defaultCodec_startWorker(t *testing.T) {
+func Test_defaultCodec_startWorkers(t *testing.T) {
        type fields struct {
                DefaultCodecRequirements       DefaultCodecRequirements
                transportInstance              transports.TransportInstance
@@ -1483,7 +1513,116 @@ func Test_defaultCodec_startWorker(t *testing.T) {
                                traceDefaultMessageCodecWorker: 
tt.fields.traceDefaultMessageCodecWorker,
                                log:                            
testutils.ProduceTestingLogger(t),
                        }
-                       m.startWorker()
+                       m.startWorkers()
                })
        }
 }
+
+func Test_defaultCodec_integration(t *testing.T) {
+       mockDefaultCodecRequirements := NewMockDefaultCodecRequirements(t)
+       {
+               expect := mockDefaultCodecRequirements.EXPECT()
+               message := NewMockMessage(t)
+               {
+                       expect := message.EXPECT()
+                       expect.String().Return("message for " + t.Name())
+               }
+               expect.Receive().RunAndReturn(func() (spi.Message, error) {
+                       // Simulate a bit read delay
+                       timer := time.NewTimer(100 * time.Millisecond)
+                       select {
+                       case <-timer.C:
+                       case <-t.Context().Done():
+                       }
+                       if err := t.Context().Err(); err != nil {
+                               return nil, err
+                       }
+                       return message, nil
+               })
+       }
+       mockTransportInstance := NewMockTransportInstance(t)
+       {
+               expect := mockTransportInstance.EXPECT()
+               expect.IsConnected().Return(true)
+               expect.Close().Return(nil)
+       }
+       sut := NewDefaultCodec(mockDefaultCodecRequirements, 
mockTransportInstance,
+               options.WithCustomLogger(testutils.ProduceTestingLogger(t)),
+               options.WithTraceDefaultMessageCodecWorker(true),
+       )
+       t.Cleanup(func() {
+               _ = sut.Disconnect()
+       })
+       // First expect
+       var firstHandled bool
+       sut.Expect(t.Context(),
+               func(message spi.Message) bool {
+                       t.Log("accepts message", message)
+                       return true
+               }, func(message spi.Message) error {
+                       t.Log("handle message", message)
+                       firstHandled = true
+                       return nil
+               }, func(err error) error {
+                       t.Log("error", err)
+                       return nil
+               },
+               500*time.Millisecond)
+       // Second expect
+       var secondHandled bool
+       sut.Expect(t.Context(),
+               func(message spi.Message) bool {
+                       t.Log("accepts message", message)
+                       return true
+               }, func(message spi.Message) error {
+                       t.Log("handle message", message)
+                       secondHandled = true
+                       return nil
+               }, func(err error) error {
+                       t.Log("error", err)
+                       return nil
+               },
+               1500*time.Millisecond)
+       // Third expect
+       var thridErrorCalled bool
+       sut.Expect(t.Context(),
+               func(message spi.Message) bool {
+                       t.Log("does not accept message", message)
+                       return false
+               }, func(message spi.Message) error {
+                       t.Error("should not be called")
+                       return nil
+               }, func(err error) error {
+                       thridErrorCalled = true
+                       return nil
+               },
+               1500*time.Millisecond)
+       // Fourth expect
+       var fourthHandled bool
+       sut.Expect(t.Context(),
+               func(message spi.Message) bool {
+                       t.Log("accepts message", message)
+                       return true
+               }, func(message spi.Message) error {
+                       t.Log("handle message", message)
+                       fourthHandled = true
+                       return nil
+               }, func(err error) error {
+                       t.Log("error", err)
+                       return nil
+               },
+               3000*time.Millisecond)
+
+       err := sut.ConnectWithContext(t.Context())
+       assert.NoError(t, err)
+       timer := time.NewTimer(10 * time.Second)
+       select {
+       case <-timer.C:
+       case <-t.Context().Done():
+       }
+       assert.NoError(t, sut.Disconnect())
+       assert.True(t, firstHandled)
+       assert.True(t, secondHandled)
+       assert.True(t, thridErrorCalled) // because of our disconnect
+       assert.True(t, fourthHandled)
+}

Reply via email to