diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/ServerFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/ServerFactory.java index 5970f85c1745..38cae631a975 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/ServerFactory.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/ServerFactory.java @@ -25,6 +25,7 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.List; +import java.util.concurrent.TimeUnit; import java.util.function.Supplier; import org.apache.beam.model.pipeline.v1.Endpoints; import org.apache.beam.sdk.fn.channel.SocketAddressFactory; @@ -42,6 +43,9 @@ /** A {@link Server gRPC server} factory. */ public abstract class ServerFactory { + + private static final int KEEP_ALIVE_TIME_SEC = 20; + /** Create a default {@link InetSocketAddressServerFactory}. */ public static ServerFactory createDefault() { return new InetSocketAddressServerFactory(UrlFactory.createDefault()); @@ -144,7 +148,8 @@ private static Server createServer(List<BindableService> services, InetSocketAdd NettyServerBuilder.forPort(socket.getPort()) // Set the message size to max value here. The actual size is governed by the // buffer size in the layers above. - .maxMessageSize(Integer.MAX_VALUE); + .maxMessageSize(Integer.MAX_VALUE) + .permitKeepAliveTime(KEEP_ALIVE_TIME_SEC, TimeUnit.SECONDS); services .stream() .forEach( @@ -200,7 +205,8 @@ private static Server createServer( .channelType(EpollServerDomainSocketChannel.class) .workerEventLoopGroup(new EpollEventLoopGroup()) .bossEventLoopGroup(new EpollEventLoopGroup()) - .maxMessageSize(Integer.MAX_VALUE); + .maxMessageSize(Integer.MAX_VALUE) + .permitKeepAliveTime(KEEP_ALIVE_TIME_SEC, TimeUnit.SECONDS); for (BindableService service : services) { // Wrap the service to extract headers builder.addService( @@ -249,7 +255,8 @@ private static Server createServer(List<BindableService> services, InetSocketAdd .channelType(EpollServerSocketChannel.class) .workerEventLoopGroup(new EpollEventLoopGroup()) .bossEventLoopGroup(new EpollEventLoopGroup()) - .maxMessageSize(Integer.MAX_VALUE); + .maxMessageSize(Integer.MAX_VALUE) + .permitKeepAliveTime(KEEP_ALIVE_TIME_SEC, TimeUnit.SECONDS); for (BindableService service : services) { // Wrap the service to extract headers builder.addService( diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index a39a996478d6..1272b0e31b52 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -61,6 +61,7 @@ from apache_beam.runners.worker import bundle_processor from apache_beam.runners.worker import data_plane from apache_beam.runners.worker import sdk_worker +from apache_beam.runners.worker.channel_factory import GRPCChannelFactory from apache_beam.transforms import trigger from apache_beam.transforms.window import GlobalWindows from apache_beam.utils import profiler @@ -830,7 +831,8 @@ def __init__(self, external_payload, state): def start_worker(self): stub = beam_fn_api_pb2_grpc.BeamFnExternalWorkerPoolStub( - grpc.insecure_channel(self._external_payload.endpoint.url)) + GRPCChannelFactory.insecure_channel( + self._external_payload.endpoint.url)) response = stub.NotifyRunnerAvailable( beam_fn_api_pb2.NotifyRunnerAvailableRequest( control_endpoint=endpoints_pb2.ApiServiceDescriptor( diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py b/sdks/python/apache_beam/runners/portability/portable_runner.py index f6108ed030d0..d2bf31b2c61c 100644 --- a/sdks/python/apache_beam/runners/portability/portable_runner.py +++ b/sdks/python/apache_beam/runners/portability/portable_runner.py @@ -44,6 +44,7 @@ from apache_beam.runners.portability.job_server import DockerizedJobServer from apache_beam.runners.worker import sdk_worker from apache_beam.runners.worker import sdk_worker_main +from apache_beam.runners.worker.channel_factory import GRPCChannelFactory __all__ = ['PortableRunner'] @@ -188,7 +189,7 @@ def run_pipeline(self, pipeline, options): for k, v in options.get_all_options().items() if v is not None} - channel = grpc.insecure_channel(job_endpoint) + channel = GRPCChannelFactory.insecure_channel(job_endpoint) grpc.channel_ready_future(channel).result() job_service = beam_job_api_pb2_grpc.JobServiceStub(channel) @@ -212,7 +213,8 @@ def send_prepare_request(max_retries=5): prepare_response = send_prepare_request() if prepare_response.artifact_staging_endpoint.url: stager = portable_stager.PortableStager( - grpc.insecure_channel(prepare_response.artifact_staging_endpoint.url), + GRPCChannelFactory.insecure_channel( + prepare_response.artifact_staging_endpoint.url), prepare_response.staging_session_token) retrieval_token, _ = stager.stage_job_resources( options, diff --git a/sdks/python/apache_beam/runners/portability/portable_runner_test.py b/sdks/python/apache_beam/runners/portability/portable_runner_test.py index a37149d633bc..80cd7bf8be2f 100644 --- a/sdks/python/apache_beam/runners/portability/portable_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/portable_runner_test.py @@ -43,6 +43,7 @@ from apache_beam.runners.portability import portable_runner from apache_beam.runners.portability.local_job_service import LocalJobServicer from apache_beam.runners.portability.portable_runner import PortableRunner +from apache_beam.runners.worker.channel_factory import GRPCChannelFactory class PortableRunnerTest(fn_api_runner_test.FnApiRunnerTest): @@ -93,7 +94,7 @@ def _start_local_runner_subprocess_job_service(cls): cls._subprocess = subprocess.Popen(cls._subprocess_command(port)) address = 'localhost:%d' % port job_service = beam_job_api_pb2_grpc.JobServiceStub( - grpc.insecure_channel(address)) + GRPCChannelFactory.insecure_channel(address)) logging.info('Waiting for server to be ready...') start = time.time() timeout = 30 diff --git a/sdks/python/apache_beam/runners/worker/channel_factory.py b/sdks/python/apache_beam/runners/worker/channel_factory.py new file mode 100644 index 000000000000..d0823fa54842 --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/channel_factory.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Factory to create grpc channel.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import grpc + + +class GRPCChannelFactory(grpc.StreamStreamClientInterceptor): + DEFAULT_OPTIONS = [("grpc.keepalive_time_ms", 20000)] + + def __init__(self): + pass + + @staticmethod + def insecure_channel(target, options=None): + if options is None: + options = [] + return grpc.insecure_channel( + target, options=options + GRPCChannelFactory.DEFAULT_OPTIONS) + + @staticmethod + def secure_channel(target, credentials, options=None): + if options is None: + options = [] + return grpc.secure_channel( + target, credentials, + options=options + GRPCChannelFactory.DEFAULT_OPTIONS) diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index 12bd3764d7c8..fc8f9cca887a 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -37,6 +37,7 @@ from apache_beam.coders import coder_impl from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.portability.api import beam_fn_api_pb2_grpc +from apache_beam.runners.worker.channel_factory import GRPCChannelFactory from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor # This module is experimental. No backwards-compatibility guarantees. @@ -343,9 +344,10 @@ def create_data_channel(self, remote_grpc_port): ("grpc.max_send_message_length", -1)] grpc_channel = None if self._credentials is None: - grpc_channel = grpc.insecure_channel(url, options=channel_options) + grpc_channel = GRPCChannelFactory.insecure_channel( + url, options=channel_options) else: - grpc_channel = grpc.secure_channel( + grpc_channel = GRPCChannelFactory.secure_channel( url, self._credentials, options=channel_options) # Add workerId to the grpc channel grpc_channel = grpc.intercept_channel(grpc_channel, diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py index f72c7b00779b..cbd68f5de99d 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler.py +++ b/sdks/python/apache_beam/runners/worker/log_handler.py @@ -30,6 +30,7 @@ from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.portability.api import beam_fn_api_pb2_grpc +from apache_beam.runners.worker.channel_factory import GRPCChannelFactory from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor # This module is experimental. No backwards-compatibility guarantees. @@ -61,7 +62,7 @@ def __init__(self, log_service_descriptor): self._dropped_logs = 0 self._log_entry_queue = queue.Queue(maxsize=self._QUEUE_SIZE) - ch = grpc.insecure_channel(log_service_descriptor.url) + ch = GRPCChannelFactory.insecure_channel(log_service_descriptor.url) # Make sure the channel is ready to avoid [BEAM-4649] grpc.channel_ready_future(ch).result(timeout=60) self._log_channel = grpc.intercept_channel(ch, WorkerIdInterceptor()) diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 2d0b61d82263..b23cf68b1dff 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -41,6 +41,7 @@ from apache_beam.portability.api import beam_fn_api_pb2_grpc from apache_beam.runners.worker import bundle_processor from apache_beam.runners.worker import data_plane +from apache_beam.runners.worker.channel_factory import GRPCChannelFactory from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor @@ -57,10 +58,12 @@ def __init__( self._worker_id = worker_id if credentials is None: logging.info('Creating insecure control channel for %s.', control_address) - self._control_channel = grpc.insecure_channel(control_address) + self._control_channel = GRPCChannelFactory.insecure_channel( + control_address) else: logging.info('Creating secure control channel for %s.', control_address) - self._control_channel = grpc.secure_channel(control_address, credentials) + self._control_channel = GRPCChannelFactory.secure_channel( + control_address, credentials) grpc.channel_ready_future(self._control_channel).result(timeout=60) logging.info('Control channel established.') @@ -355,7 +358,7 @@ def create_state_handler(self, api_service_descriptor): with self._lock: if url not in self._state_handler_cache: logging.info('Creating insecure state channel for %s', url) - grpc_channel = grpc.insecure_channel( + grpc_channel = GRPCChannelFactory.insecure_channel( url, # Options to have no limits (-1) on the size of the messages # received or sent over the data plane. The actual buffer size is
With regards, Apache Git Services