This is an automated email from the ASF dual-hosted git repository.

wusheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/skywalking-python.git


The following commit(s) were added to refs/heads/master by this push:
     new 19715ff  Feature: support authentication in grpc (#3)
19715ff is described below

commit 19715ff183f4a563b1391437960ace0a0e00a82c
Author: kezhenxu94 <[email protected]>
AuthorDate: Wed May 6 20:10:36 2020 +0800

    Feature: support authentication in grpc (#3)
---
 skywalking/agent/protocol/grpc/__init__.py     |  6 +++
 skywalking/agent/protocol/grpc/interceptors.py | 70 ++++++++++++++++++++++++++
 skywalking/config/__init__.py                  |  5 ++
 3 files changed, 81 insertions(+)

diff --git a/skywalking/agent/protocol/grpc/__init__.py 
b/skywalking/agent/protocol/grpc/__init__.py
index bed8197..d495cdb 100644
--- a/skywalking/agent/protocol/grpc/__init__.py
+++ b/skywalking/agent/protocol/grpc/__init__.py
@@ -24,6 +24,8 @@ from common.Common_pb2 import KeyStringValuePair
 from language_agent.Tracing_pb2 import SegmentObject, SpanObject, Log
 from skywalking import config
 from skywalking.agent import Protocol
+from skywalking.agent.protocol.grpc import interceptors
+from skywalking.agent.protocol.grpc.interceptors import 
header_adder_interceptor
 from skywalking.client.grpc import GrpcServiceManagementClient, 
GrpcTraceSegmentReportService
 from skywalking.trace.segment import Segment
 
@@ -34,6 +36,10 @@ class GrpcProtocol(Protocol):
     def __init__(self):
         self.state = None
         self.channel = grpc.insecure_channel(config.collector_address)
+        if config.authentication:
+            self.channel = grpc.intercept_channel(
+                self.channel, header_adder_interceptor('authentication', 
config.authentication)
+            )
 
         def cb(state):
             logger.debug('grpc channel connectivity changed, [%s -> %s]', 
self.state, state)
diff --git a/skywalking/agent/protocol/grpc/interceptors.py 
b/skywalking/agent/protocol/grpc/interceptors.py
new file mode 100644
index 0000000..814e663
--- /dev/null
+++ b/skywalking/agent/protocol/grpc/interceptors.py
@@ -0,0 +1,70 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from collections import namedtuple
+
+import grpc
+
+
+class _ClientInterceptor(
+    grpc.UnaryUnaryClientInterceptor,
+    grpc.UnaryStreamClientInterceptor,
+    grpc.StreamUnaryClientInterceptor,
+    grpc.StreamStreamClientInterceptor
+):
+
+    def __init__(self, interceptor_function):
+        self._fn = interceptor_function
+
+    def intercept_unary_unary(self, continuation, client_call_details, 
request):
+        new_details, new_request_iterator, postprocess = 
self._fn(client_call_details, iter((request,)), False, False)
+        response = continuation(new_details, next(new_request_iterator))
+        return postprocess(response) if postprocess else response
+
+    def intercept_unary_stream(self, continuation, client_call_details, 
request):
+        new_details, new_request_iterator, postprocess = 
self._fn(client_call_details, iter((request,)), False, True)
+        response_it = continuation(new_details, next(new_request_iterator))
+        return postprocess(response_it) if postprocess else response_it
+
+    def intercept_stream_unary(self, continuation, client_call_details, 
request_iterator):
+        new_details, new_request_iterator, postprocess = 
self._fn(client_call_details, request_iterator, True, False)
+        response = continuation(new_details, new_request_iterator)
+        return postprocess(response) if postprocess else response
+
+    def intercept_stream_stream(self, continuation, client_call_details, 
request_iterator):
+        new_details, new_request_iterator, postprocess = 
self._fn(client_call_details, request_iterator, True, True)
+        response_it = continuation(new_details, new_request_iterator)
+        return postprocess(response_it) if postprocess else response_it
+
+
+def create(intercept_call):
+    return _ClientInterceptor(intercept_call)
+
+
+ClientCallDetails = namedtuple('ClientCallDetails', ('method', 'timeout', 
'metadata', 'credentials'))
+
+
+def header_adder_interceptor(header, value):
+    def intercept_call(client_call_details, request_iterator, 
request_streaming, response_streaming):
+        metadata = list(client_call_details.metadata or ())
+        metadata.append((header, value))
+        client_call_details = ClientCallDetails(
+            client_call_details.method, client_call_details.timeout, metadata, 
client_call_details.credentials,
+        )
+        return client_call_details, request_iterator, None
+
+    return create(intercept_call)
diff --git a/skywalking/config/__init__.py b/skywalking/config/__init__.py
index 9910da7..6fd2823 100644
--- a/skywalking/config/__init__.py
+++ b/skywalking/config/__init__.py
@@ -22,6 +22,7 @@ service_name = os.getenv('SW_AGENT_NAME') or 'Python Service 
Name'  # type: str
 service_instance = os.getenv('SW_AGENT_INSTANCE') or 
str(uuid.uuid1()).replace('-', '')  # type: str
 collector_address = os.getenv('SW_AGENT_COLLECTOR_BACKEND_SERVICES') or 
'127.0.0.1:11800'  # type: str
 protocol = os.getenv('SW_AGENT_PROTOCOL') or 'grpc'  # type: str
+authentication = os.getenv('SW_AGENT_AUTHENTICATION')
 
 
 def init(
@@ -29,6 +30,7 @@ def init(
         instance: str = None,
         collector: str = None,
         protocol_type: str = 'grpc',
+        token: str = None,
 ):
     global service_name
     service_name = service or service_name
@@ -41,3 +43,6 @@ def init(
 
     global protocol
     protocol = protocol_type or protocol
+
+    global authentication
+    authentication = token or authentication

Reply via email to