This is an automated email from the ASF dual-hosted git repository.

tompytel pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/skywalking-python.git


The following commit(s) were added to refs/heads/master by this push:
     new 25a5e7d  Python celery plugin (#125)
25a5e7d is described below

commit 25a5e7d875a3a71125a8ecf4f3a5e86e9705c7d4
Author: Tomasz Pytel <[email protected]>
AuthorDate: Wed Jul 7 08:59:08 2021 -0300

    Python celery plugin (#125)
    
    * WIP celery plugin and required core changes
    
    * tweaks and minor fixes
    
    * added requests package to setup.py
    
    * updated sw_sanic plugin rules
    
    * doc update
    
    Co-authored-by: kezhenxu94 <[email protected]>
---
 docs/EnvVars.md                       |   1 +
 docs/Plugins.md                       |   7 ++-
 requirements.txt                      |   1 +
 setup.py                              |   3 +-
 skywalking/__init__.py                |   1 +
 skywalking/agent/__init__.py          |  65 ++++++++++++++-----
 skywalking/agent/protocol/__init__.py |  11 +++-
 skywalking/agent/protocol/http.py     |  22 +++++--
 skywalking/client/http.py             |  22 ++++---
 skywalking/config.py                  |   3 +-
 skywalking/plugins/sw_celery.py       | 114 ++++++++++++++++++++++++++++++++++
 skywalking/plugins/sw_sanic.py        |   2 +-
 skywalking/trace/context.py           |   2 +-
 skywalking/trace/tags.py              |   1 +
 14 files changed, 222 insertions(+), 33 deletions(-)

diff --git a/docs/EnvVars.md b/docs/EnvVars.md
index ddaa240..4675433 100644
--- a/docs/EnvVars.md
+++ b/docs/EnvVars.md
@@ -26,3 +26,4 @@ Environment Variable | Description | Default
 | `SW_KAFKA_REPORTER_TOPIC_MANAGEMENT` | Specifying Kafka topic name for 
service instance reporting and registering. | `skywalking-managements` |
 | `SW_KAFKA_REPORTER_TOPIC_SEGMENT` | Specifying Kafka topic name for Tracing 
data. | `skywalking-segments` |
 | `SW_KAFKA_REPORTER_CONFIG_key` | The configs to init KafkaProducer. it 
support the basic arguments (whose type is either `str`, `bool`, or `int`) 
listed 
[here](https://kafka-python.readthedocs.io/en/master/apidoc/KafkaProducer.html#kafka.KafkaProducer)
 | unset |
+| `SW_CELERY_PARAMETERS_LENGTH`| The maximum length of `celery` functions 
parameters, longer than this will be truncated, 0 turns off  | `512` |
diff --git a/docs/Plugins.md b/docs/Plugins.md
index 28688e0..23f7473 100644
--- a/docs/Plugins.md
+++ b/docs/Plugins.md
@@ -3,7 +3,7 @@
 Library | Versions | Plugin Name
 | :--- | :--- | :--- |
 | [http.server](https://docs.python.org/3/library/http.server.html) | Python 
3.5 ~ 3.9 | `sw_http_server` |
-| [urllib.request](https://docs.python.org/3/library/urllib.request.html) | 
Python 3.5 ~ 3.8 | `sw_urllib_request` |
+| [urllib.request](https://docs.python.org/3/library/urllib.request.html) | 
Python 3.5 ~ 3.9 | `sw_urllib_request` |
 | [requests](https://requests.readthedocs.io/en/master/) | >= 2.9.0 < 2.15.0, 
>= 2.17.0 <= 2.24.0 | `sw_requests` |
 | [Flask](https://flask.palletsprojects.com/en/1.1.x/) | >=1.0.4 <= 1.1.2 | 
`sw_flask` |
 | [PyMySQL](https://pymysql.readthedocs.io/en/latest/) | 0.10.0 | `sw_pymysql` 
|
@@ -18,6 +18,9 @@ Library | Versions | Plugin Name
 | [sanic](https://sanic.readthedocs.io/en/latest/) | >= 20.3.0 <= 20.9.1 | 
`sw_sanic` |
 | [aiohttp](https://sanic.readthedocs.io/en/latest/) | >= 3.7.3 | `sw_aiohttp` 
|
 | [pyramid](https://trypyramid.com) | >= 1.9 | `sw_pyramid` |
-| [psycopg2](https://www.psycopg.org/) | 2.8.6 | `sw_psycopg2` |
+| [psycopg2](https://www.psycopg.org/) | >= 2.8.6 | `sw_psycopg2` |
+| [celery](https://docs.celeryproject.org/) | >= 4.2.1 | `sw_celery` |
+
+* Note: The celery server running with "celery -A ..." should be run with the 
http protocol as it uses multiprocessing by default which is not compatible 
with the grpc protocol implementation in skywalking currently. Celery clients 
can use whatever protocol they want.
 
 The column `Versions` only indicates that the versions are tested, if you 
found the newer versions are also supported, welcome to add the newer version 
into the table.
diff --git a/requirements.txt b/requirements.txt
index 91a06d2..71d5e80 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,6 +3,7 @@ aiofiles==0.6.0
 aiohttp==3.7.3
 attrs==19.3.0
 blindspin==2.0.1
+celery==4.4.7
 certifi==2020.6.20
 chardet==3.0.4
 click==7.1.2
diff --git a/setup.py b/setup.py
index ec24ee2..779baf1 100644
--- a/setup.py
+++ b/setup.py
@@ -33,12 +33,13 @@ setup(
     author="Apache",
     author_email="[email protected]",
     license="Apache 2.0",
-    packages=find_packages(exclude=("tests",)),
+    packages=find_packages(exclude=("tests", "tests.*")),
     include_package_data=True,
     install_requires=[
         "grpcio",
         "grpcio-tools",
         "packaging",
+        "requests",
         "wrapt",
     ],
     extras_require={
diff --git a/skywalking/__init__.py b/skywalking/__init__.py
index 75f9db2..95461ce 100644
--- a/skywalking/__init__.py
+++ b/skywalking/__init__.py
@@ -42,6 +42,7 @@ class Component(Enum):
     AioHttp = 7008
     Pyramid = 7009
     Psycopg = 7010
+    Celery = 7011
 
 
 class Layer(Enum):
diff --git a/skywalking/agent/__init__.py b/skywalking/agent/__init__.py
index 356a28a..64ca4c1 100644
--- a/skywalking/agent/__init__.py
+++ b/skywalking/agent/__init__.py
@@ -16,6 +16,7 @@
 #
 
 import atexit
+import os
 from queue import Queue, Full
 from threading import Thread, Event
 from typing import TYPE_CHECKING
@@ -28,6 +29,11 @@ if TYPE_CHECKING:
     from skywalking.trace.context import Segment
 
 
+__started = False
+__protocol = Protocol()  # type: Protocol
+__heartbeat_thread = __report_thread = __queue = __finished = None
+
+
 def __heartbeat():
     while not __finished.is_set():
         if connected():
@@ -39,21 +45,26 @@ def __heartbeat():
 def __report():
     while not __finished.is_set():
         if connected():
-            __protocol.report(__queue)  # is blocking actually
+            __protocol.report(__queue)  # is blocking actually, blocks for max 
config.QUEUE_TIMEOUT seconds
 
         __finished.wait(1)
 
 
-__heartbeat_thread = Thread(name='HeartbeatThread', target=__heartbeat, 
daemon=True)
-__report_thread = Thread(name='ReportThread', target=__report, daemon=True)
-__queue = Queue(maxsize=10000)
-__finished = Event()
-__protocol = Protocol()  # type: Protocol
-__started = False
+def __init_threading():
+    global __heartbeat_thread, __report_thread, __queue, __finished
+
+    __queue = Queue(maxsize=10000)
+    __finished = Event()
+    __heartbeat_thread = Thread(name='HeartbeatThread', target=__heartbeat, 
daemon=True)
+    __report_thread = Thread(name='ReportThread', target=__report, daemon=True)
+
+    __heartbeat_thread.start()
+    __report_thread.start()
 
 
 def __init():
     global __protocol
+
     if config.protocol == 'grpc':
         from skywalking.agent.protocol.grpc import GrpcProtocol
         __protocol = GrpcProtocol()
@@ -65,14 +76,40 @@ def __init():
         __protocol = KafkaProtocol()
 
     plugins.install()
+    __init_threading()
 
 
 def __fini():
     __protocol.report(__queue, False)
     __queue.join()
+    __finished.set()
+
+
+def __fork_before():
+    if config.protocol != 'http':
+        logger.warning('fork() not currently supported with %s protocol' % 
config.protocol)
+
+    # TODO: handle __queue and __finished correctly (locks, mutexes, etc...), 
need to lock before fork and unlock after
+    # if possible, or ensure they are not locked in threads (end threads and 
restart after fork?)
+
+    __protocol.fork_before()
+
+
+def __fork_after_in_parent():
+    __protocol.fork_after_in_parent()
+
+
+def __fork_after_in_child():
+    __protocol.fork_after_in_child()
+    __init_threading()
 
 
 def start():
+    global __started
+    if __started:
+        return
+    __started = True
+
     flag = False
     try:
         from gevent import monkey
@@ -82,22 +119,22 @@ def start():
     if flag:
         import grpc.experimental.gevent as grpc_gevent
         grpc_gevent.init_gevent()
-    global __started
-    if __started:
-        raise RuntimeError('the agent can only be started once')
+
     loggings.init()
     config.finalize()
-    __started = True
+
     __init()
-    __heartbeat_thread.start()
-    __report_thread.start()
+
     atexit.register(__fini)
 
+    if (hasattr(os, 'register_at_fork')):
+        os.register_at_fork(before=__fork_before, 
after_in_parent=__fork_after_in_parent,
+                            after_in_child=__fork_after_in_child)
+
 
 def stop():
     atexit.unregister(__fini)
     __fini()
-    __finished.set()
 
 
 def started():
diff --git a/skywalking/agent/protocol/__init__.py 
b/skywalking/agent/protocol/__init__.py
index 0f6e62e..3202734 100644
--- a/skywalking/agent/protocol/__init__.py
+++ b/skywalking/agent/protocol/__init__.py
@@ -20,8 +20,17 @@ from queue import Queue
 
 
 class Protocol(ABC):
+    def fork_before(self):
+        pass
+
+    def fork_after_in_parent(self):
+        pass
+
+    def fork_after_in_child(self):
+        pass
+
     def connected(self):
-        raise NotImplementedError()
+        return False
 
     def heartbeat(self):
         raise NotImplementedError()
diff --git a/skywalking/agent/protocol/http.py 
b/skywalking/agent/protocol/http.py
index 89d43bf..809d1f8 100644
--- a/skywalking/agent/protocol/http.py
+++ b/skywalking/agent/protocol/http.py
@@ -17,7 +17,9 @@
 
 from skywalking.loggings import logger
 from queue import Queue, Empty
+from time import time
 
+from skywalking import config
 from skywalking.agent import Protocol
 from skywalking.client.http import HttpServiceManagementClient, 
HttpTraceSegmentReportService
 from skywalking.trace.segment import Segment
@@ -29,20 +31,27 @@ class HttpProtocol(Protocol):
         self.service_management = HttpServiceManagementClient()
         self.traces_reporter = HttpTraceSegmentReportService()
 
+    def fork_after_in_child(self):
+        self.service_management.fork_after_in_child()
+        self.traces_reporter.fork_after_in_child()
+
+    def connected(self):
+        return True
+
     def heartbeat(self):
         if not self.properties_sent:
             self.service_management.send_instance_props()
             self.properties_sent = True
         self.service_management.send_heart_beat()
 
-    def connected(self):
-        return True
-
     def report(self, queue: Queue, block: bool = True):
+        start = time()
+
         def generator():
             while True:
                 try:
-                    segment = queue.get(block=block)  # type: Segment
+                    timeout = max(0, config.QUEUE_TIMEOUT - int(time() - 
start))  # type: int
+                    segment = queue.get(block=block, timeout=timeout)  # type: 
Segment
                 except Empty:
                     return
 
@@ -52,4 +61,7 @@ class HttpProtocol(Protocol):
 
                 queue.task_done()
 
-        self.traces_reporter.report(generator=generator())
+        try:
+            self.traces_reporter.report(generator=generator())
+        except Exception:
+            pass
diff --git a/skywalking/client/http.py b/skywalking/client/http.py
index 87c1c08..a614a00 100644
--- a/skywalking/client/http.py
+++ b/skywalking/client/http.py
@@ -25,10 +25,14 @@ from skywalking.client import ServiceManagementClient, 
TraceSegmentReportService
 
 class HttpServiceManagementClient(ServiceManagementClient):
     def __init__(self):
-        self.session = requests.session()
+        self.session = requests.Session()
+
+    def fork_after_in_child(self):
+        self.session.close()
+        self.session = requests.Session()
 
     def send_instance_props(self):
-        url = config.collector_address.rstrip('/') + 
'/v3/management/reportProperties'
+        url = 'http://' + config.collector_address.rstrip('/') + 
'/v3/management/reportProperties'
         res = self.session.post(url, json={
             'service': config.service_name,
             'serviceInstance': config.service_instance,
@@ -44,7 +48,7 @@ class HttpServiceManagementClient(ServiceManagementClient):
             config.service_name,
             config.service_instance,
         )
-        url = config.collector_address.rstrip('/') + '/v3/management/keepAlive'
+        url = 'http://' + config.collector_address.rstrip('/') + 
'/v3/management/keepAlive'
         res = self.session.post(url, json={
             'service': config.service_name,
             'serviceInstance': config.service_instance,
@@ -54,10 +58,14 @@ class HttpServiceManagementClient(ServiceManagementClient):
 
 class HttpTraceSegmentReportService(TraceSegmentReportService):
     def __init__(self):
-        self.session = requests.session()
+        self.session = requests.Session()
+
+    def fork_after_in_child(self):
+        self.session.close()
+        self.session = requests.Session()
 
     def report(self, generator):
-        url = config.collector_address.rstrip('/') + '/v3/segment'
+        url = 'http://' + config.collector_address.rstrip('/') + '/v3/segment'
         for segment in generator:
             res = self.session.post(url, json={
                 'traceId': str(segment.related_traces[0]),
@@ -76,10 +84,10 @@ class 
HttpTraceSegmentReportService(TraceSegmentReportService):
                     'componentId': span.component.value,
                     'isError': span.error_occurred,
                     'logs': [{
-                        'time': log.timestamp * 1000,
+                        'time': int(log.timestamp * 1000),
                         'data': [{
                             'key': item.key,
-                            'value': item.val
+                            'value': item.val,
                         } for item in log.items],
                     } for log in span.logs],
                     'tags': [{
diff --git a/skywalking/config.py b/skywalking/config.py
index b8a8027..b1ff2e4 100644
--- a/skywalking/config.py
+++ b/skywalking/config.py
@@ -59,13 +59,14 @@ elasticsearch_trace_dsl = True if 
os.getenv('SW_ELASTICSEARCH_TRACE_DSL') and \
 kafka_bootstrap_servers = os.getenv('SW_KAFKA_REPORTER_BOOTSTRAP_SERVERS') or 
"localhost:9092"  # type: str
 kafka_topic_management = os.getenv('SW_KAFKA_REPORTER_TOPIC_MANAGEMENT') or 
"skywalking-managements"  # type: str
 kafka_topic_segment = os.getenv('SW_KAFKA_REPORTER_TOPIC_SEGMENT') or 
"skywalking-segments"  # type: str
+celery_parameters_length = int(os.getenv('SW_CELERY_PARAMETERS_LENGTH') or 
'512')
 
 
 def init(
         service: str = None,
         instance: str = None,
         collector: str = None,
-        protocol_type: str = 'grpc',
+        protocol_type: str = None,
         token: str = None,
 ):
     global service_name
diff --git a/skywalking/plugins/sw_celery.py b/skywalking/plugins/sw_celery.py
new file mode 100644
index 0000000..a7cff41
--- /dev/null
+++ b/skywalking/plugins/sw_celery.py
@@ -0,0 +1,114 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from skywalking import Layer, Component, config
+from skywalking.trace import tags
+from skywalking.trace.carrier import Carrier
+from skywalking.trace.context import get_context
+from skywalking.trace.tags import Tag
+
+
+def install():
+    from urllib.parse import urlparse
+    from celery import Celery
+
+    def send_task(self, name, args=None, kwargs=None, **options):
+        # NOTE: Lines commented out below left for documentation purposes if 
sometime in the future exchange / queue
+        # names are wanted. Currently these do not match between producer and 
consumer so would need some work.
+
+        broker_url = self.conf['broker_url']
+        # exchange = options['exchange']
+        # queue = options['routing_key']
+        # op = 'celery/{}/{}/{}'.format(exchange or '', queue or '', name)
+        op = 'celery/' + name
+
+        if broker_url:
+            url = urlparse(broker_url)
+            peer = '{}:{}'.format(url.hostname, url.port)
+        else:
+            peer = '???'
+
+        with get_context().new_exit_span(op=op, peer=peer) as span:
+            span.layer = Layer.MQ
+            span.component = Component.Celery
+
+            span.tag(Tag(key=tags.MqBroker, val=broker_url))
+            # span.tag(Tag(key=tags.MqTopic, val=exchange))
+            # span.tag(Tag(key=tags.MqQueue, val=queue))
+
+            if config.celery_parameters_length:
+                params = '*{}, **{}'.format(args, 
kwargs)[:config.celery_parameters_length]
+                span.tag(Tag(key=tags.CeleryParameters, val=params))
+
+            options = {**options}
+            headers = options.get('headers')
+            headers = {**headers} if headers else {}
+            options['headers'] = headers
+
+            for item in span.inject():
+                headers[item.key] = item.val
+
+            return _send_task(self, name, args, kwargs, **options)
+
+    _send_task = Celery.send_task
+    Celery.send_task = send_task
+
+    def task_from_fun(self, _fun, name=None, **options):
+        def fun(*args, **kwargs):
+            req = task.request_stack.top
+            # di = req.get('delivery_info')
+            # exchange = di and di.get('exchange')
+            # queue = di and di.get('routing_key')
+            # op = 'celery/{}/{}/{}'.format(exchange or '', queue or '', name)
+            op = 'celery/' + name
+            carrier = Carrier()
+
+            for item in carrier:
+                val = req.get(item.key)
+
+                if val:
+                    item.val = val
+
+            context = get_context()
+
+            if req.get('sw8'):
+                span = context.new_entry_span(op=op, carrier=carrier)
+                span.peer = (req.get('hostname') or '???').split('@', 1)[-1]
+            else:
+                span = context.new_local_span(op=op)
+
+            with span:
+                span.layer = Layer.MQ
+                span.component = Component.Celery
+
+                span.tag(Tag(key=tags.MqBroker, 
val=task.app.conf['broker_url']))
+                # span.tag(Tag(key=tags.MqTopic, val=exchange))
+                # span.tag(Tag(key=tags.MqQueue, val=queue))
+
+                if config.celery_parameters_length:
+                    params = '*{}, **{}'.format(args, 
kwargs)[:config.celery_parameters_length]
+                    span.tag(Tag(key=tags.CeleryParameters, val=params))
+
+                return _fun(*args, **kwargs)
+
+        name = name or self.gen_task_name(_fun.__name__, _fun.__module__)
+        task = _task_from_fun(self, fun, name, **options)
+
+        return task
+
+    _task_from_fun = Celery._task_from_fun
+    Celery._task_from_fun = task_from_fun
diff --git a/skywalking/plugins/sw_sanic.py b/skywalking/plugins/sw_sanic.py
index 1ea6299..2cf68bf 100644
--- a/skywalking/plugins/sw_sanic.py
+++ b/skywalking/plugins/sw_sanic.py
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
 
 version_rule = {
     "name": "sanic",
-    "rules": [">=20.3.0"]
+    "rules": [">=20.3.0", "<21.0.0"]
 }
 
 
diff --git a/skywalking/trace/context.py b/skywalking/trace/context.py
index 51dffa3..f7b7440 100644
--- a/skywalking/trace/context.py
+++ b/skywalking/trace/context.py
@@ -119,7 +119,7 @@ class SpanContext(object):
         spans = _spans_dup()
         parent = spans[-1] if spans else None  # type: Span
 
-        span = parent if parent is not None and parent.kind.is_exit else 
ExitSpan(
+        span = ExitSpan(
             context=self,
             sid=self._sid.next(),
             pid=parent.sid if parent else -1,
diff --git a/skywalking/trace/tags.py b/skywalking/trace/tags.py
index 57a389d..f7c9abd 100644
--- a/skywalking/trace/tags.py
+++ b/skywalking/trace/tags.py
@@ -31,3 +31,4 @@ HttpParams = 'http.params'
 MqBroker = 'mq.broker'
 MqTopic = 'mq.topic'
 MqQueue = 'mq.queue'
+CeleryParameters = 'celery.parameters'

Reply via email to