This is an automated email from the ASF dual-hosted git repository.
sruehl pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/plc4x.git
The following commit(s) were added to refs/heads/develop by this push:
new e56d9fccfe feat(plc4go/spi): make DefaultCodec burn less cpu cycles
e56d9fccfe is described below
commit e56d9fccfef9c9b481afd58559456d8b9be6bb4e
Author: Sebastian Rühl <[email protected]>
AuthorDate: Tue Nov 11 14:09:31 2025 +0100
feat(plc4go/spi): make DefaultCodec burn less cpu cycles
---
plc4go/spi/default/DefaultCodec.go | 121 +++++++++++--
plc4go/spi/default/DefaultCodec_test.go | 291 +++++++++++++++++++++++---------
2 files changed, 321 insertions(+), 91 deletions(-)
diff --git a/plc4go/spi/default/DefaultCodec.go
b/plc4go/spi/default/DefaultCodec.go
index 498695a483..9ea2df7c1d 100644
--- a/plc4go/spi/default/DefaultCodec.go
+++ b/plc4go/spi/default/DefaultCodec.go
@@ -87,6 +87,8 @@ type defaultCodec struct {
running atomic.Bool
stateChange sync.Mutex
activeWorker sync.WaitGroup
+ notifyExpireWorker chan struct{} `ignore:"true"`
+ notifyReceiveWorker chan struct{} `ignore:"true"`
receiveTimeout time.Duration
traceDefaultMessageCodecWorker bool
@@ -118,6 +120,8 @@ func buildDefaultCodec(defaultCodecRequirements
DefaultCodecRequirements, transp
defaultIncomingMessageChannel: make(chan spi.Message, 100),
expectations: []spi.Expectation{},
customMessageHandling: customMessageHandler,
+ notifyExpireWorker: make(chan struct{}),
+ notifyReceiveWorker: make(chan struct{}),
receiveTimeout: receiveTimeout,
traceDefaultMessageCodecWorker: traceDefaultMessageCodecWorker
|| config.TraceDefaultMessageCodecWorker,
log: customLogger,
@@ -158,7 +162,7 @@ func (m *defaultCodec) ConnectWithContext(ctx
context.Context) error {
}
m.log.Debug().Msg("Message codec currently not running, starting worker
now")
- m.startWorker()
+ m.startWorkers()
m.running.Store(true)
m.log.Trace().Msg("connected")
return nil
@@ -172,6 +176,8 @@ func (m *defaultCodec) Disconnect() error {
}
m.log.Trace().Msg("Disconnecting")
m.running.Store(false)
+ close(m.notifyExpireWorker)
+ close(m.notifyReceiveWorker)
m.log.Trace().Msg("Waiting for worker to shutdown")
m.activeWorker.Wait()
m.log.Trace().Msg("worker shut down")
@@ -195,6 +201,14 @@ func (m *defaultCodec) Expect(ctx context.Context,
acceptsMessage spi.AcceptsMes
expectation := newDefaultExpectation(ctx, ttl, acceptsMessage,
handleMessage, handleError)
m.expectations = append(m.expectations, expectation)
m.log.Debug().Stringer("expectation", expectation).Msg("Added
expectation")
+ select {
+ case m.notifyExpireWorker <- struct{}{}:
+ default:
+ }
+ select {
+ case m.notifyReceiveWorker <- struct{}{}:
+ default:
+ }
}
func (m *defaultCodec) SendRequest(ctx context.Context, message spi.Message,
acceptsMessage spi.AcceptsMessage, handleMessage spi.HandleMessage, handleError
spi.HandleError, ttl time.Duration) error {
@@ -206,7 +220,7 @@ func (m *defaultCodec) SendRequest(ctx context.Context,
message spi.Message, acc
return m.Send(message)
}
-func (m *defaultCodec) TimeoutExpectations(now time.Time) {
+func (m *defaultCodec) TimeoutExpectations(now time.Time) time.Duration {
m.expectationsChangeMutex.Lock() // TODO: Note: would be nice if this
is a read mutex which can be upgraded
defer m.expectationsChangeMutex.Unlock()
m.expectations = slices.DeleteFunc(m.expectations, func(expectation
spi.Expectation) bool {
@@ -234,6 +248,14 @@ func (m *defaultCodec) TimeoutExpectations(now time.Time) {
}
return false
})
+ nextExpire := 30 * time.Second
+ for _, expectation := range m.expectations {
+ expiresIn := time.Until(expectation.GetExpiration())
+ if expiresIn < nextExpire {
+ nextExpire = expiresIn
+ }
+ }
+ return nextExpire
}
func (m *defaultCodec) HandleMessages(message spi.Message) bool {
@@ -271,12 +293,79 @@ func (m *defaultCodec) HandleMessages(message
spi.Message) bool {
return messageHandled
}
-func (m *defaultCodec) startWorker() {
- m.log.Trace().Msg("starting worker")
- m.activeWorker.Go(m.Work)
+func (m *defaultCodec) startWorkers() {
+ m.log.Trace().Msg("starting workers")
+ m.startExpire()
+ m.startReceive()
+}
+
+func (m *defaultCodec) startExpire() {
+ m.log.Trace().Msg("starting expire worker")
+ m.activeWorker.Go(m.ExpireWork)
+}
+
+func (m *defaultCodec) startReceive() {
+ m.log.Trace().Msg("starting receive worker")
+ m.activeWorker.Go(m.ReceiveWork)
+}
+
+func (m *defaultCodec) ExpireWork() {
+ workerLog := m.log.With().Logger()
+ if !m.traceDefaultMessageCodecWorker {
+ workerLog = zerolog.Nop()
+ }
+ workerLog.Trace().Msg("Starting expire work")
+ defer workerLog.Trace().Msg("expire work ended")
+
+ defer func() {
+ if err := recover(); err != nil {
+ m.log.Error().
+ Str("stack", string(debug.Stack())).
+ Interface("err", err).
+ Msg("panic-ed")
+ }
+ if m.running.Load() {
+ workerLog.Warn().Msg("Keep running")
+ m.startExpire()
+ } else {
+ workerLog.Info().Msg("expire worker terminated")
+ }
+ }()
+
+ // Start an endless loop
+mainLoop:
+ for m.running.Load() {
+ workerLog.Trace().Msg("expire mainloop cycle")
+ now := time.Now()
+
+ // Guard against empty expectations
+ m.expectationsChangeMutex.RLock()
+ numberOfExpectations := len(m.expectations)
+ m.expectationsChangeMutex.RUnlock()
+ if numberOfExpectations <= 0 && m.customMessageHandling == nil {
+ workerLog.Trace().Msg("no available expectations")
+ timer := time.NewTimer(30 * time.Second)
+ select {
+ case <-m.notifyExpireWorker:
+ workerLog.Trace().Msg("waking up because of
notification")
+ case <-timer.C:
+ workerLog.Trace().Msg("waking up for next
expire")
+ }
+ continue mainLoop
+ }
+ nextExpire := m.TimeoutExpectations(now)
+ workerLog.Debug().Dur("nextExpire", nextExpire).Msg("waiting
for next expire")
+ timer := time.NewTimer(nextExpire)
+ select {
+ case <-m.notifyExpireWorker:
+ workerLog.Trace().Msg("waking up because of
notification")
+ case <-timer.C:
+ workerLog.Trace().Msg("waking up for next expire")
+ }
+ }
}
-func (m *defaultCodec) Work() {
+func (m *defaultCodec) ReceiveWork() {
workerLog := m.log.With().Logger()
if !m.traceDefaultMessageCodecWorker {
workerLog = zerolog.Nop()
@@ -293,9 +382,9 @@ func (m *defaultCodec) Work() {
}
if m.running.Load() {
workerLog.Warn().Msg("Keep running")
- m.startWorker()
+ m.startReceive()
} else {
- workerLog.Info().Msg("Worker terminated")
+ workerLog.Info().Msg("receive worker terminated")
}
}()
@@ -311,11 +400,7 @@ mainLoop:
} else {
workerLog.Debug().Stringer("processingTime",
processingTime).Msg("no need to sleep") // we use stringer instead of Dur to
have it a bit more readable
}
- workerLog.Trace().Msg("Working")
- // Check for any expired expectations.
- // (Doing this outside the loop lets us expire expectations
even if no input is coming in)
- now := time.Now()
- lastLoopTime = now
+ workerLog.Trace().Msg("receive mainloop cycle")
// Guard against empty expectations
m.expectationsChangeMutex.RLock()
@@ -323,9 +408,15 @@ mainLoop:
m.expectationsChangeMutex.RUnlock()
if numberOfExpectations <= 0 && m.customMessageHandling == nil {
workerLog.Trace().Msg("no available expectations")
- continue mainLoop
+ timer := time.NewTimer(30 * time.Second)
+ select {
+ case <-m.notifyReceiveWorker:
+ workerLog.Trace().Msg("waking up because of
notification")
+ case <-timer.C:
+ workerLog.Trace().Msg("waking up for next
receive")
+ continue mainLoop
+ }
}
- m.TimeoutExpectations(now)
workerLog.Trace().Msg("Receiving message")
// Check for incoming messages.
diff --git a/plc4go/spi/default/DefaultCodec_test.go
b/plc4go/spi/default/DefaultCodec_test.go
index 074901ab9a..4e1f071444 100644
--- a/plc4go/spi/default/DefaultCodec_test.go
+++ b/plc4go/spi/default/DefaultCodec_test.go
@@ -22,17 +22,15 @@ package _default
import (
"context"
"fmt"
- "os"
- "sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
- "github.com/rs/zerolog/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
"github.com/apache/plc4x/plc4go/spi"
"github.com/apache/plc4x/plc4go/spi/options"
@@ -246,25 +244,25 @@ func TestNewDefaultCodec(t *testing.T) {
options []options.WithOption
}
tests := []struct {
- name string
- args args
- want DefaultCodec
+ name string
+ args args
+ wantAssert func(*testing.T, DefaultCodec) bool
}{
{
name: "create it",
- want: &defaultCodec{
- expectations: []spi.Expectation{},
- receiveTimeout: 10 * time.Second,
- log: log.Logger,
+ wantAssert: func(t *testing.T, got DefaultCodec) bool {
+ require.IsType(t, &defaultCodec{}, got)
+ d := got.(*defaultCodec)
+ assert.NotNil(t,
d.defaultIncomingMessageChannel)
+ assert.Equal(t, 10*time.Second,
d.receiveTimeout)
+ return true
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewDefaultCodec(tt.args.requirements,
tt.args.transportInstance, tt.args.options...)
- assert.NotNil(t,
got.(*defaultCodec).defaultIncomingMessageChannel)
- got.(*defaultCodec).defaultIncomingMessageChannel = nil
// Not comparable
- assert.Equalf(t, tt.want, got, "NewDefaultCodec(%v, %v,
%v)", tt.args.requirements, tt.args.transportInstance, tt.args.options)
+ assert.Truef(t, tt.wantAssert(t, got),
"NewDefaultCodec(%v, %v, %v)", tt.args.requirements, tt.args.transportInstance,
tt.args.options)
})
}
}
@@ -297,38 +295,45 @@ func Test_buildDefaultCodec(t *testing.T) {
options []options.WithOption
}
tests := []struct {
- name string
- args args
- want DefaultCodec
+ name string
+ args args
+ wantAssert func(*testing.T, DefaultCodec) bool
}{
{
name: "build it",
- want: &defaultCodec{
- expectations: []spi.Expectation{},
- receiveTimeout: 10 * time.Second,
- log: log.Logger,
+ wantAssert: func(t *testing.T, got DefaultCodec) bool {
+ require.IsType(t, &defaultCodec{}, got)
+ d := got.(*defaultCodec)
+ assert.NotNil(t,
d.defaultIncomingMessageChannel)
+ assert.Equal(t, 10*time.Second,
d.receiveTimeout)
+ return true
},
},
{
name: "build it with custom handler",
args: args{
options: []options.WithOption{
- withCustomMessageHandler{},
+ withCustomMessageHandler{
+ customMessageHandler: func(_
DefaultCodecRequirements, _ spi.Message) bool {
+ return true
+ },
+ },
},
},
- want: &defaultCodec{
- expectations: []spi.Expectation{},
- receiveTimeout: 10 * time.Second,
- log: log.Logger,
+ wantAssert: func(t *testing.T, got DefaultCodec) bool {
+ require.IsType(t, &defaultCodec{}, got)
+ d := got.(*defaultCodec)
+ assert.NotNil(t,
d.defaultIncomingMessageChannel)
+ assert.Equal(t, 10*time.Second,
d.receiveTimeout)
+ assert.NotNil(t, d.customMessageHandling)
+ return true
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got :=
buildDefaultCodec(tt.args.defaultCodecRequirements, tt.args.transportInstance,
tt.args.options...)
- assert.NotNil(t,
got.(*defaultCodec).defaultIncomingMessageChannel)
- got.(*defaultCodec).defaultIncomingMessageChannel = nil
// Not comparable
- assert.Equalf(t, tt.want, got, "buildDefaultCodec(%v,
%v, %v)", tt.args.defaultCodecRequirements, tt.args.transportInstance,
tt.args.options)
+ assert.Truef(t, tt.wantAssert(t, got),
"buildDefaultCodec(%v, %v, %v)", tt.args.defaultCodecRequirements,
tt.args.transportInstance, tt.args.options)
})
}
}
@@ -368,6 +373,8 @@ func Test_defaultCodec_Connect(t *testing.T) {
transportInstance:
tt.fields.transportInstance,
defaultIncomingMessageChannel:
tt.fields.defaultIncomingMessageChannel,
expectations:
tt.fields.expectations,
+ notifyExpireWorker: make(chan
struct{}, 100),
+ notifyReceiveWorker: make(chan
struct{}, 100),
customMessageHandling:
tt.fields.customMessageHandling,
log:
testutils.ProduceTestingLogger(t),
}
@@ -435,6 +442,8 @@ func Test_defaultCodec_ConnectWithContext(t *testing.T) {
transportInstance:
tt.fields.transportInstance,
defaultIncomingMessageChannel:
tt.fields.defaultIncomingMessageChannel,
expectations:
tt.fields.expectations,
+ notifyExpireWorker: make(chan
struct{}, 100),
+ notifyReceiveWorker: make(chan
struct{}, 100),
customMessageHandling:
tt.fields.customMessageHandling,
log:
testutils.ProduceTestingLogger(t),
}
@@ -486,6 +495,8 @@ func Test_defaultCodec_Disconnect(t *testing.T) {
transportInstance:
tt.fields.transportInstance,
defaultIncomingMessageChannel:
tt.fields.defaultIncomingMessageChannel,
expectations:
tt.fields.expectations,
+ notifyExpireWorker: make(chan
struct{}, 100),
+ notifyReceiveWorker: make(chan
struct{}, 100),
customMessageHandling:
tt.fields.customMessageHandling,
log:
testutils.ProduceTestingLogger(t),
}
@@ -536,6 +547,8 @@ func Test_defaultCodec_Expect(t *testing.T) {
transportInstance:
tt.fields.transportInstance,
defaultIncomingMessageChannel:
tt.fields.defaultIncomingMessageChannel,
expectations:
tt.fields.expectations,
+ notifyExpireWorker: make(chan
struct{}, 100),
+ notifyReceiveWorker: make(chan
struct{}, 100),
customMessageHandling:
tt.fields.customMessageHandling,
log:
testutils.ProduceTestingLogger(t),
}
@@ -569,6 +582,8 @@ func Test_defaultCodec_GetDefaultIncomingMessageChannel(t
*testing.T) {
transportInstance:
tt.fields.transportInstance,
defaultIncomingMessageChannel:
tt.fields.defaultIncomingMessageChannel,
expectations:
tt.fields.expectations,
+ notifyExpireWorker: make(chan
struct{}, 100),
+ notifyReceiveWorker: make(chan
struct{}, 100),
customMessageHandling:
tt.fields.customMessageHandling,
log:
testutils.ProduceTestingLogger(t),
}
@@ -602,6 +617,8 @@ func Test_defaultCodec_GetTransportInstance(t *testing.T) {
transportInstance:
tt.fields.transportInstance,
defaultIncomingMessageChannel:
tt.fields.defaultIncomingMessageChannel,
expectations:
tt.fields.expectations,
+ notifyExpireWorker: make(chan
struct{}, 100),
+ notifyReceiveWorker: make(chan
struct{}, 100),
customMessageHandling:
tt.fields.customMessageHandling,
log:
testutils.ProduceTestingLogger(t),
}
@@ -829,6 +846,8 @@ func Test_defaultCodec_HandleMessages(t *testing.T) {
transportInstance:
tt.fields.transportInstance,
defaultIncomingMessageChannel:
tt.fields.defaultIncomingMessageChannel,
expectations:
tt.fields.expectations,
+ notifyExpireWorker: make(chan
struct{}, 100),
+ notifyReceiveWorker: make(chan
struct{}, 100),
customMessageHandling:
tt.fields.customMessageHandling,
log:
testutils.ProduceTestingLogger(t),
}
@@ -862,6 +881,8 @@ func Test_defaultCodec_IsRunning(t *testing.T) {
transportInstance:
tt.fields.transportInstance,
defaultIncomingMessageChannel:
tt.fields.defaultIncomingMessageChannel,
expectations:
tt.fields.expectations,
+ notifyExpireWorker: make(chan
struct{}, 100),
+ notifyReceiveWorker: make(chan
struct{}, 100),
customMessageHandling:
tt.fields.customMessageHandling,
log:
testutils.ProduceTestingLogger(t),
}
@@ -938,6 +959,8 @@ func Test_defaultCodec_SendRequest(t *testing.T) {
transportInstance:
tt.fields.transportInstance,
defaultIncomingMessageChannel:
tt.fields.defaultIncomingMessageChannel,
expectations:
tt.fields.expectations,
+ notifyExpireWorker: make(chan
struct{}, 100),
+ notifyReceiveWorker: make(chan
struct{}, 100),
customMessageHandling:
tt.fields.customMessageHandling,
log:
testutils.ProduceTestingLogger(t),
}
@@ -963,28 +986,30 @@ func Test_defaultCodec_TimeoutExpectations(t *testing.T) {
fields fields
args args
setup func(t *testing.T, fields *fields, args *args)
+ want time.Duration
}{
{
name: "timeout it (no expectations)",
+ want: 30 * time.Second,
},
{
name: "timeout some",
fields: fields{
expectations: []spi.Expectation{
&defaultExpectation{ // Expired
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return nil
},
},
&defaultExpectation{ // Expired errors
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return errors.New("yep")
},
},
&defaultExpectation{ // Fine
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return errors.New("yep")
},
@@ -992,7 +1017,7 @@ func Test_defaultCodec_TimeoutExpectations(t *testing.T) {
},
&defaultExpectation{ // Context error
Context: func() context.Context
{
- ctx, cancelFunc :=
context.WithCancel(context.Background())
+ ctx, cancelFunc :=
context.WithCancel(t.Context())
cancelFunc() // Cancel
it instantly
return ctx
}(),
@@ -1004,6 +1029,7 @@ func Test_defaultCodec_TimeoutExpectations(t *testing.T) {
},
},
args: args{now: time.Time{}.Add(2 * time.Hour)},
+ want: time.Until(time.Time{}.Add(3 * time.Hour)),
},
{
name: "timeout some (ensure everyone is called)",
@@ -1024,21 +1050,21 @@ func Test_defaultCodec_TimeoutExpectations(t
*testing.T) {
})
fields.expectations = []spi.Expectation{
&defaultExpectation{ // Expired
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
handle1.Store(true)
return nil
},
},
&defaultExpectation{ // Expired errors
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
handle2.Store(true)
return errors.New("yep")
},
},
&defaultExpectation{ // Fine
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
handle3.Store(true)
return errors.New("yep")
@@ -1047,7 +1073,7 @@ func Test_defaultCodec_TimeoutExpectations(t *testing.T) {
},
&defaultExpectation{ // Context error
Context: func() context.Context
{
- ctx, cancelFunc :=
context.WithCancel(context.Background())
+ ctx, cancelFunc :=
context.WithCancel(t.Context())
cancelFunc() // Cancel
it instantly
return ctx
}(),
@@ -1058,7 +1084,7 @@ func Test_defaultCodec_TimeoutExpectations(t *testing.T) {
Expiration: time.Time{}.Add(3 *
time.Hour),
},
&defaultExpectation{ // Fine
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
handle5.Store(true)
return errors.New("yep")
@@ -1067,6 +1093,7 @@ func Test_defaultCodec_TimeoutExpectations(t *testing.T) {
},
}
},
+ want: time.Until(time.Time{}.Add(3 * time.Hour)),
},
}
for _, tt := range tests {
@@ -1079,20 +1106,18 @@ func Test_defaultCodec_TimeoutExpectations(t
*testing.T) {
transportInstance:
tt.fields.transportInstance,
defaultIncomingMessageChannel:
tt.fields.defaultIncomingMessageChannel,
expectations:
tt.fields.expectations,
+ notifyExpireWorker: make(chan
struct{}, 100),
+ notifyReceiveWorker: make(chan
struct{}, 100),
customMessageHandling:
tt.fields.customMessageHandling,
log:
testutils.ProduceTestingLogger(t),
}
- m.TimeoutExpectations(tt.args.now)
- // TODO: handle error is called async so we sleep here
a bit. Not sure if we want to sync something here at all
- time.Sleep(100 * time.Millisecond)
+ assert.Equalf(t, tt.want,
m.TimeoutExpectations(tt.args.now), "TimeoutExpectations(%v)", tt.args.now)
+ m.wg.Wait()
})
}
}
-func Test_defaultCodec_Work(t *testing.T) {
- if os.Getenv("ENABLE_RANDOMLY_FAILING_TESTS") == "" {
- t.Skip("Skipping randomly failing tests")
- }
+func Test_defaultCodec_ReceiveWork(t *testing.T) {
type fields struct {
DefaultCodecRequirements DefaultCodecRequirements
transportInstance transports.TransportInstance
@@ -1123,19 +1148,19 @@ func Test_defaultCodec_Work(t *testing.T) {
fields: fields{
expectations: []spi.Expectation{
&defaultExpectation{ // Expired
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return nil
},
},
&defaultExpectation{ // Expired errors
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return errors.New("yep")
},
},
&defaultExpectation{ // Fine
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return errors.New("yep")
},
@@ -1143,7 +1168,7 @@ func Test_defaultCodec_Work(t *testing.T) {
},
&defaultExpectation{ // Context error
Context: func() context.Context
{
- ctx, cancelFunc :=
context.WithCancel(context.Background())
+ ctx, cancelFunc :=
context.WithCancel(t.Context())
cancelFunc() // Cancel
it instantly
return ctx
}(),
@@ -1169,19 +1194,19 @@ func Test_defaultCodec_Work(t *testing.T) {
fields: fields{
expectations: []spi.Expectation{
&defaultExpectation{ // Expired
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return nil
},
},
&defaultExpectation{ // Expired errors
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return errors.New("yep")
},
},
&defaultExpectation{ // Fine
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return errors.New("yep")
},
@@ -1189,7 +1214,7 @@ func Test_defaultCodec_Work(t *testing.T) {
},
&defaultExpectation{ // Context error
Context: func() context.Context
{
- ctx, cancelFunc :=
context.WithCancel(context.Background())
+ ctx, cancelFunc :=
context.WithCancel(t.Context())
cancelFunc() // Cancel
it instantly
return ctx
}(),
@@ -1215,19 +1240,19 @@ func Test_defaultCodec_Work(t *testing.T) {
fields: fields{
expectations: []spi.Expectation{
&defaultExpectation{ // Expired
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return nil
},
},
&defaultExpectation{ // Expired errors
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return errors.New("yep")
},
},
&defaultExpectation{ // Fine
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return errors.New("yep")
},
@@ -1235,7 +1260,7 @@ func Test_defaultCodec_Work(t *testing.T) {
},
&defaultExpectation{ // Context error
Context: func() context.Context
{
- ctx, cancelFunc :=
context.WithCancel(context.Background())
+ ctx, cancelFunc :=
context.WithCancel(t.Context())
cancelFunc() // Cancel
it instantly
return ctx
}(),
@@ -1262,7 +1287,7 @@ func Test_defaultCodec_Work(t *testing.T) {
defaultIncomingMessageChannel: make(chan
spi.Message, 1),
expectations: []spi.Expectation{
&defaultExpectation{ // Fine
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return errors.New("yep")
},
@@ -1285,19 +1310,19 @@ func Test_defaultCodec_Work(t *testing.T) {
fields: fields{
expectations: []spi.Expectation{
&defaultExpectation{ // Expired
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return nil
},
},
&defaultExpectation{ // Expired errors
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return errors.New("yep")
},
},
&defaultExpectation{ // Fine
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return errors.New("yep")
},
@@ -1305,7 +1330,7 @@ func Test_defaultCodec_Work(t *testing.T) {
},
&defaultExpectation{ // Context error
Context: func() context.Context
{
- ctx, cancelFunc :=
context.WithCancel(context.Background())
+ ctx, cancelFunc :=
context.WithCancel(t.Context())
cancelFunc() // Cancel
it instantly
return ctx
}(),
@@ -1334,19 +1359,19 @@ func Test_defaultCodec_Work(t *testing.T) {
},
expectations: []spi.Expectation{
&defaultExpectation{ // Expired
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return nil
},
},
&defaultExpectation{ // Expired errors
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return errors.New("yep")
},
},
&defaultExpectation{ // Fine
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return errors.New("yep")
},
@@ -1354,7 +1379,7 @@ func Test_defaultCodec_Work(t *testing.T) {
},
&defaultExpectation{ // Context error
Context: func() context.Context
{
- ctx, cancelFunc :=
context.WithCancel(context.Background())
+ ctx, cancelFunc :=
context.WithCancel(t.Context())
cancelFunc() // Cancel
it instantly
return ctx
}(),
@@ -1383,19 +1408,19 @@ func Test_defaultCodec_Work(t *testing.T) {
},
expectations: []spi.Expectation{
&defaultExpectation{ // Expired
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return nil
},
},
&defaultExpectation{ // Expired errors
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return errors.New("yep")
},
},
&defaultExpectation{ // Fine
- Context: context.Background(),
+ Context: t.Context(),
HandleError: func(err error)
error {
return errors.New("yep")
},
@@ -1403,7 +1428,7 @@ func Test_defaultCodec_Work(t *testing.T) {
},
&defaultExpectation{ // Context error
Context: func() context.Context
{
- ctx, cancelFunc :=
context.WithCancel(context.Background())
+ ctx, cancelFunc :=
context.WithCancel(t.Context())
cancelFunc() // Cancel
it instantly
return ctx
}(),
@@ -1435,25 +1460,30 @@ func Test_defaultCodec_Work(t *testing.T) {
transportInstance:
tt.fields.transportInstance,
defaultIncomingMessageChannel:
tt.fields.defaultIncomingMessageChannel,
expectations:
tt.fields.expectations,
+ notifyExpireWorker: make(chan
struct{}, 100),
+ notifyReceiveWorker: make(chan
struct{}, 100),
customMessageHandling:
tt.fields.customMessageHandling,
log:
testutils.ProduceTestingLogger(t),
}
if tt.manipulator != nil {
tt.manipulator(t, m)
}
- var wg sync.WaitGroup
- t.Cleanup(wg.Wait)
- wg.Go(func() {
+ t.Cleanup(m.wg.Wait)
+ m.wg.Go(func() {
// Stop after 200ms
- time.Sleep(200 * time.Millisecond)
+ timer := time.NewTimer(200 * time.Millisecond)
+ select {
+ case <-timer.C:
+ case <-t.Context().Done():
+ }
m.running.Store(false)
})
- m.Work()
+ m.ReceiveWork()
})
}
}
-func Test_defaultCodec_startWorker(t *testing.T) {
+func Test_defaultCodec_startWorkers(t *testing.T) {
type fields struct {
DefaultCodecRequirements DefaultCodecRequirements
transportInstance transports.TransportInstance
@@ -1483,7 +1513,116 @@ func Test_defaultCodec_startWorker(t *testing.T) {
traceDefaultMessageCodecWorker:
tt.fields.traceDefaultMessageCodecWorker,
log:
testutils.ProduceTestingLogger(t),
}
- m.startWorker()
+ m.startWorkers()
})
}
}
+
+func Test_defaultCodec_integration(t *testing.T) {
+ mockDefaultCodecRequirements := NewMockDefaultCodecRequirements(t)
+ {
+ expect := mockDefaultCodecRequirements.EXPECT()
+ message := NewMockMessage(t)
+ {
+ expect := message.EXPECT()
+ expect.String().Return("message for " + t.Name())
+ }
+ expect.Receive().RunAndReturn(func() (spi.Message, error) {
+ // Simulate a bit read delay
+ timer := time.NewTimer(100 * time.Millisecond)
+ select {
+ case <-timer.C:
+ case <-t.Context().Done():
+ }
+ if err := t.Context().Err(); err != nil {
+ return nil, err
+ }
+ return message, nil
+ })
+ }
+ mockTransportInstance := NewMockTransportInstance(t)
+ {
+ expect := mockTransportInstance.EXPECT()
+ expect.IsConnected().Return(true)
+ expect.Close().Return(nil)
+ }
+ sut := NewDefaultCodec(mockDefaultCodecRequirements,
mockTransportInstance,
+ options.WithCustomLogger(testutils.ProduceTestingLogger(t)),
+ options.WithTraceDefaultMessageCodecWorker(true),
+ )
+ t.Cleanup(func() {
+ _ = sut.Disconnect()
+ })
+ // First expect
+ var firstHandled bool
+ sut.Expect(t.Context(),
+ func(message spi.Message) bool {
+ t.Log("accepts message", message)
+ return true
+ }, func(message spi.Message) error {
+ t.Log("handle message", message)
+ firstHandled = true
+ return nil
+ }, func(err error) error {
+ t.Log("error", err)
+ return nil
+ },
+ 500*time.Millisecond)
+ // Second expect
+ var secondHandled bool
+ sut.Expect(t.Context(),
+ func(message spi.Message) bool {
+ t.Log("accepts message", message)
+ return true
+ }, func(message spi.Message) error {
+ t.Log("handle message", message)
+ secondHandled = true
+ return nil
+ }, func(err error) error {
+ t.Log("error", err)
+ return nil
+ },
+ 1500*time.Millisecond)
+ // Third expect
+ var thridErrorCalled bool
+ sut.Expect(t.Context(),
+ func(message spi.Message) bool {
+ t.Log("does not accept message", message)
+ return false
+ }, func(message spi.Message) error {
+ t.Error("should not be called")
+ return nil
+ }, func(err error) error {
+ thridErrorCalled = true
+ return nil
+ },
+ 1500*time.Millisecond)
+ // Fourth expect
+ var fourthHandled bool
+ sut.Expect(t.Context(),
+ func(message spi.Message) bool {
+ t.Log("accepts message", message)
+ return true
+ }, func(message spi.Message) error {
+ t.Log("handle message", message)
+ fourthHandled = true
+ return nil
+ }, func(err error) error {
+ t.Log("error", err)
+ return nil
+ },
+ 3000*time.Millisecond)
+
+ err := sut.ConnectWithContext(t.Context())
+ assert.NoError(t, err)
+ timer := time.NewTimer(10 * time.Second)
+ select {
+ case <-timer.C:
+ case <-t.Context().Done():
+ }
+ assert.NoError(t, sut.Disconnect())
+ assert.True(t, firstHandled)
+ assert.True(t, secondHandled)
+ assert.True(t, thridErrorCalled) // because of our disconnect
+ assert.True(t, fourthHandled)
+}