This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new a4fb844  Secure GRPC channel for SDK worker (#4984)
a4fb844 is described below

commit a4fb844df82051ef93bb7e2d47967e143eabcc5c
Author: ananvay <ananvay2...@yahoo.com>
AuthorDate: Wed Apr 4 17:50:04 2018 -0700

    Secure GRPC channel for SDK worker (#4984)
---
 .../apache_beam/runners/worker/data_plane.py       | 25 +++++++++++++++-------
 .../apache_beam/runners/worker/sdk_worker.py       | 15 ++++++++++---
 2 files changed, 29 insertions(+), 11 deletions(-)

diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py 
b/sdks/python/apache_beam/runners/worker/data_plane.py
index 7c79c4c..1ff60aa 100644
--- a/sdks/python/apache_beam/runners/worker/data_plane.py
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -295,9 +295,13 @@ class GrpcClientDataChannelFactory(DataChannelFactory):
   Caches the created channels by ``data descriptor url``.
   """
 
-  def __init__(self):
+  def __init__(self, credentials=None):
     self._data_channel_cache = {}
     self._lock = threading.Lock()
+    self._credentials = None
+    if credentials is not None:
+      logging.info('Using secure channel creds.')
+      self._credentials = credentials
 
   def create_data_channel(self, remote_grpc_port):
     url = remote_grpc_port.api_service_descriptor.url
@@ -305,18 +309,23 @@ class GrpcClientDataChannelFactory(DataChannelFactory):
       with self._lock:
         if url not in self._data_channel_cache:
           logging.info('Creating channel for %s', url)
-          grpc_channel = grpc.insecure_channel(
-              url,
-              # Options to have no limits (-1) on the size of the messages
-              # received or sent over the data plane. The actual buffer size is
-              # controlled in a layer above.
-              options=[("grpc.max_receive_message_length", -1),
-                       ("grpc.max_send_message_length", -1)])
+          # Options to have no limits (-1) on the size of the messages
+          # received or sent over the data plane. The actual buffer size
+          # is controlled in a layer above.
+          channel_options = [("grpc.max_receive_message_length", -1),
+                             ("grpc.max_send_message_length", -1)]
+          grpc_channel = None
+          if self._credentials is None:
+            grpc_channel = grpc.insecure_channel(url, options=channel_options)
+          else:
+            grpc_channel = grpc.secure_channel(
+                url, self._credentials, options=channel_options)
           # Add workerId to the grpc channel
           grpc_channel = grpc.intercept_channel(grpc_channel,
                                                 WorkerIdInterceptor())
           self._data_channel_cache[url] = GrpcClientDataChannel(
               beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel))
+
     return self._data_channel_cache[url]
 
   def close(self):
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index c77659b..3b6ed65 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -40,12 +40,21 @@ from apache_beam.runners.worker.worker_id_interceptor 
import WorkerIdInterceptor
 class SdkHarness(object):
   REQUEST_METHOD_PREFIX = '_request_'
 
-  def __init__(self, control_address, worker_count):
+  def __init__(self, control_address, worker_count, credentials=None):
     self._worker_count = worker_count
     self._worker_index = 0
+    if credentials is None:
+      logging.info('Creating insecure channel.')
+      self._control_channel = grpc.insecure_channel(control_address)
+    else:
+      logging.info('Creating secure channel.')
+      self._control_channel = grpc.secure_channel(control_address, credentials)
+      grpc.channel_ready_future(self._control_channel).result()
+      logging.info('Secure channel established.')
     self._control_channel = grpc.intercept_channel(
-        grpc.insecure_channel(control_address), WorkerIdInterceptor())
-    self._data_channel_factory = data_plane.GrpcClientDataChannelFactory()
+        self._control_channel, WorkerIdInterceptor())
+    self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
+        credentials)
     self.workers = queue.Queue()
     # one thread is enough for getting the progress report.
     # Assumption:

-- 
To stop receiving notification emails like this one, please contact
rober...@apache.org.

Reply via email to