This is an automated email from the ASF dual-hosted git repository.
mrproliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/skywalking-go.git
The following commit(s) were added to refs/heads/main by this push:
new dceeb56 Fix data race in plugins (#248)
dceeb56 is described below
commit dceeb5605644b51198c847007e7addbc148fa630
Author: mrproliu <[email protected]>
AuthorDate: Sat Jun 6 21:26:13 2026 +0800
Fix data race in plugins (#248)
---
CHANGES.md | 7 +
plugins/amqp/general_consumer.go | 37 ++++-
plugins/amqp/general_consumer_test.go | 76 +++++++++
plugins/core/operator/tracing.go | 3 +
plugins/core/reporter/grpc/grpc.go | 63 +++++---
plugins/core/reporter/kafka/kafka.go | 107 +++++++------
plugins/core/span_default.go | 17 +-
plugins/core/tracing.go | 32 ++++
plugins/core/tracing/api.go | 16 ++
plugins/core/tracing_extract_test.go | 124 +++++++++++++++
plugins/gorm/entry/callback.go | 20 ++-
plugins/gorm/entry/callback_test.go | 177 +++++++++++++++++++++
plugins/grpc/client_finish_interceptor.go | 20 ++-
plugins/grpc/client_finish_interceptor_test.go | 91 +++++++++++
plugins/grpc/client_recvmsg_interceptor.go | 2 +-
plugins/grpc/client_streaming_interceptor.go | 8 +-
plugins/microv4/server/structure.go | 9 +-
plugins/microv4/util/socket/accept_interceptor.go | 7 +
plugins/microv4/util/socket/close_interceptor.go | 13 +-
.../microv4/util/socket/close_interceptor_test.go | 107 +++++++++++++
plugins/mongo/mongo/interceptor.go | 9 +-
plugins/mongo/mongo/interceptor_test.go | 144 +++++++++++++++++
plugins/mux/serve_interceptor.go | 2 +-
plugins/mux/serve_interceptor_test.go | 102 ++++++++++++
plugins/pulsar/pulsar/send_async_producer.go | 55 ++++---
plugins/pulsar/pulsar/send_async_producer_test.go | 126 +++++++++++++++
plugins/rocketmq/consumer/consumer.go | 65 +++++---
plugins/rocketmq/consumer/consumer_test.go | 138 ++++++++++++++++
plugins/rocketmq/producer/async_producer.go | 68 +++++---
plugins/rocketmq/producer/async_producer_test.go | 168 +++++++++++++++++++
30 files changed, 1649 insertions(+), 164 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 0fcda9d..6995609 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -24,6 +24,13 @@ Release Notes.
* Fix data race when sending trace data to reporter.
* Fix multiple data races in span lifecycle, correlation context and segment
collection.
* Add recover protection for the metrics, profile and segment-transform
goroutines.
+* Fix the RocketMQ batch consumer span: report once with one segment reference
per message (new `ExtractContext` API).
+* Fix nil dereference and wrong span ownership in the RocketMQ/Pulsar async
producer callbacks.
+* Fix concurrent finish flags of the gRPC streaming client and the go-micro
socket close.
+* Fix the MongoDB command span to complete through the async API (events may
fire on different goroutines).
+* Fix the gorm span storage to be per-statement and the mux response writer
wrapping a nil writer.
+* Add recover protection for the kafka instance-check and gRPC profile-fetch
goroutines.
+* Fix unsynchronized consumer-tag map access in the AMQP plugin (fatal
concurrent map read/write).
#### Issues and PR
- All issues are
[here](https://github.com/apache/skywalking/milestone/238?closed=1)
diff --git a/plugins/amqp/general_consumer.go b/plugins/amqp/general_consumer.go
index 89c4cc8..5c3bfc9 100644
--- a/plugins/amqp/general_consumer.go
+++ b/plugins/amqp/general_consumer.go
@@ -21,6 +21,7 @@ import (
"fmt"
"os"
"strconv"
+ "sync"
"sync/atomic"
"github.com/rabbitmq/amqp091-go"
@@ -41,7 +42,32 @@ const (
)
var consumerSeq uint64
-var queueConsumerTagMapping = make(map[string]string)
+
+// queueConsumerTagMapping is touched from three goroutines (Consume writes,
+// the SDK delivery dispatch reads, Close deletes) - unsynchronized access is
+// a fatal concurrent map read/write, so it only goes through the accessors.
+var (
+ queueConsumerTagMapping = make(map[string]string)
+ queueConsumerTagLock sync.RWMutex
+)
+
+func registerConsumerQueue(consumerTag, queue string) {
+ queueConsumerTagLock.Lock()
+ defer queueConsumerTagLock.Unlock()
+ queueConsumerTagMapping[consumerTag] = queue
+}
+
+func consumerQueue(consumerTag string) string {
+ queueConsumerTagLock.RLock()
+ defer queueConsumerTagLock.RUnlock()
+ return queueConsumerTagMapping[consumerTag]
+}
+
+func removeConsumerQueue(consumerTag string) {
+ queueConsumerTagLock.Lock()
+ defer queueConsumerTagLock.Unlock()
+ delete(queueConsumerTagMapping, consumerTag)
+}
func GeneralConsumersSendAfterInvoke(invocation operator.Invocation, results
...interface{}) error {
if foundConsumer := results[0].(bool); !foundConsumer {
@@ -49,7 +75,8 @@ func GeneralConsumersSendAfterInvoke(invocation
operator.Invocation, results ...
}
consumerTag, _ := invocation.Args()[0].(string)
delivery, _ := invocation.Args()[1].(*Delivery)
- operationName := amqpConsumerPrefix +
queueConsumerTagMapping[consumerTag] + "/" + consumerTag + amqpConsumerSuffix
+ queue := consumerQueue(consumerTag)
+ operationName := amqpConsumerPrefix + queue + "/" + consumerTag +
amqpConsumerSuffix
channel, _ := delivery.Acknowledger.(*nativeChannel)
peer := getPeerInfo(channel.connection)
@@ -59,7 +86,7 @@ func GeneralConsumersSendAfterInvoke(invocation
operator.Invocation, results ...
}, tracing.WithLayer(tracing.SpanLayerMQ),
tracing.WithComponent(ConsumerComponentID),
tracing.WithTag(tracing.TagMQBroker, peer),
- tracing.WithTag(tracing.TagMQQueue,
queueConsumerTagMapping[consumerTag]),
+ tracing.WithTag(tracing.TagMQQueue, queue),
tracing.WithTag(tracing.TagMQMsgID, delivery.MessageId),
tracing.WithTag(tagMQConsumerTag, consumerTag),
tracing.WithTag(tagMQCorrelationID, delivery.CorrelationId),
@@ -80,7 +107,7 @@ func GeneralConsumerBeforeInvoke(invocation
operator.Invocation, args amqp091.Ta
if consumerTag == "" {
consumerTag = uniqueConsumerTag()
}
- queueConsumerTagMapping[consumerTag] = queue
+ registerConsumerQueue(consumerTag, queue)
return nil
}
@@ -89,7 +116,7 @@ func GeneralConsumerCloseBeforeInvoke(invocation
operator.Invocation) error {
consumers.Lock()
defer consumers.Unlock()
for consumerTag := range consumers.chans {
- delete(queueConsumerTagMapping, consumerTag)
+ removeConsumerQueue(consumerTag)
}
return nil
}
diff --git a/plugins/amqp/general_consumer_test.go
b/plugins/amqp/general_consumer_test.go
new file mode 100644
index 0000000..2f005db
--- /dev/null
+++ b/plugins/amqp/general_consumer_test.go
@@ -0,0 +1,76 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package amqp
+
+import (
+ "fmt"
+ "sync"
+ "testing"
+)
+
+// TestConsumerQueueMappingConcurrency hammers the consumer-tag mapping from
+// the three goroutine roles that touch it in production: Consume registers on
+// the user goroutine, the delivery dispatch reads on the SDK goroutine and
+// Close deletes. The mapping used to be a plain map, which is a fatal
+// (unrecoverable) concurrent map read/write.
+func TestConsumerQueueMappingConcurrency(t *testing.T) {
+ const workers = 8
+ const iterations = 500
+
+ var wg sync.WaitGroup
+ for w := 0; w < workers; w++ {
+ tag := fmt.Sprintf("ctag-%d", w)
+ wg.Add(3)
+ go func() { // Consume path
+ defer wg.Done()
+ for i := 0; i < iterations; i++ {
+ registerConsumerQueue(tag, "queue-A")
+ }
+ }()
+ go func() { // delivery dispatch path
+ defer wg.Done()
+ for i := 0; i < iterations; i++ {
+ if q := consumerQueue(tag); q != "" && q !=
"queue-A" {
+ t.Errorf("unexpected queue %q", q)
+ }
+ }
+ }()
+ go func() { // Close path
+ defer wg.Done()
+ for i := 0; i < iterations; i++ {
+ removeConsumerQueue(tag)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+// TestConsumerQueueMappingBasics pins the accessor semantics.
+func TestConsumerQueueMappingBasics(t *testing.T) {
+ registerConsumerQueue("tag-1", "orders")
+ if q := consumerQueue("tag-1"); q != "orders" {
+ t.Fatalf("expected orders, got %q", q)
+ }
+ if q := consumerQueue("missing"); q != "" {
+ t.Fatalf("missing tag must yield empty queue, got %q", q)
+ }
+ removeConsumerQueue("tag-1")
+ if q := consumerQueue("tag-1"); q != "" {
+ t.Fatalf("removed tag must yield empty queue, got %q", q)
+ }
+}
diff --git a/plugins/core/operator/tracing.go b/plugins/core/operator/tracing.go
index cb11d57..d1c90c1 100644
--- a/plugins/core/operator/tracing.go
+++ b/plugins/core/operator/tracing.go
@@ -21,6 +21,9 @@ type TracingOperator interface {
CreateEntrySpan(operationName string, extractor interface{}, opts
...interface{}) (s interface{}, err error)
CreateLocalSpan(operationName string, opts ...interface{}) (s
interface{}, err error)
CreateExitSpan(operationName, peer string, injector interface{}, opts
...interface{}) (s interface{}, err error)
+ // ExtractContext attaches the context carried by extractor to the
current
+ // active entry span as one more segment reference (batch consumers).
+ ExtractContext(extractor interface{}) error
ActiveSpan() interface{} // to Span
GetRuntimeContextValue(key string) interface{}
diff --git a/plugins/core/reporter/grpc/grpc.go
b/plugins/core/reporter/grpc/grpc.go
index 32dab07..6042c23 100644
--- a/plugins/core/reporter/grpc/grpc.go
+++ b/plugins/core/reporter/grpc/grpc.go
@@ -475,34 +475,49 @@ func (r *gRPCReporter) fetchProfileTasks() {
}
go func() {
for {
- // Construct the request
- req := &profilev3.ProfileTaskCommandQuery{
- Service: r.entity.ServiceName,
- ServiceInstance: r.entity.ServiceInstanceName,
- LastCommandTime: r.lastProfileCommandTime,
- }
+ // The recover wraps a single iteration: this
long-lived goroutine
+ // has no other protection and a panic while handling
the profile
+ // commands would otherwise kill the whole process.
+ func() {
+ defer func() {
+ if rec := recover(); rec != nil {
+ r.logger.Errorf("gRPCReporter
recovered from panic while fetching profile tasks: %v", rec)
+ }
+ }()
+ r.fetchProfileTasksOnce()
+ }()
+ time.Sleep(r.profileFetchInterval)
+ }
+ }()
+}
- // Pull tasks
- resp, err :=
r.profileTaskClient.GetProfileTaskCommands(context.Background(), req)
- if err != nil {
- r.logger.Errorf("fetch profile task error: %v",
err)
- time.Sleep(r.profileFetchInterval)
- continue
- }
+// fetchProfileTasksOnce pulls and handles the pending profile task commands of
+// one polling round.
+func (r *gRPCReporter) fetchProfileTasksOnce() {
+ // Construct the request
+ req := &profilev3.ProfileTaskCommandQuery{
+ Service: r.entity.ServiceName,
+ ServiceInstance: r.entity.ServiceInstanceName,
+ LastCommandTime: r.lastProfileCommandTime,
+ }
- // Handle all returned commands
- for _, cmd := range resp.Commands {
- nt := r.handleProfileTask(cmd,
r.lastProfileCommandTime)
- if nt > r.lastProfileCommandTime {
- r.lastProfileCommandTime = nt
- }
- }
+ // Pull tasks
+ resp, err :=
r.profileTaskClient.GetProfileTaskCommands(context.Background(), req)
+ if err != nil {
+ r.logger.Errorf("fetch profile task error: %v", err)
+ return
+ }
- // Remove completed tasks
- r.profileTaskManager.RemoveProfileTask()
- time.Sleep(r.profileFetchInterval)
+ // Handle all returned commands
+ for _, cmd := range resp.Commands {
+ nt := r.handleProfileTask(cmd, r.lastProfileCommandTime)
+ if nt > r.lastProfileCommandTime {
+ r.lastProfileCommandTime = nt
}
- }()
+ }
+
+ // Remove completed tasks
+ r.profileTaskManager.RemoveProfileTask()
}
func (r *gRPCReporter) AddProfileTaskManager(p reporter.ProfileTaskManager) {
diff --git a/plugins/core/reporter/kafka/kafka.go
b/plugins/core/reporter/kafka/kafka.go
index 558f544..c0d4c82 100644
--- a/plugins/core/reporter/kafka/kafka.go
+++ b/plugins/core/reporter/kafka/kafka.go
@@ -308,56 +308,73 @@ func (r *kafkaReporter) check() {
time.Sleep(r.checkInterval)
instancePropertiesSubmitted := false
for {
- if !instancePropertiesSubmitted {
- instanceProperties :=
&managementv3.InstanceProperties{
- Service: r.entity.ServiceName,
- ServiceInstance:
r.entity.ServiceInstanceName,
- Properties: r.entity.Props,
- }
- payload, err :=
proto.Marshal(instanceProperties)
- if err != nil {
- r.logger.Errorf("marshal instance
properties error %v", err)
- time.Sleep(r.checkInterval)
- continue
- }
- ctx := context.WithValue(context.Background(),
internalReporterContextKey, true)
- err = r.writer.WriteMessages(ctx, kafka.Message{
- Topic: r.topicManagement,
- Key: []byte(topicKeyRegister +
r.entity.ServiceInstanceName),
- Value: payload,
- })
- if err != nil {
- r.logger.Errorf("send instance
properties to kafka error %v", err)
- time.Sleep(r.checkInterval)
- continue
- }
- instancePropertiesSubmitted = true
- }
-
- ping := &managementv3.InstancePingPkg{
- Service: r.entity.ServiceName,
- ServiceInstance: r.entity.ServiceInstanceName,
- }
- payload, err := proto.Marshal(ping)
- if err != nil {
- r.logger.Errorf("marshal instance ping error
%v", err)
- time.Sleep(r.checkInterval)
- continue
- }
- ctx := context.WithValue(context.Background(),
internalReporterContextKey, true)
- err = r.writer.WriteMessages(ctx, kafka.Message{
- Topic: r.topicManagement,
- Key: []byte(r.entity.ServiceInstanceName),
- Value: payload,
- })
- if err != nil {
- r.logger.Errorf("send instance ping to kafka
error %v", err)
- }
+ // The recover wraps a single iteration: this
long-lived goroutine
+ // has no other protection and a panic in the kafka
writer would
+ // otherwise kill the whole process.
+ func() {
+ defer func() {
+ if rec := recover(); rec != nil {
+ r.logger.Errorf("kafkaReporter
recovered from panic while checking the instance: %v", rec)
+ }
+ }()
+ instancePropertiesSubmitted =
r.checkOnce(instancePropertiesSubmitted)
+ }()
time.Sleep(r.checkInterval)
}
}()
}
+// checkOnce submits the instance properties (until that succeeded once) and
+// the keep-alive ping of one round; it returns whether the instance
+// properties have been submitted.
+func (r *kafkaReporter) checkOnce(instancePropertiesSubmitted bool) bool {
+ if !instancePropertiesSubmitted {
+ instanceProperties := &managementv3.InstanceProperties{
+ Service: r.entity.ServiceName,
+ ServiceInstance: r.entity.ServiceInstanceName,
+ Properties: r.entity.Props,
+ }
+ payload, err := proto.Marshal(instanceProperties)
+ if err != nil {
+ r.logger.Errorf("marshal instance properties error %v",
err)
+ return false
+ }
+ ctx := context.WithValue(context.Background(),
internalReporterContextKey, true)
+ err = r.writer.WriteMessages(ctx, kafka.Message{
+ Topic: r.topicManagement,
+ Key: []byte(topicKeyRegister +
r.entity.ServiceInstanceName),
+ Value: payload,
+ })
+ if err != nil {
+ r.logger.Errorf("send instance properties to kafka
error %v", err)
+ return false
+ }
+ }
+
+ // this point is only reachable once the instance properties have been
+ // submitted (this round or a previous one) - every failure path above
+ // returns false first - so the returns below all report true
+ ping := &managementv3.InstancePingPkg{
+ Service: r.entity.ServiceName,
+ ServiceInstance: r.entity.ServiceInstanceName,
+ }
+ payload, err := proto.Marshal(ping)
+ if err != nil {
+ r.logger.Errorf("marshal instance ping error %v", err)
+ return true
+ }
+ ctx := context.WithValue(context.Background(),
internalReporterContextKey, true)
+ err = r.writer.WriteMessages(ctx, kafka.Message{
+ Topic: r.topicManagement,
+ Key: []byte(r.entity.ServiceInstanceName),
+ Value: payload,
+ })
+ if err != nil {
+ r.logger.Errorf("send instance ping to kafka error %v", err)
+ }
+ return true
+}
+
func (r *kafkaReporter) ConnectionStatus() reporter.ConnectionStatus {
return r.connectionStatus
}
diff --git a/plugins/core/span_default.go b/plugins/core/span_default.go
index a7c1525..b5f7c8d 100644
--- a/plugins/core/span_default.go
+++ b/plugins/core/span_default.go
@@ -47,10 +47,11 @@ type DefaultSpan struct {
AsyncModeFinished bool
// opLock guards the mutable fields above (OperationName, Peer, Layer,
- // ComponentID, Tags, Logs, IsError, EndTime, the async flags) together
with
- // the ended flag. SpanType, Parent, Refs, StartTime and tracer are
+ // ComponentID, Tags, Logs, Refs, IsError, EndTime, the async flags)
+ // together with the ended flag. SpanType, Parent, StartTime and tracer
are
// write-once during construction - before the span is ever shared -
and are
// therefore read without the lock
(IsEntry/IsExit/ParentSpan/StartTime).
+ // Refs is also appended by ExtractContext; reporting reads it after
the freeze.
// It must stay a pointer: DefaultSpan is copied by value when it is
embedded
// into SegmentSpanImpl/SnapshotSpan, and an embedded sync.Mutex value
would
// trip the go vet copylocks check.
@@ -259,6 +260,18 @@ func (ds *DefaultSpan) endAndFreeze() bool {
return true
}
+// appendRef attaches one more segment reference to this span (see
+// Tracer.ExtractContext); late appends after the freeze are dropped.
+func (ds *DefaultSpan) appendRef(ref reporter.SpanContext) {
+ ds.opLock.Lock()
+ defer ds.opLock.Unlock()
+ if ds.ended {
+ ds.logDroppedWrite("ref", "")
+ return
+ }
+ ds.Refs = append(ds.Refs, ref)
+}
+
// enterReuse registers one more owner of this span. It is called from the span
// reuse branches of CreateEntrySpan/CreateExitSpan when the active span is
// handed to a nested plugin; that owner's End then only decrements the counter
diff --git a/plugins/core/tracing.go b/plugins/core/tracing.go
index 7c96fa9..7308fb0 100644
--- a/plugins/core/tracing.go
+++ b/plugins/core/tracing.go
@@ -169,6 +169,38 @@ func (t *Tracer) CreateExitSpan(operationName, peer
string, injector interface{}
return span, nil
}
+// ExtractContext decodes the propagated context carried by extractor and
+// attaches it to the current active entry span as one more segment reference,
+// merging the carried correlation values - the equivalent of the Java agent's
+// ContextManager.extract, used by batch consumers to link every upstream
+// message to the single entry span.
+func (t *Tracer) ExtractContext(extractor interface{}) error {
+ ctx := getTracingContext()
+ if ctx == nil || ctx.ActiveSpan() == nil {
+ return nil
+ }
+ segmentSpan, ok := ctx.ActiveSpan().(SegmentSpan)
+ if !ok || !segmentSpan.GetDefaultSpan().IsEntry() {
+ // only an entry span carries upstream references, mirroring
the Java agent
+ return nil
+ }
+ ref := &SpanContext{}
+ if err := ref.Decode(extractor.(tracing.ExtractorWrapper).Fun()); err
!= nil {
+ return err
+ }
+ if !ref.Valid {
+ return nil
+ }
+ segmentSpan.GetDefaultSpan().appendRef(ref)
+ // merge the carried correlation into the segment, last write wins
+ // (mirroring the Java agent's extractCorrelationTo)
+ correlation := segmentSpan.GetSegmentContext().CorrelationContext
+ for k, v := range ref.CorrelationContext {
+ correlation.Set(k, v)
+ }
+ return nil
+}
+
func (t *Tracer) ActiveSpan() interface{} {
ctx := getTracingContext()
if ctx == nil || ctx.ActiveSpan() == nil {
diff --git a/plugins/core/tracing/api.go b/plugins/core/tracing/api.go
index 4c8c3fd..9bea821 100644
--- a/plugins/core/tracing/api.go
+++ b/plugins/core/tracing/api.go
@@ -86,6 +86,22 @@ func CreateExitSpan(operationName, peer string, injector
Injector, opts ...SpanO
return newSpanAdapter(span.(AdaptSpan)), nil
}
+// ExtractContext attaches the context carried by extractor to the current
+// active entry span as one more segment reference, so a batch consumer can
+// link every upstream message to its single entry span. It returns an error
+// only for a nil extractor or a carrier that fails to decode; no active
+// entry span or an empty carrier is a silent no-op.
+func ExtractContext(extractor Extractor) error {
+ if extractor == nil {
+ return errParameter
+ }
+ op := operator.GetOperator()
+ if op == nil {
+ return nil
+ }
+ return
op.Tracing().(operator.TracingOperator).ExtractContext(extractorWrapper(extractor))
+}
+
// ActiveSpan returns the current active span, it can be got the current span
in the current goroutine.
// If the current goroutine is not in the context of the span, it will return
nil.
// If get the span from other goroutine, it can only get information but
cannot be operated.
diff --git a/plugins/core/tracing_extract_test.go
b/plugins/core/tracing_extract_test.go
new file mode 100644
index 0000000..d8b9357
--- /dev/null
+++ b/plugins/core/tracing_extract_test.go
@@ -0,0 +1,124 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package core
+
+import (
+ "testing"
+
+ "github.com/apache/skywalking-go/plugins/core/tracing"
+)
+
+// secondHeader returns the sw8 headers of a second upstream segment, distinct
+// from the package-level `header` built in tracing_test.go.
+func secondUpstreamExtractor(correlation map[string]string) func(string)
(string, error) {
+ scx := SpanContext{
+ Sample: 1,
+ TraceID: "2f2d4bf47bf711eab794acde48001122",
+ ParentSegmentID: "2e7c204a7bf711eab858acde48001122",
+ ParentSpanID: 1,
+ ParentService: "service-2",
+ ParentServiceInstance: "instance-2",
+ ParentEndpoint: "/producer/second",
+ AddressUsedAtClient: "mq.svc:9876",
+ CorrelationContext: correlation,
+ }
+ sw8 := scx.EncodeSW8()
+ sw8Correlation := scx.EncodeSW8Correlation()
+ return func(headerKey string) (string, error) {
+ switch headerKey {
+ case Header:
+ return sw8, nil
+ case HeaderCorrelation:
+ return sw8Correlation, nil
+ }
+ return "", nil
+ }
+}
+
+// TestExtractContextAddsRefs covers the batch-consumer flow: the entry span is
+// created from the first message and every further message is attached as one
+// more segment reference, with its correlation merged.
+func TestExtractContextAddsRefs(t *testing.T) {
+ ResetTracingContext()
+ defer ResetTracingContext()
+
+ entry, err := tracing.CreateEntrySpan("MQ/batch/Consumer",
func(headerKey string) (string, error) {
+ if headerKey == Header {
+ return header, nil
+ }
+ return "", nil
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if err :=
tracing.ExtractContext(secondUpstreamExtractor(map[string]string{"upstream":
"second"})); err != nil {
+ t.Fatalf("extract second upstream failed: %v", err)
+ }
+ // an empty/invalid carrier must be a silent no-op
+ if err := tracing.ExtractContext(func(string) (string, error) { return
"", nil }); err != nil {
+ t.Fatalf("invalid carrier must not error: %v", err)
+ }
+ // correlation carried by the second message is visible on the segment
+ if got := Tracing.GetCorrelationContextValue("upstream"); got !=
"second" {
+ t.Fatalf("correlation was not merged, got %q", got)
+ }
+
+ entry.End()
+
+ spans := waitReportedSpans(t, 1)
+ if len(spans) != 1 {
+ t.Fatalf("expected exactly one reported span, got %d",
len(spans))
+ }
+ refs := spans[0].Refs()
+ if len(refs) != 2 {
+ t.Fatalf("expected 2 segment references (first message +
extracted), got %d", len(refs))
+ }
+ if refs[0].GetTraceID() != traceID {
+ t.Fatalf("first ref must keep the creation carrier, got %s",
refs[0].GetTraceID())
+ }
+ if refs[1].GetTraceID() != "2f2d4bf47bf711eab794acde48001122" {
+ t.Fatalf("second ref must carry the extracted upstream, got
%s", refs[1].GetTraceID())
+ }
+}
+
+// TestExtractContextRequiresEntrySpan pins the no-op behavior outside an entry
+// span (mirroring the Java agent, only the EntrySpan carries extra refs).
+func TestExtractContextRequiresEntrySpan(t *testing.T) {
+ ResetTracingContext()
+ defer ResetTracingContext()
+
+ // no active span at all
+ if err := tracing.ExtractContext(secondUpstreamExtractor(nil)); err !=
nil {
+ t.Fatalf("no active span must be a no-op: %v", err)
+ }
+
+ local, err := tracing.CreateLocalSpan("local/op")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := tracing.ExtractContext(secondUpstreamExtractor(nil)); err !=
nil {
+ t.Fatalf("non-entry active span must be a no-op: %v", err)
+ }
+ local.End()
+
+ spans := waitReportedSpans(t, 1)
+ if got := len(spans[0].Refs()); got != 0 {
+ t.Fatalf("local span must not gain refs, got %d", got)
+ }
+}
diff --git a/plugins/gorm/entry/callback.go b/plugins/gorm/entry/callback.go
index 51b6686..ed0de13 100644
--- a/plugins/gorm/entry/callback.go
+++ b/plugins/gorm/entry/callback.go
@@ -31,6 +31,16 @@ func beforeCallback(dbInfo DatabaseInfo, op string) func(db
*gorm.DB) {
return func(db *gorm.DB) {
tableName := db.Statement.Table
operation := fmt.Sprintf("%s/%s", tableName, op)
+ // a leftover span on this very Statement means a chained
*gorm.DB is
+ // shared across goroutines (unsupported by gorm): the previous
span is
+ // about to be overwritten and lost, so make the misuse visible
+ if leftover, ok := db.InstanceGet(spanKey); ok {
+ if _, isSpan := leftover.(tracing.Span); isSpan {
+ db.Logger.Warn(db.Statement.Context,
+ "gorm:skywalking found an unfinished
span on the statement, "+
+ "the *gorm.DB is probably
shared across goroutines; its trace data will be lost")
+ }
+ }
s, err := tracing.CreateExitSpan(operation, dbInfo.Peer(),
func(k, v string) error {
return nil
}, tracing.WithComponent(dbInfo.ComponentID()),
@@ -42,18 +52,24 @@ func beforeCallback(dbInfo DatabaseInfo, op string) func(db
*gorm.DB) {
return
}
- db.Set(spanKey, s)
+ // InstanceSet keys by the Statement pointer: gorm's
Statement.clone
+ // copies plain db.Set Settings into every Session/Transaction
clone,
+ // which let a derived operation pick up - and end - the OUTER
span
+ db.InstanceSet(spanKey, s)
}
}
func afterCallback(dbInfo DatabaseInfo) func(db *gorm.DB) {
return func(db *gorm.DB) {
// get span from db instance's context
- spanInterface, _ := db.Get(spanKey)
+ spanInterface, _ := db.InstanceGet(spanKey)
span, ok := spanInterface.(tracing.Span)
if !ok {
return
}
+ // the span is consumed: a later operation on the same
statement must
+ // not see it as a leftover
+ db.InstanceSet(spanKey, nil)
defer span.End()
diff --git a/plugins/gorm/entry/callback_test.go
b/plugins/gorm/entry/callback_test.go
new file mode 100644
index 0000000..5cda6ac
--- /dev/null
+++ b/plugins/gorm/entry/callback_test.go
@@ -0,0 +1,177 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package entry
+
+import (
+ "context"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "gorm.io/gorm"
+ "gorm.io/gorm/logger"
+
+ "github.com/apache/skywalking-go/plugins/core"
+)
+
+type fakeDBInfo struct{}
+
+func (fakeDBInfo) Type() string { return "mysql" }
+func (fakeDBInfo) ComponentID() int32 { return 5012 }
+func (fakeDBInfo) Peer() string { return "localhost:3306" }
+
+type capturingLogger struct {
+ mu sync.Mutex
+ warns []string
+}
+
+func (l *capturingLogger) LogMode(logger.LogLevel) logger.Interface { return l
}
+func (l *capturingLogger) Info(context.Context, string, ...interface{}) {
+}
+func (l *capturingLogger) Warn(_ context.Context, msg string, _
...interface{}) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ l.warns = append(l.warns, msg)
+}
+func (l *capturingLogger) Error(context.Context, string, ...interface{}) {
+}
+func (l *capturingLogger) Trace(context.Context, time.Time, func() (string,
int64), error) {
+}
+
+func (l *capturingLogger) warnCount() int {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ return len(l.warns)
+}
+
+func newCallbackDB(log logger.Interface) *gorm.DB {
+ db := &gorm.DB{Config: &gorm.Config{Logger: log}}
+ // Statement embeds *DB (promoted fields like Statement.Error resolve
+ // through it), the back-reference is mandatory like in gorm.Open
+ db.Statement = &gorm.Statement{DB: db, Table: "users"}
+ return db
+}
+
+// waitSpanCount polls until the reported span count reaches want (the segment
+// collection is asynchronous); raw sleeps flake under -race on slow runners.
+func waitSpanCount(t *testing.T, want int) {
+ t.Helper()
+ deadline := time.Now().Add(2 * time.Second)
+ for len(core.GetReportedSpans()) < want {
+ if time.Now().After(deadline) {
+ t.Fatalf("expected %d reported spans, got %d", want,
len(core.GetReportedSpans()))
+ }
+ time.Sleep(20 * time.Millisecond)
+ }
+}
+
+// TestAfterConsumesSpanAndClearsStatement covers the normal pair: the span is
+// stored per-statement, reported by the after callback, and a following
+// operation on the same statement must not see it as a leftover.
+func TestAfterConsumesSpanAndClearsStatement(t *testing.T) {
+ core.ResetTracingContext()
+ defer core.ResetTracingContext()
+
+ log := &capturingLogger{}
+ db := newCallbackDB(log)
+ before := beforeCallback(fakeDBInfo{}, "create")
+ after := afterCallback(fakeDBInfo{})
+
+ before(db)
+ after(db)
+
+ waitSpanCount(t, 1)
+ spans := core.GetReportedSpans()
+ if len(spans) != 1 {
+ t.Fatalf("expected one reported span, got %d", len(spans))
+ }
+ if spans[0].OperationName() != "users/create" {
+ t.Fatalf("unexpected operation name %s",
spans[0].OperationName())
+ }
+
+ // the same statement is reused for the next operation: no leftover
warning
+ before(db)
+ if log.warnCount() != 0 {
+ t.Fatalf("consumed span must not be reported as leftover: %v",
log.warns)
+ }
+ after(db)
+}
+
+// TestLeftoverSpanIsReported makes the cross-goroutine *gorm.DB sharing
+// misuse visible: two before callbacks on the same statement without an after
+// in between mean the first span is overwritten and lost.
+func TestLeftoverSpanIsReported(t *testing.T) {
+ core.ResetTracingContext()
+ defer core.ResetTracingContext()
+
+ log := &capturingLogger{}
+ db := newCallbackDB(log)
+ before := beforeCallback(fakeDBInfo{}, "create")
+
+ before(db)
+ before(db) // overwrites the unfinished span of the first operation
+
+ if log.warnCount() != 1 {
+ t.Fatalf("expected exactly one leftover warning, got %d (%v)",
log.warnCount(), log.warns)
+ }
+ if !strings.Contains(log.warns[0], "shared across goroutines") {
+ t.Fatalf("warning must explain the misuse: %q", log.warns[0])
+ }
+ // drain the second span so the next test starts clean
+ afterCallback(fakeDBInfo{})(db)
+}
+
+// TestSpanNotInheritedByClonedStatement pins the InstanceSet keying: gorm's
+// Statement.clone copies the plain Settings into every Session/Transaction
+// clone, which previously let a derived statement pick up - and end - the
+// OUTER operation's span. The instance key contains the Statement pointer, so
+// the clone must miss it.
+func TestSpanNotInheritedByClonedStatement(t *testing.T) {
+ core.ResetTracingContext()
+ defer core.ResetTracingContext()
+
+ log := &capturingLogger{}
+ db := newCallbackDB(log)
+ before := beforeCallback(fakeDBInfo{}, "create")
+ after := afterCallback(fakeDBInfo{})
+
+ before(db)
+
+ // simulate gorm's Statement.clone: a NEW statement with the Settings
+ // entries copied over (statement.go copies them one by one)
+ cloned := newCallbackDB(log)
+ db.Statement.Settings.Range(func(k, v interface{}) bool {
+ cloned.Statement.Settings.Store(k, v)
+ return true
+ })
+
+ after(cloned) // must NOT find - and end - the outer operation's span
+
+ // negative check keeps a fixed window: nothing may arrive at all
+ time.Sleep(100 * time.Millisecond)
+ if got := len(core.GetReportedSpans()); got != 0 {
+ t.Fatalf("the cloned statement must not end the outer span, got
%d reported", got)
+ }
+
+ after(db) // the real owner ends it
+ waitSpanCount(t, 1)
+ if got := len(core.GetReportedSpans()); got != 1 {
+ t.Fatalf("expected the outer span to be reported once, got %d",
got)
+ }
+}
diff --git a/plugins/grpc/client_finish_interceptor.go
b/plugins/grpc/client_finish_interceptor.go
index 3b5d900..1b06dd0 100644
--- a/plugins/grpc/client_finish_interceptor.go
+++ b/plugins/grpc/client_finish_interceptor.go
@@ -31,13 +31,9 @@ func (h *ClientFinishInterceptor) BeforeInvoke(invocation
operator.Invocation) e
return nil
}
contextdata := csEnhanced.GetSkyWalkingDynamicField().(*contextData)
- if !contextdata.interceptFinish {
+ if !finishStreamSpan(contextdata) {
return nil
}
- contextdata.interceptFinish = false
- if contextdata.asyncSpan != nil {
- contextdata.asyncSpan.AsyncFinish()
- }
cs := invocation.CallerInstance().(*nativeclientStream)
method := cs.callHdr.Method
activeSpan := tracing.ActiveSpan()
@@ -51,3 +47,17 @@ func (h *ClientFinishInterceptor) BeforeInvoke(invocation
operator.Invocation) e
func (h *ClientFinishInterceptor) AfterInvoke(invocation operator.Invocation,
result ...interface{}) error {
return nil
}
+
+// finishStreamSpan consumes the one-shot finish flag and finishes the async
+// stream span. Only the first caller wins: either RecvMsg never armed the
+// finish, or a concurrent Finish already consumed it - so two racing Finish
+// calls can never both run the AsyncFinish.
+func finishStreamSpan(contextdata *contextData) bool {
+ if !contextdata.interceptFinish.CompareAndSwap(true, false) {
+ return false
+ }
+ if contextdata.asyncSpan != nil {
+ contextdata.asyncSpan.AsyncFinish()
+ }
+ return true
+}
diff --git a/plugins/grpc/client_finish_interceptor_test.go
b/plugins/grpc/client_finish_interceptor_test.go
new file mode 100644
index 0000000..fcb28d2
--- /dev/null
+++ b/plugins/grpc/client_finish_interceptor_test.go
@@ -0,0 +1,91 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package grpc
+
+import (
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/apache/skywalking-go/plugins/core"
+ "github.com/apache/skywalking-go/plugins/core/tracing"
+)
+
+// TestConcurrentFinishRunsAsyncFinishOnce replicates the streaming client
+// lifecycle (PrepareAsync + End in the streaming interceptor, finish armed by
+// RecvMsg) and lets several Finish calls race: exactly one may consume the
+// flag and run AsyncFinish, and the span must be reported exactly once.
+// The flag used to be a plain bool, which was both a data race (RecvMsg
+// goroutine vs gRPC-internal Finish goroutine) and a double-AsyncFinish risk.
+func TestConcurrentFinishRunsAsyncFinishOnce(t *testing.T) {
+ core.ResetTracingContext()
+ defer core.ResetTracingContext()
+
+ span, err := tracing.CreateExitSpan("/grpc.TestService/Streaming",
"localhost:9000",
+ func(k, v string) error { return nil })
+ if err != nil {
+ t.Fatal(err)
+ }
+ span.PrepareAsync()
+ span.End()
+
+ data := &contextData{asyncSpan: span}
+ data.interceptFinish.Store(true) // RecvMsg armed the finish
+
+ const finishers = 8
+ var wg sync.WaitGroup
+ winners := make(chan bool, finishers)
+ for i := 0; i < finishers; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ winners <- finishStreamSpan(data)
+ }()
+ }
+ wg.Wait()
+ close(winners)
+
+ winnerCount := 0
+ for won := range winners {
+ if won {
+ winnerCount++
+ }
+ }
+ if winnerCount != 1 {
+ t.Fatalf("exactly one Finish may run AsyncFinish, got %d
winners", winnerCount)
+ }
+
+ deadline := time.Now().Add(2 * time.Second)
+ for len(core.GetReportedSpans()) < 1 {
+ if time.Now().After(deadline) {
+ t.Fatal("async stream span was never reported")
+ }
+ time.Sleep(20 * time.Millisecond)
+ }
+ if got := len(core.GetReportedSpans()); got != 1 {
+ t.Fatalf("span must be reported exactly once, got %d", got)
+ }
+}
+
+// TestFinishWithoutRecvMsgIsNoop pins the unarmed case.
+func TestFinishWithoutRecvMsgIsNoop(t *testing.T) {
+ data := &contextData{}
+ if finishStreamSpan(data) {
+ t.Fatal("finish without a RecvMsg must not consume anything")
+ }
+}
diff --git a/plugins/grpc/client_recvmsg_interceptor.go
b/plugins/grpc/client_recvmsg_interceptor.go
index 9be0c13..fff7de1 100644
--- a/plugins/grpc/client_recvmsg_interceptor.go
+++ b/plugins/grpc/client_recvmsg_interceptor.go
@@ -38,7 +38,7 @@ func (h *ClientRecvMsgInterceptor) BeforeInvoke(invocation
operator.Invocation)
if ok && csEnhanced.GetSkyWalkingDynamicField() != nil {
contextdata :=
csEnhanced.GetSkyWalkingDynamicField().(*contextData)
tracing.ContinueContext(contextdata.continueSnapShot)
- contextdata.interceptFinish = true
+ contextdata.interceptFinish.Store(true)
}
s, err := tracing.CreateLocalSpan(formatOperationName(method,
"/Client/Response/RecvMsg"),
tracing.WithLayer(tracing.SpanLayerRPCFramework),
diff --git a/plugins/grpc/client_streaming_interceptor.go
b/plugins/grpc/client_streaming_interceptor.go
index 10e86fd..c6d3a18 100644
--- a/plugins/grpc/client_streaming_interceptor.go
+++ b/plugins/grpc/client_streaming_interceptor.go
@@ -20,6 +20,7 @@ package grpc
import (
"context"
"strings"
+ "sync/atomic"
"google.golang.org/grpc/metadata"
@@ -39,8 +40,10 @@ type contextData struct {
// endSnapShot is the snapshot that the span has ended
// When the service is completely finished, it should be continued
endSnapShot tracing.ContextSnapshot
- // interceptFinish is whether to intercept finish()
- interceptFinish bool
+ // interceptFinish is whether to intercept finish(). RecvMsg writes it
on
+ // the user goroutine while Finish may consume it from a gRPC-internal
+ // goroutine, so it is atomic; the CAS consumer also makes finish
one-shot.
+ interceptFinish atomic.Bool
}
func (h *ClientStreamingInterceptor) BeforeInvoke(invocation
operator.Invocation) error {
@@ -87,7 +90,6 @@ func (h *ClientStreamingInterceptor) AfterInvoke(invocation
operator.Invocation,
asyncSpan: span,
continueSnapShot: continueSnapShot,
endSnapShot: tracing.CaptureContext(),
- interceptFinish: false,
})
return nil
}
diff --git a/plugins/microv4/server/structure.go
b/plugins/microv4/server/structure.go
index a4c89cc..4f37d0f 100644
--- a/plugins/microv4/server/structure.go
+++ b/plugins/microv4/server/structure.go
@@ -17,10 +17,17 @@
package server
-import "github.com/apache/skywalking-go/plugins/core/tracing"
+import (
+ "sync"
+
+ "github.com/apache/skywalking-go/plugins/core/tracing"
+)
//skywalking:ref_generate go-micro.dev/v4/util/socket InjectData
type InjectData struct {
Span tracing.Span
Snapshot tracing.ContextSnapshot
+ // finished mirrors the real definition in
util/socket/accept_interceptor.go
+ // (the two declarations must stay structurally identical for
ref_generate)
+ finished sync.Once
}
diff --git a/plugins/microv4/util/socket/accept_interceptor.go
b/plugins/microv4/util/socket/accept_interceptor.go
index 844f76e..f4f936c 100644
--- a/plugins/microv4/util/socket/accept_interceptor.go
+++ b/plugins/microv4/util/socket/accept_interceptor.go
@@ -18,6 +18,8 @@
package socket
import (
+ "sync"
+
"github.com/apache/skywalking-go/plugins/core/operator"
"github.com/apache/skywalking-go/plugins/core/tracing"
)
@@ -26,6 +28,11 @@ import (
type InjectData struct {
Span tracing.Span
Snapshot tracing.ContextSnapshot
+ // finished makes the AsyncFinish of the connection span one-shot under
+ // concurrent Close calls (see close_interceptor.go). sync.Once on
purpose:
+ // injected files may only import what go-micro's util/socket package
+ // already imports - "sync" via its pool.go, "sync/atomic" is
unavailable.
+ finished sync.Once
}
type AcceptInterceptor struct {
diff --git a/plugins/microv4/util/socket/close_interceptor.go
b/plugins/microv4/util/socket/close_interceptor.go
index 419de23..fcec53e 100644
--- a/plugins/microv4/util/socket/close_interceptor.go
+++ b/plugins/microv4/util/socket/close_interceptor.go
@@ -30,11 +30,16 @@ func (n *CloseInterceptor) BeforeInvoke(invocation
operator.Invocation) error {
func (n *CloseInterceptor) AfterInvoke(invocation operator.Invocation, results
...interface{}) error {
instance := invocation.CallerInstance().(operator.EnhancedInstance)
- span := instance.GetSkyWalkingDynamicField()
- if span == nil {
+ data, ok := instance.GetSkyWalkingDynamicField().(*InjectData)
+ if !ok || data == nil {
return nil
}
- span.(*InjectData).Span.AsyncFinish()
- instance.SetSkyWalkingDynamicField(nil)
+ // one-shot under concurrent Close calls; the winner also clears the
+ // dynamic field so a socket reused for a new connection gets a fresh
+ // span instead of being blocked by the stale InjectData
+ data.finished.Do(func() {
+ data.Span.AsyncFinish()
+ instance.SetSkyWalkingDynamicField(nil)
+ })
return nil
}
diff --git a/plugins/microv4/util/socket/close_interceptor_test.go
b/plugins/microv4/util/socket/close_interceptor_test.go
new file mode 100644
index 0000000..5e4f1df
--- /dev/null
+++ b/plugins/microv4/util/socket/close_interceptor_test.go
@@ -0,0 +1,107 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package socket
+
+import (
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/apache/skywalking-go/plugins/core"
+ "github.com/apache/skywalking-go/plugins/core/operator"
+ "github.com/apache/skywalking-go/plugins/core/tracing"
+)
+
+// fakeEnhancedInstance stands in for the toolchain-enhanced socket instance.
+// The fake synchronizes the field so the test isolates the one-shot guard
+// itself; the raciness of the real generated accessor is the known
+// toolchain-level dynamic-field issue, not what this test verifies.
+type fakeEnhancedInstance struct {
+ mu sync.Mutex
+ field interface{}
+}
+
+func (f *fakeEnhancedInstance) GetSkyWalkingDynamicField() interface{} {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ return f.field
+}
+
+func (f *fakeEnhancedInstance) SetSkyWalkingDynamicField(val interface{}) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.field = val
+}
+
+// TestConcurrentCloseFinishesOnce replicates two racing Close calls on the
+// same connection: both read a non-nil dynamic field, but the one-shot guard
+// lets only one of them run AsyncFinish (a double AsyncFinish used to panic
+// before the core made it drop-and-log; now it cannot happen at all).
+func TestConcurrentCloseFinishesOnce(t *testing.T) {
+ core.ResetTracingContext()
+ defer core.ResetTracingContext()
+
+ span, err := tracing.CreateEntrySpan("micro/connection", func(string)
(string, error) { return "", nil })
+ if err != nil {
+ t.Fatal(err)
+ }
+ span.PrepareAsync()
+ snapshot := tracing.CaptureContext()
+ span.End()
+
+ instance := &fakeEnhancedInstance{field: &InjectData{Span: span,
Snapshot: snapshot}}
+ interceptor := &CloseInterceptor{}
+
+ const closers = 8
+ var wg sync.WaitGroup
+ for i := 0; i < closers; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err :=
interceptor.AfterInvoke(operator.NewInvocation(instance)); err != nil {
+ t.Errorf("close interceptor error: %v", err)
+ }
+ }()
+ }
+ wg.Wait()
+
+ deadline := time.Now().Add(2 * time.Second)
+ for len(core.GetReportedSpans()) < 1 {
+ if time.Now().After(deadline) {
+ t.Fatal("connection span was never reported")
+ }
+ time.Sleep(20 * time.Millisecond)
+ }
+ if got := len(core.GetReportedSpans()); got != 1 {
+ t.Fatalf("connection span must be reported exactly once, got
%d", got)
+ }
+ if instance.GetSkyWalkingDynamicField() != nil {
+ t.Fatal("the winner must clear the dynamic field so a reused
socket gets a fresh span")
+ }
+}
+
+// TestCloseWithoutInjectDataIsNoop keeps the nil/foreign dynamic-field guard.
+func TestCloseWithoutInjectDataIsNoop(t *testing.T) {
+ interceptor := &CloseInterceptor{}
+ if err :=
interceptor.AfterInvoke(operator.NewInvocation(&fakeEnhancedInstance{})); err
!= nil {
+ t.Fatal(err)
+ }
+ if err :=
interceptor.AfterInvoke(operator.NewInvocation(&fakeEnhancedInstance{field:
"not-inject-data"})); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/plugins/mongo/mongo/interceptor.go
b/plugins/mongo/mongo/interceptor.go
index 94fd944..d6b115c 100644
--- a/plugins/mongo/mongo/interceptor.go
+++ b/plugins/mongo/mongo/interceptor.go
@@ -79,6 +79,11 @@ func (m *NewClientInterceptor) BeforeInvoke(invocation
operator.Invocation) erro
span.Tag(tracing.TagDBStatement,
m.gettingStatements(startedEvent))
}
+ // Succeeded/Failed may fire on a DIFFERENT
goroutine, so the
+ // completion goes through the async machinery;
End() also pops
+ // the span off this goroutine's active stack
immediately.
+ span.PrepareAsync()
+ span.End()
syncMap.Put(fmt.Sprintf("%d",
startedEvent.RequestID), span)
},
Succeeded: func(ctx context.Context, succeededEvent
*event.CommandSucceededEvent) {
@@ -86,7 +91,7 @@ func (m *NewClientInterceptor) BeforeInvoke(invocation
operator.Invocation) erro
configuredMonitor.Succeeded(ctx,
succeededEvent)
}
if span, ok := syncMap.Remove(fmt.Sprintf("%d",
succeededEvent.RequestID)); ok && span != nil {
- span.(tracing.Span).End()
+ span.(tracing.Span).AsyncFinish()
}
},
Failed: func(ctx context.Context, failedEvent
*event.CommandFailedEvent) {
@@ -95,7 +100,7 @@ func (m *NewClientInterceptor) BeforeInvoke(invocation
operator.Invocation) erro
}
if span, ok := syncMap.Remove(fmt.Sprintf("%d",
failedEvent.RequestID)); ok && span != nil {
span.(tracing.Span).Error(failedEvent.Failure)
- span.(tracing.Span).End()
+ span.(tracing.Span).AsyncFinish()
}
},
}
diff --git a/plugins/mongo/mongo/interceptor_test.go
b/plugins/mongo/mongo/interceptor_test.go
new file mode 100644
index 0000000..fca514c
--- /dev/null
+++ b/plugins/mongo/mongo/interceptor_test.go
@@ -0,0 +1,144 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package mongo
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "go.mongodb.org/mongo-driver/event"
+ "go.mongodb.org/mongo-driver/mongo/options"
+
+ "github.com/apache/skywalking-go/plugins/core"
+ "github.com/apache/skywalking-go/plugins/core/operator"
+ "github.com/apache/skywalking-go/plugins/core/reporter"
+ "github.com/apache/skywalking-go/plugins/core/tracing"
+)
+
+func installMonitor(t *testing.T) *event.CommandMonitor {
+ t.Helper()
+ opts := []*options.ClientOptions{{Hosts: []string{"127.0.0.1:27017"}}}
+ interceptor := &NewClientInterceptor{}
+ if err := interceptor.BeforeInvoke(operator.NewInvocation(nil, opts));
err != nil {
+ t.Fatal(err)
+ }
+ if opts[0].Monitor == nil {
+ t.Fatal("command monitor was not installed")
+ }
+ return opts[0].Monitor
+}
+
+func waitOneSpan(t *testing.T) reporter.ReportedSpan {
+ t.Helper()
+ deadline := time.Now().Add(2 * time.Second)
+ for {
+ if spans := core.GetReportedSpans(); len(spans) >= 1 {
+ return spans[0]
+ }
+ if time.Now().After(deadline) {
+ t.Fatal("span was never reported")
+ }
+ time.Sleep(20 * time.Millisecond)
+ }
+}
+
+// TestCommandFinishedOnAnotherGoroutine replicates the SDAM/monitor topology:
+// Started fires on goroutine A, Failed on goroutine B. The span completion
+// must go through the async machinery (a plain End from B used to leave A's
+// active stack pointing at a span another goroutine handed to the reporter).
+func TestCommandFinishedOnAnotherGoroutine(t *testing.T) {
+ core.ResetTracingContext()
+ defer core.ResetTracingContext()
+
+ monitor := installMonitor(t)
+
+ started := make(chan struct{})
+ go func() { // goroutine A: the operation goroutine
+ defer close(started)
+ monitor.Started(context.Background(),
&event.CommandStartedEvent{
+ CommandName: "find",
+ RequestID: 42,
+ ConnectionID: "127.0.0.1:27017[-4]",
+ })
+ // the handed-over span must NOT stay active on this goroutine:
+ // whatever the application starts next must not chain onto it
+ if tracing.ActiveSpan() != nil {
+ t.Error("the mongo span must be popped off the active
stack after Started")
+ }
+ }()
+ <-started
+
+ failed := make(chan struct{})
+ go func() { // goroutine B: a different SDK goroutine completes the
command
+ defer close(failed)
+ monitor.Failed(context.Background(), &event.CommandFailedEvent{
+ CommandFinishedEvent:
event.CommandFinishedEvent{RequestID: 42},
+ Failure: "network error",
+ })
+ }()
+ <-failed
+
+ span := waitOneSpan(t)
+ if span.OperationName() != "MongoDB/find" {
+ t.Fatalf("unexpected operation name %s", span.OperationName())
+ }
+ if !span.IsError() {
+ t.Fatal("a failed command must be an error span")
+ }
+ if span.EndTime() <= 0 {
+ t.Fatal("span must carry the completion time")
+ }
+}
+
+// TestCommandSucceededReportsOnce covers the happy path plus the
+// unknown-request guard.
+func TestCommandSucceededReportsOnce(t *testing.T) {
+ core.ResetTracingContext()
+ defer core.ResetTracingContext()
+
+ monitor := installMonitor(t)
+
+ monitor.Started(context.Background(), &event.CommandStartedEvent{
+ CommandName: "insert",
+ RequestID: 7,
+ })
+ // completion of a request the agent never saw must be a no-op
+ monitor.Succeeded(context.Background(), &event.CommandSucceededEvent{
+ CommandFinishedEvent: event.CommandFinishedEvent{RequestID:
999},
+ })
+ monitor.Succeeded(context.Background(), &event.CommandSucceededEvent{
+ CommandFinishedEvent: event.CommandFinishedEvent{RequestID: 7},
+ })
+ // a duplicated completion must be a no-op as well (the map entry is
gone)
+ monitor.Succeeded(context.Background(), &event.CommandSucceededEvent{
+ CommandFinishedEvent: event.CommandFinishedEvent{RequestID: 7},
+ })
+
+ span := waitOneSpan(t)
+ if span.OperationName() != "MongoDB/insert" {
+ t.Fatalf("unexpected operation name %s", span.OperationName())
+ }
+ if span.IsError() {
+ t.Fatal("a succeeded command must not be an error span")
+ }
+ time.Sleep(100 * time.Millisecond)
+ if got := len(core.GetReportedSpans()); got != 1 {
+ t.Fatalf("the command span must be reported exactly once, got
%d", got)
+ }
+}
diff --git a/plugins/mux/serve_interceptor.go b/plugins/mux/serve_interceptor.go
index 01b2dee..6535b0c 100644
--- a/plugins/mux/serve_interceptor.go
+++ b/plugins/mux/serve_interceptor.go
@@ -68,7 +68,7 @@ func newResponseWriter(val interface{}) http.ResponseWriter {
case http.Hijacker:
rw = newWriterWrapperWithHijacker(sourceWriter,
sourceWriter.(http.Hijacker))
default:
- rw = newWriterWrapper(rw)
+ rw = newWriterWrapper(sourceWriter)
}
return rw
}
diff --git a/plugins/mux/serve_interceptor_test.go
b/plugins/mux/serve_interceptor_test.go
new file mode 100644
index 0000000..3eee29c
--- /dev/null
+++ b/plugins/mux/serve_interceptor_test.go
@@ -0,0 +1,102 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package mux
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/apache/skywalking-go/plugins/core"
+ "github.com/apache/skywalking-go/plugins/core/operator"
+ "github.com/apache/skywalking-go/plugins/core/tracing"
+)
+
+// TestNonHijackerWriterIsUsable pins the nil-writer bug: an
http.ResponseWriter
+// that does not implement http.Hijacker (e.g. HTTP/2) took the default branch,
+// which wrapped a nil writer - the first Write of the user handler then
+// crashed with a nil dereference.
+func TestNonHijackerWriterIsUsable(t *testing.T) {
+ recorder := httptest.NewRecorder() // does not implement http.Hijacker
+
+ rw := newResponseWriter(recorder)
+ rw.WriteHeader(http.StatusNotFound)
+ if _, err := rw.Write([]byte("not found")); err != nil {
+ t.Fatal(err)
+ }
+
+ wrapped, ok := rw.(*writerWrapper)
+ if !ok {
+ t.Fatalf("non-hijacker writer must use the plain wrapper, got
%T", rw)
+ }
+ if wrapped.statusCode != http.StatusNotFound {
+ t.Fatalf("status code was not captured, got %d",
wrapped.statusCode)
+ }
+ if recorder.Body.String() != "not found" {
+ t.Fatalf("response body was not written through, got %q",
recorder.Body.String())
+ }
+}
+
+// TestServeInterceptorWithNonHijackerWriter runs the full interceptor pair on
+// a non-hijacker writer and checks the reported span carries the status code.
+func TestServeInterceptorWithNonHijackerWriter(t *testing.T) {
+ core.ResetTracingContext()
+ defer core.ResetTracingContext()
+
+ request, err := http.NewRequest(http.MethodGet,
"http://localhost/api/users", http.NoBody)
+ if err != nil {
+ t.Fatal(err)
+ }
+ recorder := httptest.NewRecorder()
+
+ interceptor := &ServeHTTPInterceptor{}
+ invocation := operator.NewInvocation(nil, recorder, request)
+ if err := interceptor.BeforeInvoke(invocation); err != nil {
+ t.Fatal(err)
+ }
+
+ // the user handler writes through the (previously nil) wrapped writer
+ rw := invocation.Args()[0].(http.ResponseWriter)
+ rw.WriteHeader(http.StatusCreated)
+
+ if err := interceptor.AfterInvoke(invocation); err != nil {
+ t.Fatal(err)
+ }
+
+ deadline := time.Now().Add(2 * time.Second)
+ for len(core.GetReportedSpans()) < 1 {
+ if time.Now().After(deadline) {
+ t.Fatal("span was never reported")
+ }
+ time.Sleep(20 * time.Millisecond)
+ }
+ spans := core.GetReportedSpans()
+ if len(spans) != 1 {
+ t.Fatalf("expected one reported span, got %d", len(spans))
+ }
+ statusCode := ""
+ for _, tag := range spans[0].Tags() {
+ if tag.Key == tracing.TagStatusCode {
+ statusCode = tag.Value
+ }
+ }
+ if statusCode != "201" {
+ t.Fatalf("status code tag mismatch: %q", statusCode)
+ }
+}
diff --git a/plugins/pulsar/pulsar/send_async_producer.go
b/plugins/pulsar/pulsar/send_async_producer.go
index cd91091..1ec3a7c 100644
--- a/plugins/pulsar/pulsar/send_async_producer.go
+++ b/plugins/pulsar/pulsar/send_async_producer.go
@@ -65,29 +65,11 @@ func (s *SendAsyncInterceptor) BeforeInvoke(invocation
operator.Invocation) erro
continueSnapShot := tracing.CaptureContext()
zuper := invocation.Args()[2].(func(id MessageID, message
*ProducerMessage, err error))
+ // enhance async callback method: the agent part is fully isolated
inside
+ // traceAsyncSendCallback (see its doc), the user callback runs after it
callbackFunc := func(id MessageID, message *ProducerMessage, err error)
{
- defer tracing.CleanContext()
- tracing.ContinueContext(continueSnapShot)
- operationName = pulsarAsyncPrefix + topic + pulsarCallbackSuffix
-
- localSpan, localErr := tracing.CreateLocalSpan(operationName,
- tracing.WithComponent(pulsarAsyncComponentID),
- tracing.WithLayer(tracing.SpanLayerMQ),
- tracing.WithTag(tracing.TagMQTopic,
nativeProducer.topic),
- )
- if localErr != nil {
- zuper(id, message, err)
- return
- }
- if err != nil {
- span.Error(err.Error())
- }
- localSpan.Tag(tracing.TagMQBroker, lookup.PhysicalAddr.String())
- localSpan.Tag(tracing.TagMQMsgID, id.String())
-
+ traceAsyncSendCallback(continueSnapShot, topic,
nativeProducer.topic, peer, lookup.PhysicalAddr.String(), id, err)
zuper(id, message, err)
- localSpan.SetPeer(peer)
- localSpan.End()
}
span.SetPeer(peer)
@@ -96,6 +78,37 @@ func (s *SendAsyncInterceptor) BeforeInvoke(invocation
operator.Invocation) erro
return nil
}
+// traceAsyncSendCallback records the async send result on a NEW local span -
+// never on the exit span, already ended by AfterInvoke. It runs on an SDK
+// goroutine without framework recover, so the agent logic is fully wrapped in
+// its own recover; the user callback runs outside, never swallowed.
+func traceAsyncSendCallback(snapshot tracing.ContextSnapshot, opTopic,
tagTopic, peer, broker string, id MessageID, sendErr error) {
+ defer tracing.CleanContext()
+ defer func() {
+ // no logging channel exists on this goroutine, drop on purpose
+ _ = recover()
+ }()
+ tracing.ContinueContext(snapshot)
+
+ localSpan, err :=
tracing.CreateLocalSpan(pulsarAsyncPrefix+opTopic+pulsarCallbackSuffix,
+ tracing.WithComponent(pulsarAsyncComponentID),
+ tracing.WithLayer(tracing.SpanLayerMQ),
+ tracing.WithTag(tracing.TagMQTopic, tagTopic),
+ )
+ if err != nil {
+ return
+ }
+ if sendErr != nil {
+ localSpan.Error(sendErr.Error())
+ }
+ localSpan.Tag(tracing.TagMQBroker, broker)
+ if id != nil { // nil when the send failed
+ localSpan.Tag(tracing.TagMQMsgID, id.String())
+ }
+ localSpan.SetPeer(peer)
+ localSpan.End()
+}
+
func (s *SendAsyncInterceptor) AfterInvoke(invocation operator.Invocation,
result ...interface{}) error {
if invocation.GetContext() == nil {
return nil
diff --git a/plugins/pulsar/pulsar/send_async_producer_test.go
b/plugins/pulsar/pulsar/send_async_producer_test.go
new file mode 100644
index 0000000..8123e09
--- /dev/null
+++ b/plugins/pulsar/pulsar/send_async_producer_test.go
@@ -0,0 +1,126 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package pulsar
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/apache/skywalking-go/plugins/core"
+ "github.com/apache/skywalking-go/plugins/core/reporter"
+ "github.com/apache/skywalking-go/plugins/core/tracing"
+)
+
+type fakeMessageID struct{ id string }
+
+func (f *fakeMessageID) String() string { return f.id }
+
+func findReportedSpan(t *testing.T, name string) reporter.ReportedSpan {
+ t.Helper()
+ deadline := time.Now().Add(2 * time.Second)
+ for {
+ for _, s := range core.GetReportedSpans() {
+ if s.OperationName() == name {
+ return s
+ }
+ }
+ if time.Now().After(deadline) {
+ t.Fatalf("span %q was never reported", name)
+ }
+ time.Sleep(20 * time.Millisecond)
+ }
+}
+
+func reportedTagValue(s reporter.ReportedSpan, key string) (string, bool) {
+ for _, tag := range s.Tags() {
+ if tag.Key == key {
+ return tag.Value, true
+ }
+ }
+ return "", false
+}
+
+// TestAsyncCallbackFailedSendIsSafe pins the failed-send branch: the message
+// id is nil (the old code called id.String() and killed the process - no
+// recover exists on SDK goroutines) and the error must land on the CALLBACK
+// local span, never on the already-ended exit span.
+func TestAsyncCallbackFailedSendIsSafe(t *testing.T) {
+ core.ResetTracingContext()
+ defer core.ResetTracingContext()
+
+ span, err :=
tracing.CreateExitSpan("Pulsar/persistent://public/default/t1/AsyncProducer",
"broker:6650",
+ func(k, v string) error { return nil })
+ if err != nil {
+ t.Fatal(err)
+ }
+ snapshot := tracing.CaptureContext()
+ span.End() // AfterInvoke ends the exit span before the callback fires
+
+ done := make(chan struct{})
+ go func() { // the SDK callback goroutine
+ defer close(done)
+ traceAsyncSendCallback(snapshot, "t1",
"persistent://public/default/t1",
+ "broker:6650", "broker:6650", nil, errors.New("send
failed"))
+ }()
+ <-done
+
+ local := findReportedSpan(t, "Pulsar/t1/Producer/Callback")
+ if !local.IsError() {
+ t.Fatal("send error must be recorded on the callback local
span")
+ }
+ if _, ok := reportedTagValue(local, tracing.TagMQMsgID); ok {
+ t.Fatal("failed send must not carry a message id tag")
+ }
+ if v, _ := reportedTagValue(local, tracing.TagMQTopic); v !=
"persistent://public/default/t1" {
+ t.Fatalf("topic tag mismatch: %q", v)
+ }
+}
+
+// TestAsyncCallbackSuccessTagsResult covers the happy path tag set.
+func TestAsyncCallbackSuccessTagsResult(t *testing.T) {
+ core.ResetTracingContext()
+ defer core.ResetTracingContext()
+
+ span, err := tracing.CreateExitSpan("Pulsar/t1/AsyncProducer",
"broker:6650",
+ func(k, v string) error { return nil })
+ if err != nil {
+ t.Fatal(err)
+ }
+ snapshot := tracing.CaptureContext()
+ span.End()
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ traceAsyncSendCallback(snapshot, "t1", "t1", "broker:6650",
"broker:6650",
+ &fakeMessageID{id: "ledger:1:entry:2"}, nil)
+ }()
+ <-done
+
+ local := findReportedSpan(t, "Pulsar/t1/Producer/Callback")
+ if local.IsError() {
+ t.Fatal("successful send must not be an error span")
+ }
+ if v, _ := reportedTagValue(local, tracing.TagMQMsgID); v !=
"ledger:1:entry:2" {
+ t.Fatalf("msg id tag mismatch: %q", v)
+ }
+ if v, _ := reportedTagValue(local, tracing.TagMQBroker); v !=
"broker:6650" {
+ t.Fatalf("broker tag mismatch: %q", v)
+ }
+}
diff --git a/plugins/rocketmq/consumer/consumer.go
b/plugins/rocketmq/consumer/consumer.go
index 18cb80c..5ce5caf 100644
--- a/plugins/rocketmq/consumer/consumer.go
+++ b/plugins/rocketmq/consumer/consumer.go
@@ -43,34 +43,55 @@ func (c *SwConsumerInterceptor) BeforeInvoke(invocation
operator.Invocation) err
pushConsumer := invocation.CallerInstance().(*nativepushConsumer)
peer := strings.Join(pushConsumer.client.GetNameSrv().AddrList(),
semicolon)
subMsgs := invocation.Args()[1].([]*primitive.MessageExt)
+ span, err := createConsumerEntrySpan(subMsgs, peer)
+ if err != nil || span == nil {
+ return err
+ }
+ invocation.SetContext(span)
+ return nil
+}
+
+// createConsumerEntrySpan creates ONE entry span for the whole batch from the
+// first message and attaches every remaining message as an extra segment
+// reference, mirroring the Java agent. One span per message must be avoided:
+// the reuse rule would hand back the same span N times while AfterInvoke
+// calls End only once, so the span would never be reported.
+func createConsumerEntrySpan(subMsgs []*primitive.MessageExt, peer string)
(tracing.Span, error) {
if len(subMsgs) == 0 {
- return nil
+ return nil, nil
+ }
+ first := subMsgs[0]
+ topic := first.Topic
+ msgIDs := make([]string, 0, len(subMsgs))
+ offsetMsgIDs := make([]string, 0, len(subMsgs))
+ for _, msg := range subMsgs {
+ msgIDs = append(msgIDs, msg.MsgId)
+ offsetMsgIDs = append(offsetMsgIDs, msg.OffsetMsgId)
}
- topic, addr := subMsgs[0].Topic, subMsgs[0].StoreHost
- operationName := rmqConsumerPrefix + topic + rmqConsumerSuffix
- var (
- span tracing.Span
- err error
+ span, err :=
tracing.CreateEntrySpan(rmqConsumerPrefix+topic+rmqConsumerSuffix,
func(headerKey string) (string, error) {
+ return first.GetProperty(headerKey), nil
+ },
+ tracing.WithLayer(tracing.SpanLayerMQ),
+ tracing.WithComponent(rmqConsumerComponentID),
+ tracing.WithTag(tracing.TagMQTopic, topic),
+ tracing.WithTag(tagMQMsgID, strings.Join(msgIDs, semicolon)),
+ tracing.WithTag(tagMQOffsetMsgID, strings.Join(offsetMsgIDs,
semicolon)),
)
- for _, msg := range subMsgs {
- span, err = tracing.CreateEntrySpan(operationName,
func(headerKey string) (string, error) {
- return msg.GetProperty(headerKey), nil
- },
- tracing.WithLayer(tracing.SpanLayerMQ),
- tracing.WithComponent(rmqConsumerComponentID),
- tracing.WithTag(tracing.TagMQTopic, topic),
- tracing.WithTag(tagMQMsgID, msg.MsgId),
- tracing.WithTag(tagMQOffsetMsgID, msg.OffsetMsgId),
- )
- if err != nil {
- return err
- }
+ if err != nil {
+ return nil, err
}
- span.Tag(tracing.TagMQBroker, addr)
+ for _, msg := range subMsgs[1:] {
+ extractMsg := msg
+ // a broken header on a single message must not lose the batch
span,
+ // so the error is intentionally ignored
+ _ = tracing.ExtractContext(func(headerKey string) (string,
error) {
+ return extractMsg.GetProperty(headerKey), nil
+ })
+ }
+ span.Tag(tracing.TagMQBroker, first.StoreHost)
span.SetPeer(peer)
- invocation.SetContext(span)
- return nil
+ return span, nil
}
func (c *SwConsumerInterceptor) AfterInvoke(invocation operator.Invocation,
result ...interface{}) error {
diff --git a/plugins/rocketmq/consumer/consumer_test.go
b/plugins/rocketmq/consumer/consumer_test.go
new file mode 100644
index 0000000..cf43f4e
--- /dev/null
+++ b/plugins/rocketmq/consumer/consumer_test.go
@@ -0,0 +1,138 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package consumer
+
+import (
+ "testing"
+ "time"
+
+ "github.com/apache/rocketmq-client-go/v2/consumer"
+ "github.com/apache/rocketmq-client-go/v2/primitive"
+
+ "github.com/apache/skywalking-go/plugins/core"
+ "github.com/apache/skywalking-go/plugins/core/operator"
+ "github.com/apache/skywalking-go/plugins/core/reporter"
+)
+
+func newTestMessage(topic, msgID, traceID, segmentID string)
*primitive.MessageExt {
+ scx := core.SpanContext{
+ Sample: 1,
+ TraceID: traceID,
+ ParentSegmentID: segmentID,
+ ParentSpanID: 0,
+ ParentService: "producer-service",
+ ParentServiceInstance: "producer-instance",
+ ParentEndpoint: "/producer/send",
+ AddressUsedAtClient: "mq.svc:9876",
+ }
+ msg := &primitive.MessageExt{MsgId: msgID, OffsetMsgId: "off-" + msgID}
+ msg.Topic = topic
+ msg.WithProperty(core.Header, scx.EncodeSW8())
+ return msg
+}
+
+func waitReportedSpans(t *testing.T, want int) []reporter.ReportedSpan {
+ t.Helper()
+ deadline := time.Now().Add(2 * time.Second)
+ for {
+ spans := core.GetReportedSpans()
+ if len(spans) >= want {
+ return spans
+ }
+ if time.Now().After(deadline) {
+ t.Fatalf("expected %d reported spans, got %d", want,
len(spans))
+ }
+ time.Sleep(20 * time.Millisecond)
+ }
+}
+
+func tagValue(s reporter.ReportedSpan, key string) string {
+ for _, tag := range s.Tags() {
+ if tag.Key == key {
+ return tag.Value
+ }
+ }
+ return ""
+}
+
+// TestBatchConsumeReportsSingleSpanWithAllRefs pins the Java-aligned batch
+// semantics: ONE entry span carrying one segment reference per message, and
+// it must be reported exactly once. The previous per-message CreateEntrySpan
+// loop left the span reuse counter unbalanced (N reuses, one End), so the
+// span of a batch with more than one message was never reported at all.
+func TestBatchConsumeReportsSingleSpanWithAllRefs(t *testing.T) {
+ core.ResetTracingContext()
+ defer core.ResetTracingContext()
+
+ msgs := []*primitive.MessageExt{
+ newTestMessage("TopicTest", "m1",
"11d1aaaaaaaaaaaaaaaaaaaaaaaaaaaa", "11c1aaaaaaaaaaaaaaaaaaaaaaaaaaaa"),
+ newTestMessage("TopicTest", "m2",
"22d2bbbbbbbbbbbbbbbbbbbbbbbbbbbb", "22c2bbbbbbbbbbbbbbbbbbbbbbbbbbbb"),
+ newTestMessage("TopicTest", "m3",
"33d3cccccccccccccccccccccccccccc", "33c3cccccccccccccccccccccccccccc"),
+ }
+
+ span, err := createConsumerEntrySpan(msgs, "nameserver:9876")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if span == nil {
+ t.Fatal("no span created for non-empty batch")
+ }
+
+ invocation := operator.NewInvocation(nil)
+ invocation.SetContext(span)
+ interceptor := &SwConsumerInterceptor{}
+ if err := interceptor.AfterInvoke(invocation, consumer.ConsumeSuccess,
nil); err != nil {
+ t.Fatal(err)
+ }
+
+ spans := waitReportedSpans(t, 1)
+ if len(spans) != 1 {
+ t.Fatalf("batch must report exactly one span, got %d",
len(spans))
+ }
+ reported := spans[0]
+ if reported.OperationName() != "RocketMQ/TopicTest/Consumer" {
+ t.Fatalf("unexpected operation name %s",
reported.OperationName())
+ }
+ refs := reported.Refs()
+ if len(refs) != 3 {
+ t.Fatalf("every message of the batch must become a segment
reference, got %d", len(refs))
+ }
+ if refs[0].GetTraceID() != "11d1aaaaaaaaaaaaaaaaaaaaaaaaaaaa" ||
+ refs[1].GetTraceID() != "22d2bbbbbbbbbbbbbbbbbbbbbbbbbbbb" ||
+ refs[2].GetTraceID() != "33d3cccccccccccccccccccccccccccc" {
+ t.Fatalf("refs do not carry the upstream traces in order: %v",
+ []string{refs[0].GetTraceID(), refs[1].GetTraceID(),
refs[2].GetTraceID()})
+ }
+ if got := tagValue(reported, tagMQMsgID); got != "m1;m2;m3" {
+ t.Fatalf("message ids must be aggregated, got %q", got)
+ }
+ if got := tagValue(reported, tagMQOffsetMsgID); got !=
"off-m1;off-m2;off-m3" {
+ t.Fatalf("offset message ids must be aggregated, got %q", got)
+ }
+}
+
+// TestEmptyBatchCreatesNoSpan keeps the empty-callback guard.
+func TestEmptyBatchCreatesNoSpan(t *testing.T) {
+ core.ResetTracingContext()
+ defer core.ResetTracingContext()
+
+ span, err := createConsumerEntrySpan(nil, "nameserver:9876")
+ if err != nil || span != nil {
+ t.Fatalf("empty batch must be a no-op, span=%v err=%v", span,
err)
+ }
+}
diff --git a/plugins/rocketmq/producer/async_producer.go
b/plugins/rocketmq/producer/async_producer.go
index 1ef0237..58b5635 100644
--- a/plugins/rocketmq/producer/async_producer.go
+++ b/plugins/rocketmq/producer/async_producer.go
@@ -63,34 +63,13 @@ func (sa *SendASyncInterceptor) BeforeInvoke(invocation
operator.Invocation) err
continueSnapShot := tracing.CaptureContext()
zuper := invocation.Args()[1].(func(ctx context.Context, result
*primitive.SendResult, err error))
- // enhance async callback method
+ // enhance async callback method: the agent part is fully isolated
inside
+ // traceAsyncSendCallback (see its doc), the user callback runs after it
callbackFunc := func(ctx context.Context, sendResult
*primitive.SendResult, err error) {
- defer tracing.CleanContext()
- tracing.ContinueContext(continueSnapShot)
- operationName = rmqASyncSendPrefix + topic + rmqCallbackSuffix
-
- localSpan, localErr := tracing.CreateLocalSpan(operationName,
- tracing.WithComponent(rmqASyncComponentID),
- tracing.WithLayer(tracing.SpanLayerMQ),
- tracing.WithTag(tracing.TagMQTopic, topic),
- )
- if localErr != nil {
- zuper(ctx, sendResult, err)
- return
- }
- if err != nil {
- span.Error(err.Error())
- }
- localSpan.Tag(tracing.TagMQStatus,
SendStatusStr(sendResult.Status))
- localSpan.Tag(tracing.TagMQQueue, fmt.Sprintf("%d",
sendResult.MessageQueue.QueueId))
- localSpan.Tag(tracing.TagMQBroker,
defaultProducer.client.GetNameSrv().
-
FindBrokerAddrByName(sendResult.MessageQueue.BrokerName))
- localSpan.Tag(tracing.TagMQMsgID, sendResult.MsgID)
- localSpan.Tag(aSyncTagMQOffsetMsgID, sendResult.OffsetMsgID)
-
+ traceAsyncSendCallback(continueSnapShot, topic, peer,
sendResult, err, func(brokerName string) string {
+ return
defaultProducer.client.GetNameSrv().FindBrokerAddrByName(brokerName)
+ })
zuper(ctx, sendResult, err)
- localSpan.SetPeer(peer)
- localSpan.End()
}
span.SetPeer(peer)
@@ -99,6 +78,43 @@ func (sa *SendASyncInterceptor) BeforeInvoke(invocation
operator.Invocation) err
return nil
}
+// traceAsyncSendCallback records the async send result on a NEW local span -
+// never on the exit span, already ended by AfterInvoke. It runs on an SDK
+// goroutine without framework recover, so the agent logic is fully wrapped in
+// its own recover; the user callback runs outside, never swallowed.
+func traceAsyncSendCallback(snapshot tracing.ContextSnapshot, topic, peer
string,
+ sendResult *primitive.SendResult, sendErr error, brokerAddr
func(brokerName string) string) {
+ defer tracing.CleanContext()
+ defer func() {
+ // no logging channel exists on this goroutine, drop on purpose
+ _ = recover()
+ }()
+ tracing.ContinueContext(snapshot)
+
+ localSpan, err :=
tracing.CreateLocalSpan(rmqASyncSendPrefix+topic+rmqCallbackSuffix,
+ tracing.WithComponent(rmqASyncComponentID),
+ tracing.WithLayer(tracing.SpanLayerMQ),
+ tracing.WithTag(tracing.TagMQTopic, topic),
+ )
+ if err != nil {
+ return
+ }
+ if sendErr != nil {
+ localSpan.Error(sendErr.Error())
+ }
+ if sendResult != nil { // nil when the send failed
+ localSpan.Tag(tracing.TagMQStatus,
SendStatusStr(sendResult.Status))
+ if sendResult.MessageQueue != nil {
+ localSpan.Tag(tracing.TagMQQueue, fmt.Sprintf("%d",
sendResult.MessageQueue.QueueId))
+ localSpan.Tag(tracing.TagMQBroker,
brokerAddr(sendResult.MessageQueue.BrokerName))
+ }
+ localSpan.Tag(tracing.TagMQMsgID, sendResult.MsgID)
+ localSpan.Tag(aSyncTagMQOffsetMsgID, sendResult.OffsetMsgID)
+ }
+ localSpan.SetPeer(peer)
+ localSpan.End()
+}
+
func (sa *SendASyncInterceptor) AfterInvoke(invocation operator.Invocation,
result ...interface{}) error {
if invocation.GetContext() == nil {
return nil
diff --git a/plugins/rocketmq/producer/async_producer_test.go
b/plugins/rocketmq/producer/async_producer_test.go
new file mode 100644
index 0000000..25cb680
--- /dev/null
+++ b/plugins/rocketmq/producer/async_producer_test.go
@@ -0,0 +1,168 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package producer
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/apache/rocketmq-client-go/v2/primitive"
+
+ "github.com/apache/skywalking-go/plugins/core"
+ "github.com/apache/skywalking-go/plugins/core/reporter"
+ "github.com/apache/skywalking-go/plugins/core/tracing"
+)
+
+func findCallbackSpan(t *testing.T, name string) reporter.ReportedSpan {
+ t.Helper()
+ deadline := time.Now().Add(2 * time.Second)
+ for {
+ for _, s := range core.GetReportedSpans() {
+ if s.OperationName() == name {
+ return s
+ }
+ }
+ if time.Now().After(deadline) {
+ t.Fatalf("callback span %q was never reported", name)
+ }
+ time.Sleep(20 * time.Millisecond)
+ }
+}
+
+func callbackTagValue(s reporter.ReportedSpan, key string) (string, bool) {
+ for _, tag := range s.Tags() {
+ if tag.Key == key {
+ return tag.Value, true
+ }
+ }
+ return "", false
+}
+
+// runOnSDKGoroutine mirrors production: the send callback fires on a fresh
+// SDK goroutine after the caller already ended the exit span.
+func runOnSDKGoroutine(fn func()) {
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ fn()
+ }()
+ <-done
+}
+
+// TestAsyncCallbackFailedSendIsSafe pins the failed-send branch: sendResult is
+// nil (the old code dereferenced it and killed the process - no recover exists
+// on SDK goroutines) and the error must land on the CALLBACK local span, never
+// on the already-ended exit span.
+func TestAsyncCallbackFailedSendIsSafe(t *testing.T) {
+ core.ResetTracingContext()
+ defer core.ResetTracingContext()
+
+ span, err := tracing.CreateExitSpan("RocketMQ/TopicTest/AsyncProducer",
"nameserver:9876",
+ func(k, v string) error { return nil })
+ if err != nil {
+ t.Fatal(err)
+ }
+ snapshot := tracing.CaptureContext()
+ span.End() // AfterInvoke ends the exit span before the callback fires
+
+ runOnSDKGoroutine(func() {
+ traceAsyncSendCallback(snapshot, "TopicTest",
"nameserver:9876", nil,
+ errors.New("send to broker failed"),
+ func(string) string {
+ t.Error("broker lookup must not run for a
failed send")
+ return ""
+ })
+ })
+
+ local := findCallbackSpan(t, "RocketMQ/TopicTest/Producer/Callback")
+ if !local.IsError() {
+ t.Fatal("send error must be recorded on the callback local
span")
+ }
+ if _, ok := callbackTagValue(local, tracing.TagMQStatus); ok {
+ t.Fatal("failed send must not carry a status tag")
+ }
+}
+
+// TestAsyncCallbackSuccessTagsResult covers the happy path tag set.
+func TestAsyncCallbackSuccessTagsResult(t *testing.T) {
+ core.ResetTracingContext()
+ defer core.ResetTracingContext()
+
+ span, err := tracing.CreateExitSpan("RocketMQ/TopicTest/AsyncProducer",
"nameserver:9876",
+ func(k, v string) error { return nil })
+ if err != nil {
+ t.Fatal(err)
+ }
+ snapshot := tracing.CaptureContext()
+ span.End()
+
+ result := &primitive.SendResult{
+ Status: primitive.SendOK,
+ MsgID: "msg-1",
+ OffsetMsgID: "off-1",
+ MessageQueue: &primitive.MessageQueue{
+ Topic: "TopicTest",
+ BrokerName: "broker-a",
+ QueueId: 3,
+ },
+ }
+ runOnSDKGoroutine(func() {
+ traceAsyncSendCallback(snapshot, "TopicTest",
"nameserver:9876", result, nil,
+ func(brokerName string) string { return brokerName +
":10911" })
+ })
+
+ local := findCallbackSpan(t, "RocketMQ/TopicTest/Producer/Callback")
+ if local.IsError() {
+ t.Fatal("successful send must not be an error span")
+ }
+ if v, _ := callbackTagValue(local, tracing.TagMQBroker); v !=
"broker-a:10911" {
+ t.Fatalf("broker tag mismatch: %q", v)
+ }
+ if v, _ := callbackTagValue(local, tracing.TagMQMsgID); v != "msg-1" {
+ t.Fatalf("msg id tag mismatch: %q", v)
+ }
+ if v, _ := callbackTagValue(local, aSyncTagMQOffsetMsgID); v != "off-1"
{
+ t.Fatalf("offset msg id tag mismatch: %q", v)
+ }
+}
+
+// TestAsyncCallbackAgentPanicIsIsolated proves a panic inside the agent logic
+// (here: the broker lookup) never escapes to the user callback.
+func TestAsyncCallbackAgentPanicIsIsolated(t *testing.T) {
+ core.ResetTracingContext()
+ defer core.ResetTracingContext()
+
+ span, err := tracing.CreateExitSpan("RocketMQ/TopicTest/AsyncProducer",
"nameserver:9876",
+ func(k, v string) error { return nil })
+ if err != nil {
+ t.Fatal(err)
+ }
+ snapshot := tracing.CaptureContext()
+ span.End()
+
+ result := &primitive.SendResult{
+ Status: primitive.SendOK,
+ MessageQueue: &primitive.MessageQueue{BrokerName: "broker-a"},
+ }
+ runOnSDKGoroutine(func() {
+ // must return normally even though the agent logic panics
inside
+ traceAsyncSendCallback(snapshot, "TopicTest",
"nameserver:9876", result, nil,
+ func(string) string { panic("name server gone") })
+ })
+}