This is an automated email from the ASF dual-hosted git repository.

kezhenxu94 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/skywalking-python.git


The following commit(s) were added to refs/heads/master by this push:
     new 6b0c2e8  Support propagating correlation context (#55)
6b0c2e8 is described below

commit 6b0c2e84de173f685230b1d73a7fc1c701d89fc1
Author: huawei <[email protected]>
AuthorDate: Sat Aug 1 13:44:50 2020 +0800

    Support propagating correlation context (#55)
---
 README.md                                  |  2 ++
 skywalking/config/__init__.py              |  2 ++
 skywalking/trace/carrier/__init__.py       | 33 +++++++++++++++++++++++++++++-
 skywalking/trace/context/__init__.py       | 25 ++++++++++++++++++++++
 skywalking/trace/span/__init__.py          |  8 +++++++-
 tests/plugin/sw_flask/services/consumer.py |  2 ++
 tests/plugin/sw_flask/services/provider.py |  3 ++-
 tests/plugin/sw_flask/test_flask.py        |  5 ++++-
 8 files changed, 76 insertions(+), 4 deletions(-)

diff --git a/README.md b/README.md
index 0dcf7bc..fd8521c 100755
--- a/README.md
+++ b/README.md
@@ -64,6 +64,8 @@ Environment Variable | Description | Default
 | `SW_FLASK_COLLECT_HTTP_PARAMS`| This config item controls that whether the 
Flask plugin should collect the parameters of the request.| `false` |
 | `SW_DJANGO_COLLECT_HTTP_PARAMS`| This config item controls that whether the 
Django plugin should collect the parameters of the request.| `false` |
 | `SW_HTTP_PARAMS_LENGTH_THRESHOLD`| When `COLLECT_HTTP_PARAMS` is enabled, 
how many characters to keep and send to the OAP backend, use negative values to 
keep and send the complete parameters, NB. this config item is added for the 
sake of performance.  | `1024` |
+| `SW_CORRELATION_ELEMENT_MAX_NUMBER`|Max element count of the correlation 
context.| `3` |
+| `SW_CORRELATION_VALUE_MAX_LENGTH`| Max value length of correlation context 
element.| `128` |
 
 
 ## Supported Libraries
diff --git a/skywalking/config/__init__.py b/skywalking/config/__init__.py
index e6e0920..cc2544e 100644
--- a/skywalking/config/__init__.py
+++ b/skywalking/config/__init__.py
@@ -36,6 +36,8 @@ flask_collect_http_params = True if 
os.getenv('SW_FLASK_COLLECT_HTTP_PARAMS') an
 http_params_length_threshold = 
int(os.getenv('SW_HTTP_PARAMS_LENGTH_THRESHOLD') or '1024')  # type: int
 django_collect_http_params = True if 
os.getenv('SW_DJANGO_COLLECT_HTTP_PARAMS') and \
                                     os.getenv('SW_DJANGO_COLLECT_HTTP_PARAMS') 
== 'True' else False   # type: bool
+correlation_element_max_number = 
int(os.getenv('SW_CORRELATION_ELEMENT_MAX_NUMBER') or '3')  # type: int
+correlation_value_max_length = 
int(os.getenv('SW_CORRELATION_VALUE_MAX_LENGTH') or '128')  # type: int
 
 
 def init(
diff --git a/skywalking/trace/carrier/__init__.py 
b/skywalking/trace/carrier/__init__.py
index 468997b..d485ca4 100644
--- a/skywalking/trace/carrier/__init__.py
+++ b/skywalking/trace/carrier/__init__.py
@@ -17,6 +17,7 @@
 
 from typing import List
 
+from skywalking import config
 from skywalking.utils.lang import b64encode, b64decode
 
 
@@ -52,7 +53,8 @@ class Carrier(CarrierItem):
         self.service_instance = ''  # type: str
         self.endpoint = ''  # type: str
         self.client_address = ''  # type: str
-        self.items = [self]  # type: List[CarrierItem]
+        self.correlation_carrier = SW8CorrelationCarrier()
+        self.items = [self.correlation_carrier, self]  # type: 
List[CarrierItem]
         self.__iter_index = 0  # type: int
 
     @property
@@ -105,3 +107,32 @@ class Carrier(CarrierItem):
         n = self.items[self.__iter_index]
         self.__iter_index += 1
         return n
+
+
+class SW8CorrelationCarrier(CarrierItem):
+    def __init__(self):
+        super(SW8CorrelationCarrier, self).__init__(key='sw8-correlation')
+        self.correlation = {}  # type: dict
+
+    @property
+    def val(self) -> str:
+        if self.correlation is None or len(self.correlation) == 0:
+            return ""
+
+        return ','.join([
+            b64encode(k) + ":" + b64encode(v)
+            for k, v in self.correlation.items()
+        ])
+
+    @val.setter
+    def val(self, val: str):
+        self.__val = val
+        if not val:
+            return
+        for per in val.split(","):
+            if len(self.correlation) > config.correlation_element_max_number:
+                break
+            parts = per.split(":")
+            if len(parts) != 2:
+                continue
+            self.correlation[b64decode(parts[0])] = b64decode(parts[1])
diff --git a/skywalking/trace/context/__init__.py 
b/skywalking/trace/context/__init__.py
index a6acdcf..99fe8af 100644
--- a/skywalking/trace/context/__init__.py
+++ b/skywalking/trace/context/__init__.py
@@ -33,6 +33,7 @@ class SpanContext(object):
         self.spans = []  # type: List[Span]
         self.segment = Segment()  # type: Segment
         self._sid = Counter()
+        self._correlation = {}  # type: dict
 
     def new_local_span(self, op: str) -> Span:
         span = self.ignore_check(op, Kind.Local)
@@ -118,12 +119,31 @@ class SpanContext(object):
 
         return None
 
+    def get_correlation(self, key):
+        if key in self._correlation:
+            return self._correlation[key]
+        return None
+
+    def put_correlation(self, key, value):
+        if key is None:
+            return
+        if value is None:
+            self._correlation.pop(key, value)
+            return
+        if len(value) > config.correlation_value_max_length:
+            return
+        if len(self._correlation) > config.correlation_element_max_number:
+            return
+
+        self._correlation[key] = value
+
 
 class NoopContext(SpanContext):
     def __init__(self):
         super().__init__()
         self._depth = 0
         self._noop_span = NoopSpan(self, kind=Kind.Local)
+        self.correlation = {}  # type: dict
 
     def new_local_span(self, op: str) -> Span:
         self._depth += 1
@@ -131,10 +151,15 @@ class NoopContext(SpanContext):
 
     def new_entry_span(self, op: str, carrier: 'Carrier' = None) -> Span:
         self._depth += 1
+        if carrier is not None:
+            self._noop_span.extract(carrier)
         return self._noop_span
 
     def new_exit_span(self, op: str, peer: str, carrier: 'Carrier' = None) -> 
Span:
         self._depth += 1
+        if carrier is not None:
+            self._noop_span.inject(carrier)
+
         return self._noop_span
 
     def stop(self, span: Span) -> bool:
diff --git a/skywalking/trace/span/__init__.py 
b/skywalking/trace/span/__init__.py
index 87ba9ec..587cec3 100644
--- a/skywalking/trace/span/__init__.py
+++ b/skywalking/trace/span/__init__.py
@@ -109,7 +109,7 @@ class Span(ABC):
             return self
 
         self.context.segment.relate(ID(carrier.trace_id))
-
+        self.context._correlation = carrier.correlation_carrier.correlation
         return self
 
     def __enter__(self):
@@ -213,6 +213,7 @@ class ExitSpan(StackedSpan):
         carrier.service_instance = config.service_instance
         carrier.endpoint = self.op
         carrier.client_address = self.peer
+        carrier.correlation_carrier.correlation = self.context._correlation
         return self
 
     def start(self):
@@ -225,5 +226,10 @@ class NoopSpan(Span):
     def __init__(self, context: 'SpanContext' = None, kind: 'Kind' = None):
         Span.__init__(self, context=context, kind=kind)
 
+    def extract(self, carrier: 'Carrier') -> 'Span':
+        if carrier is not None:
+            self.context._correlation = carrier.correlation_carrier.correlation
+
     def inject(self, carrier: 'Carrier') -> 'Span':
+        carrier.correlation_carrier.correlation = self.context._correlation
         return self
diff --git a/tests/plugin/sw_flask/services/consumer.py 
b/tests/plugin/sw_flask/services/consumer.py
index 45e7e80..129222a 100644
--- a/tests/plugin/sw_flask/services/consumer.py
+++ b/tests/plugin/sw_flask/services/consumer.py
@@ -31,6 +31,8 @@ if __name__ == '__main__':
 
     @app.route("/users", methods=["POST", "GET"])
     def application():
+        from skywalking.trace.context import get_context
+        get_context().put_correlation("correlation", "correlation")
         res = requests.post("http://provider:9091/users";)
         return jsonify(res.json())
 
diff --git a/tests/plugin/sw_flask/services/provider.py 
b/tests/plugin/sw_flask/services/provider.py
index 11f2c0b..4bf4722 100644
--- a/tests/plugin/sw_flask/services/provider.py
+++ b/tests/plugin/sw_flask/services/provider.py
@@ -30,8 +30,9 @@ if __name__ == '__main__':
 
     @app.route("/users", methods=["POST", "GET"])
     def application():
+        from skywalking.trace.context import get_context
         time.sleep(0.5)
-        return jsonify({"song": "Despacito", "artist": "Luis Fonsi"})
+        return jsonify({"correlation": 
get_context().get_correlation("correlation")})
 
     PORT = 9091
     app.run(host='0.0.0.0', port=PORT, debug=True)
diff --git a/tests/plugin/sw_flask/test_flask.py 
b/tests/plugin/sw_flask/test_flask.py
index 79827f5..1ff07d5 100644
--- a/tests/plugin/sw_flask/test_flask.py
+++ b/tests/plugin/sw_flask/test_flask.py
@@ -19,6 +19,7 @@ import time
 import unittest
 from os.path import dirname
 
+import requests
 from testcontainers.compose import DockerCompose
 
 from tests.plugin import BasePluginTest
@@ -29,13 +30,15 @@ class TestPlugin(BasePluginTest):
     def setUpClass(cls):
         cls.compose = DockerCompose(filepath=dirname(inspect.getfile(cls)))
         cls.compose.start()
-
         cls.compose.wait_for(cls.url(('consumer', '9090'), 
'users?test=test1&test=test2&test2=test2'))
 
     def test_plugin(self):
         time.sleep(3)
 
         self.validate()
+        response = requests.get(TestPlugin.url(('consumer', '9090'), 'users'))
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(response.json()["correlation"], "correlation")
 
 
 if __name__ == '__main__':

Reply via email to