This is an automated email from the ASF dual-hosted git repository.

wusheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/skywalking-python.git


The following commit(s) were added to refs/heads/master by this push:
     new 9a419fd  [Core][Defect] Validate carrier before using it (#29)
9a419fd is described below

commit 9a419fd88393a65118b493c3c175795f265bcf61
Author: kezhenxu94 <kezhenx...@apache.org>
AuthorDate: Thu Jul 2 07:38:27 2020 +0800

    [Core][Defect] Validate carrier before using it (#29)
---
 skywalking/trace/carrier/__init__.py | 19 ++++++++++++++++---
 skywalking/trace/context/__init__.py |  2 +-
 skywalking/trace/segment/__init__.py |  2 +-
 skywalking/trace/span/__init__.py    |  4 ++--
 4 files changed, 20 insertions(+), 7 deletions(-)

diff --git a/skywalking/trace/carrier/__init__.py 
b/skywalking/trace/carrier/__init__.py
index 0b95861..468997b 100644
--- a/skywalking/trace/carrier/__init__.py
+++ b/skywalking/trace/carrier/__init__.py
@@ -47,7 +47,7 @@ class Carrier(CarrierItem):
         super(Carrier, self).__init__(key='sw8')
         self.trace_id = ''  # type: str
         self.segment_id = ''  # type: str
-        self.span_id = -1  # type: int
+        self.span_id = ''  # type: str
         self.service = ''  # type: str
         self.service_instance = ''  # type: str
         self.endpoint = ''  # type: str
@@ -61,7 +61,7 @@ class Carrier(CarrierItem):
             '1',
             b64encode(self.trace_id),
             b64encode(self.segment_id),
-            str(self.span_id),
+            self.span_id,
             b64encode(self.service),
             b64encode(self.service_instance),
             b64encode(self.endpoint),
@@ -74,14 +74,27 @@ class Carrier(CarrierItem):
         if not val:
             return
         parts = val.split('-')
+        if len(parts) != 8:
+            return
         self.trace_id = b64decode(parts[1])
         self.segment_id = b64decode(parts[2])
-        self.span_id = int(parts[3])
+        self.span_id = parts[3]
         self.service = b64decode(parts[4])
         self.service_instance = b64decode(parts[5])
         self.endpoint = b64decode(parts[6])
         self.client_address = b64decode(parts[7])
 
+    @property
+    def is_valid(self):
+        # type: () -> bool
+        return len(self.trace_id) > 0 and \
+               len(self.segment_id) > 0 and \
+               len(self.service) > 0 and \
+               len(self.service_instance) > 0 and \
+               len(self.endpoint) > 0 and \
+               len(self.client_address) > 0 and \
+               self.span_id.isnumeric()
+
     def __iter__(self):
         self.__iter_index = 0
         return self
diff --git a/skywalking/trace/context/__init__.py 
b/skywalking/trace/context/__init__.py
index dce8d68..190121f 100644
--- a/skywalking/trace/context/__init__.py
+++ b/skywalking/trace/context/__init__.py
@@ -55,7 +55,7 @@ class SpanContext(object):
         )
         span.op = op
 
-        if carrier is not None:
+        if carrier is not None and carrier.is_valid:
             span.extract(carrier=carrier)
 
         return span
diff --git a/skywalking/trace/segment/__init__.py 
b/skywalking/trace/segment/__init__.py
index 44f5b87..bf1a71e 100644
--- a/skywalking/trace/segment/__init__.py
+++ b/skywalking/trace/segment/__init__.py
@@ -31,7 +31,7 @@ class SegmentRef(object):
         self.ref_type = 'CrossProcess'  # type: str
         self.trace_id = carrier.trace_id  # type: str
         self.segment_id = carrier.segment_id  # type: str
-        self.span_id = carrier.span_id  # type: int
+        self.span_id = int(carrier.span_id)  # type: int
         self.service = carrier.service  # type: str
         self.service_instance = carrier.service_instance  # type: str
         self.endpoint = carrier.endpoint  # type: str
diff --git a/skywalking/trace/span/__init__.py 
b/skywalking/trace/span/__init__.py
index e448424..ec77a9b 100644
--- a/skywalking/trace/span/__init__.py
+++ b/skywalking/trace/span/__init__.py
@@ -165,7 +165,7 @@ class EntrySpan(StackedSpan):
     def extract(self, carrier: 'Carrier') -> 'Span':
         Span.extract(self, carrier)
 
-        if carrier is None:
+        if carrier is None or not carrier.is_valid:
             return self
 
         ref = SegmentRef(carrier=carrier)
@@ -203,7 +203,7 @@ class ExitSpan(StackedSpan):
     def inject(self, carrier: 'Carrier') -> 'Span':
         carrier.trace_id = str(self.context.segment.related_traces[0])
         carrier.segment_id = str(self.context.segment.segment_id)
-        carrier.span_id = self.sid
+        carrier.span_id = str(self.sid)
         carrier.service = config.service_name
         carrier.service_instance = config.service_instance
         carrier.endpoint = self.op

Reply via email to