[
https://issues.apache.org/jira/browse/BEAM-6094?focusedWorklogId=174496&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-174496
]
ASF GitHub Bot logged work on BEAM-6094:
----------------------------------------
Author: ASF GitHub Bot
Created on: 12/Dec/18 12:02
Start Date: 12/Dec/18 12:02
Worklog Time Spent: 10m
Work Description: robertwb closed pull request #7078: [BEAM-6094]
Implement external environment for portable BeamPython.
URL: https://github.com/apache/beam/pull/7078
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/model/fn-execution/src/main/proto/beam_fn_api.proto
b/model/fn-execution/src/main/proto/beam_fn_api.proto
index e7586ea07a3c..a602f0088c05 100644
--- a/model/fn-execution/src/main/proto/beam_fn_api.proto
+++ b/model/fn-execution/src/main/proto/beam_fn_api.proto
@@ -976,3 +976,23 @@ service BeamFnLogging {
stream LogControl
) {}
}
+
+
+message NotifyRunnerAvailableRequest {
+ string worker_id = 1;
+ org.apache.beam.model.pipeline.v1.ApiServiceDescriptor control_endpoint = 2;
+ org.apache.beam.model.pipeline.v1.ApiServiceDescriptor logging_endpoint = 3;
+ org.apache.beam.model.pipeline.v1.ApiServiceDescriptor artifact_endpoint = 4;
+ org.apache.beam.model.pipeline.v1.ApiServiceDescriptor provision_endpoint =
5;
+ map<string, string> params = 10;
+}
+
+message NotifyRunnerAvailableResponse {
+ string error = 1;
+}
+
+service BeamFnExternalWorkerPool {
+ rpc NotifyRunnerAvailable(NotifyRunnerAvailableRequest)
+ returns (NotifyRunnerAvailableResponse) {
+ }
+}
diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto
b/model/pipeline/src/main/proto/beam_runner_api.proto
index a5131c7e8ec5..42a197025f74 100644
--- a/model/pipeline/src/main/proto/beam_runner_api.proto
+++ b/model/pipeline/src/main/proto/beam_runner_api.proto
@@ -29,6 +29,7 @@ option go_package = "pipeline_v1";
option java_package = "org.apache.beam.model.pipeline.v1";
option java_outer_classname = "RunnerApi";
+import "endpoints.proto";
import "google/protobuf/any.proto";
import "google/protobuf/descriptor.proto";
@@ -982,6 +983,11 @@ message ProcessPayload {
map<string, string> env = 4; // Environment variables
}
+message ExternalPayload {
+ ApiServiceDescriptor endpoint = 1;
+ map<string, string> params = 2; // Arbitrary extra parameters to pass
+}
+
// A specification of a user defined function.
//
message SdkFunctionSpec {
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDefaultExecutableStageContext.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDefaultExecutableStageContext.java
index b7bbd84f5c8c..adde82208727 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDefaultExecutableStageContext.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDefaultExecutableStageContext.java
@@ -34,6 +34,7 @@
import org.apache.beam.runners.fnexecution.control.StageBundleFactory;
import
org.apache.beam.runners.fnexecution.environment.DockerEnvironmentFactory;
import
org.apache.beam.runners.fnexecution.environment.EmbeddedEnvironmentFactory;
+import
org.apache.beam.runners.fnexecution.environment.ExternalEnvironmentFactory;
import
org.apache.beam.runners.fnexecution.environment.ProcessEnvironmentFactory;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.sdk.options.PortablePipelineOptions;
@@ -52,6 +53,8 @@ private static FlinkDefaultExecutableStageContext
create(JobInfo jobInfo) {
PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions())),
BeamUrns.getUrn(StandardEnvironments.Environments.PROCESS),
new ProcessEnvironmentFactory.Provider(),
+ BeamUrns.getUrn(StandardEnvironments.Environments.EXTERNAL),
+ new ExternalEnvironmentFactory.Provider(),
Environments.ENVIRONMENT_EMBEDDED, // Non Public urn for
testing.
new EmbeddedEnvironmentFactory.Provider()));
return new FlinkDefaultExecutableStageContext(jobBundleFactory);
diff --git
a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/ExternalEnvironmentFactory.java
b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/ExternalEnvironmentFactory.java
new file mode 100644
index 000000000000..1e23f9f56907
--- /dev/null
+++
b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/ExternalEnvironmentFactory.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.environment;
+
+import com.google.common.base.Preconditions;
+import java.time.Duration;
+import java.util.concurrent.TimeoutException;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.fnexecution.v1.BeamFnExternalWorkerPoolGrpc;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.model.pipeline.v1.RunnerApi.Environment;
+import org.apache.beam.runners.core.construction.BeamUrns;
+import org.apache.beam.runners.fnexecution.GrpcFnServer;
+import org.apache.beam.runners.fnexecution.artifact.ArtifactRetrievalService;
+import org.apache.beam.runners.fnexecution.control.ControlClientPool;
+import
org.apache.beam.runners.fnexecution.control.FnApiControlClientPoolService;
+import org.apache.beam.runners.fnexecution.control.InstructionRequestHandler;
+import org.apache.beam.runners.fnexecution.logging.GrpcLoggingService;
+import
org.apache.beam.runners.fnexecution.provisioning.StaticGrpcProvisionService;
+import org.apache.beam.sdk.fn.IdGenerator;
+import org.apache.beam.sdk.fn.channel.ManagedChannelFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** An {@link EnvironmentFactory} which requests workers via the given URL in
the Environment. */
+public class ExternalEnvironmentFactory implements EnvironmentFactory {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(ExternalEnvironmentFactory.class);
+
+ public static ExternalEnvironmentFactory create(
+ GrpcFnServer<FnApiControlClientPoolService> controlServiceServer,
+ GrpcFnServer<GrpcLoggingService> loggingServiceServer,
+ GrpcFnServer<ArtifactRetrievalService> retrievalServiceServer,
+ GrpcFnServer<StaticGrpcProvisionService> provisioningServiceServer,
+ ControlClientPool.Source clientSource,
+ IdGenerator idGenerator) {
+ return new ExternalEnvironmentFactory(
+ controlServiceServer,
+ loggingServiceServer,
+ retrievalServiceServer,
+ provisioningServiceServer,
+ idGenerator,
+ clientSource);
+ }
+
+ private final GrpcFnServer<FnApiControlClientPoolService>
controlServiceServer;
+ private final GrpcFnServer<GrpcLoggingService> loggingServiceServer;
+ private final GrpcFnServer<ArtifactRetrievalService> retrievalServiceServer;
+ private final GrpcFnServer<StaticGrpcProvisionService>
provisioningServiceServer;
+ private final IdGenerator idGenerator;
+ private final ControlClientPool.Source clientSource;
+
+ private ExternalEnvironmentFactory(
+ GrpcFnServer<FnApiControlClientPoolService> controlServiceServer,
+ GrpcFnServer<GrpcLoggingService> loggingServiceServer,
+ GrpcFnServer<ArtifactRetrievalService> retrievalServiceServer,
+ GrpcFnServer<StaticGrpcProvisionService> provisioningServiceServer,
+ IdGenerator idGenerator,
+ ControlClientPool.Source clientSource) {
+ this.controlServiceServer = controlServiceServer;
+ this.loggingServiceServer = loggingServiceServer;
+ this.retrievalServiceServer = retrievalServiceServer;
+ this.provisioningServiceServer = provisioningServiceServer;
+ this.idGenerator = idGenerator;
+ this.clientSource = clientSource;
+ }
+
+ /** Creates a new, active {@link RemoteEnvironment} backed by an unmanaged
worker. */
+ @Override
+ public RemoteEnvironment createEnvironment(Environment environment) throws
Exception {
+ Preconditions.checkState(
+ environment
+ .getUrn()
+
.equals(BeamUrns.getUrn(RunnerApi.StandardEnvironments.Environments.EXTERNAL)),
+ "The passed environment does not contain an ExternalPayload.");
+ final RunnerApi.ExternalPayload externalPayload =
+ RunnerApi.ExternalPayload.parseFrom(environment.getPayload());
+ final String workerId = idGenerator.getId();
+
+ BeamFnApi.NotifyRunnerAvailableRequest notifyRunnerAvailableRequest =
+ BeamFnApi.NotifyRunnerAvailableRequest.newBuilder()
+ .setWorkerId(workerId)
+ .setControlEndpoint(controlServiceServer.getApiServiceDescriptor())
+ .setLoggingEndpoint(loggingServiceServer.getApiServiceDescriptor())
+
.setArtifactEndpoint(retrievalServiceServer.getApiServiceDescriptor())
+
.setProvisionEndpoint(provisioningServiceServer.getApiServiceDescriptor())
+ .putAllParams(externalPayload.getParamsMap())
+ .build();
+
+ LOG.debug("Requesting worker ID {}", workerId);
+ BeamFnApi.NotifyRunnerAvailableResponse notifyRunnerAvailableResponse =
+ BeamFnExternalWorkerPoolGrpc.newBlockingStub(
+
ManagedChannelFactory.createDefault().forDescriptor(externalPayload.getEndpoint()))
+ .notifyRunnerAvailable(notifyRunnerAvailableRequest);
+ if (!notifyRunnerAvailableResponse.getError().isEmpty()) {
+ throw new RuntimeException(notifyRunnerAvailableResponse.getError());
+ }
+
+ // Wait on a client from the gRPC server.
+ InstructionRequestHandler instructionHandler = null;
+ while (instructionHandler == null) {
+ try {
+ instructionHandler = clientSource.take(workerId,
Duration.ofMinutes(2));
+ } catch (TimeoutException timeoutEx) {
+ LOG.info(
+ "Still waiting for startup of environment from {} for worker id
{}",
+ externalPayload.getEndpoint().getUrl(),
+ workerId);
+ } catch (InterruptedException interruptEx) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException(interruptEx);
+ }
+ }
+ final InstructionRequestHandler finalInstructionHandler =
instructionHandler;
+
+ return new RemoteEnvironment() {
+ @Override
+ public Environment getEnvironment() {
+ return environment;
+ }
+
+ @Override
+ public InstructionRequestHandler getInstructionRequestHandler() {
+ return finalInstructionHandler;
+ }
+ };
+ }
+
+ /** Provider of ExternalEnvironmentFactory. */
+ public static class Provider implements EnvironmentFactory.Provider {
+ @Override
+ public EnvironmentFactory createEnvironmentFactory(
+ GrpcFnServer<FnApiControlClientPoolService> controlServiceServer,
+ GrpcFnServer<GrpcLoggingService> loggingServiceServer,
+ GrpcFnServer<ArtifactRetrievalService> retrievalServiceServer,
+ GrpcFnServer<StaticGrpcProvisionService> provisioningServiceServer,
+ ControlClientPool clientPool,
+ IdGenerator idGenerator) {
+ return create(
+ controlServiceServer,
+ loggingServiceServer,
+ retrievalServiceServer,
+ provisioningServiceServer,
+ clientPool.getSource(),
+ idGenerator);
+ }
+ }
+}
diff --git
a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/ProcessEnvironmentFactory.java
b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/ProcessEnvironmentFactory.java
index 41f577c8806a..a0d285b980e3 100644
---
a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/ProcessEnvironmentFactory.java
+++
b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/ProcessEnvironmentFactory.java
@@ -36,7 +36,7 @@
import org.slf4j.LoggerFactory;
/**
- * An {@link EnvironmentFactory} which forks processes based on the given URL
in the Environment.
+ * An {@link EnvironmentFactory} which forks processes based on the parameters
in the Environment.
* The returned {@link ProcessEnvironment} has to make sure to stop the
processes.
*/
public class ProcessEnvironmentFactory implements EnvironmentFactory {
diff --git a/sdks/python/apache_beam/pipeline.py
b/sdks/python/apache_beam/pipeline.py
index 4618c2c6cbd9..577e773e2851 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -606,13 +606,18 @@ def visit_value(self, value, _):
return Visitor.ok
def to_runner_api(
- self, return_context=False, context=None, use_fake_coders=False):
+ self, return_context=False, context=None, use_fake_coders=False,
+ default_environment=None):
"""For internal use only; no backwards-compatibility guarantees."""
from apache_beam.runners import pipeline_context
from apache_beam.portability.api import beam_runner_api_pb2
if context is None:
context = pipeline_context.PipelineContext(
- use_fake_coders=use_fake_coders)
+ use_fake_coders=use_fake_coders,
+ default_environment=default_environment)
+ elif default_environment is not None:
+ raise ValueError(
+ 'Only one of context or default_environment may be specificed.')
# The RunnerAPI spec requires certain transforms to have KV inputs
# (and corresponding outputs).
diff --git a/sdks/python/apache_beam/portability/python_urns.py
b/sdks/python/apache_beam/portability/python_urns.py
index 9f22313ec537..980ca68d086a 100644
--- a/sdks/python/apache_beam/portability/python_urns.py
+++ b/sdks/python/apache_beam/portability/python_urns.py
@@ -32,3 +32,19 @@
IMPULSE_READ_TRANSFORM = "beam:transform:read_from_impulse_python:v1"
GENERIC_COMPOSITE_TRANSFORM = "beam:transform:generic_composite:v1"
+
+# Invoke UserFns in process, via direct function calls.
+# Payload: None.
+EMBEDDED_PYTHON = "beam:env:embedded_python:v1"
+
+# Invoke UserFns in process, but over GRPC channels.
+# Payload: (optional) Number of worker threads, as a decimal string.
+# (Used for testing.)
+EMBEDDED_PYTHON_GRPC = "beam:env:embedded_python_grpc:v1"
+
+# Instantiate SDK harness via a command line provided in the payload.
+# This is different than the standard process environment in that it
+# starts up the SDK harness directly, rather than the bootstrapping
+# and artifact fetching code.
+# (Used for testing.)
+SUBPROCESS_SDK = "beam:env:harness_subprocess_python:v1"
diff --git a/sdks/python/apache_beam/runners/pipeline_context.py
b/sdks/python/apache_beam/runners/pipeline_context.py
index 437fe0194994..74156a1f146e 100644
--- a/sdks/python/apache_beam/runners/pipeline_context.py
+++ b/sdks/python/apache_beam/runners/pipeline_context.py
@@ -124,7 +124,7 @@ def __init__(
self, cls, getattr(proto, name, None)))
if default_environment:
self._default_environment_id = self.environments.get_id(
- Environment(default_environment))
+ Environment(default_environment), label='default_environment')
else:
self._default_environment_id = None
self.use_fake_coders = use_fake_coders
diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index 4ccf89de8d33..f90ce5fe437e 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -87,9 +87,10 @@ def create_options(self):
options = super(FlinkRunnerTest, self).create_options()
options.view_as(DebugOptions).experiments = ['beam_fn_api']
options.view_as(FlinkOptions).parallelism = 1
- # Default environment is Docker.
if environment_type == 'process':
options.view_as(PortableOptions).environment_type = 'PROCESS'
+ else:
+ options.view_as(PortableOptions).environment_type = 'DOCKER'
if environment_config:
options.view_as(PortableOptions).environment_config =
environment_config
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index 341139e2db2d..677b0e1af33e 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -212,15 +212,15 @@ def encoded_items(self):
class FnApiRunner(runner.PipelineRunner):
- def __init__(self, use_grpc=False, sdk_harness_factory=None, bundle_repeat=0,
- use_state_iterables=False):
+ def __init__(
+ self,
+ default_environment=None,
+ bundle_repeat=0,
+ use_state_iterables=False):
"""Creates a new Fn API Runner.
Args:
- use_grpc: whether to use grpc or simply make in-process calls
- defaults to False
- sdk_harness_factory: callable used to instantiate customized sdk
harnesses
- typcially not set by users
+ default_environment: the default environment to use for UserFns.
bundle_repeat: replay every bundle this many extra times, for profiling
and debugging
use_state_iterables: Intentionally split gbk iterables over state API
@@ -228,10 +228,9 @@ def __init__(self, use_grpc=False,
sdk_harness_factory=None, bundle_repeat=0,
"""
super(FnApiRunner, self).__init__()
self._last_uid = -1
- self._use_grpc = use_grpc
- if sdk_harness_factory and not use_grpc:
- raise ValueError('GRPC must be used if a harness factory is provided.')
- self._sdk_harness_factory = sdk_harness_factory
+ self._default_environment = (
+ default_environment
+ or beam_runner_api_pb2.Environment(urn=python_urns.EMBEDDED_PYTHON))
self._bundle_repeat = bundle_repeat
self._progress_frequency = None
self._profiler_factory = None
@@ -253,7 +252,8 @@ def run_pipeline(self, pipeline, options):
pipeline_options.DirectOptions).direct_runner_bundle_repeat
self._profiler_factory = profiler.Profile.factory_from_options(
options.view_as(pipeline_options.ProfilingOptions))
- return self.run_via_runner_api(pipeline.to_runner_api())
+ return self.run_via_runner_api(pipeline.to_runner_api(
+ default_environment=self._default_environment))
def run_via_runner_api(self, pipeline_proto):
return self.run_stages(*self.create_stages(pipeline_proto))
@@ -860,10 +860,8 @@ def window_pcollection_coders(stages):
return pipeline_components, stages, safe_coders
def run_stages(self, pipeline_components, stages, safe_coders):
- if self._use_grpc:
- controller = FnApiRunner.GrpcController(self._sdk_harness_factory)
- else:
- controller = FnApiRunner.DirectController()
+ worker_handler_manager = WorkerHandlerManager(
+ pipeline_components.environments)
metrics_by_stage = {}
monitoring_infos_by_stage = {}
@@ -872,18 +870,26 @@ def run_stages(self, pipeline_components, stages,
safe_coders):
pcoll_buffers = collections.defaultdict(list)
for stage in stages:
stage_results = self.run_stage(
- controller, pipeline_components, stage,
- pcoll_buffers, safe_coders)
+ worker_handler_manager.get_worker_handler,
+ pipeline_components,
+ stage,
+ pcoll_buffers,
+ safe_coders)
metrics_by_stage[stage.name] = stage_results.process_bundle.metrics
monitoring_infos_by_stage[stage.name] = (
stage_results.process_bundle.monitoring_infos)
finally:
- controller.close()
+ worker_handler_manager.close_all()
return RunnerResult(
runner.PipelineState.DONE, monitoring_infos_by_stage, metrics_by_stage)
def run_stage(
- self, controller, pipeline_components, stage, pcoll_buffers,
safe_coders):
+ self,
+ worker_handler_factory,
+ pipeline_components,
+ stage,
+ pcoll_buffers,
+ safe_coders):
def iterable_state_write(values, element_coder_impl):
token = unique_name(None, 'iter').encode('ascii')
@@ -896,6 +902,7 @@ def iterable_state_write(values, element_coder_impl):
out.get())
return token
+ controller = worker_handler_factory(stage.environment)
context = pipeline_context.PipelineContext(
pipeline_components, iterable_state_write=iterable_state_write)
data_api_service_descriptor = controller.data_api_service_descriptor()
@@ -972,7 +979,7 @@ def extract_endpoints(stage):
side_input_id=tag,
window=window,
key=key))
- controller.state_handler.blocking_append(state_key, elements_data)
+ controller.state.blocking_append(state_key, elements_data)
def get_buffer(buffer_id):
kind, name = split_buffer_id(buffer_id)
@@ -1005,12 +1012,12 @@ def get_buffer(buffer_id):
for k in range(self._bundle_repeat):
try:
- controller.state_handler.checkpoint()
+ controller.state.checkpoint()
BundleManager(
controller, lambda pcoll_id: [], process_bundle_descriptor,
self._progress_frequency, k).process_bundle(data_input,
data_output)
finally:
- controller.state_handler.restore()
+ controller.state.restore()
result = BundleManager(
controller, get_buffer, process_bundle_descriptor,
@@ -1138,8 +1145,10 @@ def blocking_clear(self, state_key):
def _to_key(state_key):
return state_key.SerializeToString()
- class GrpcStateServicer(
- StateServicer, beam_fn_api_pb2_grpc.BeamFnStateServicer):
+ class GrpcStateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer):
+ def __init__(self, state):
+ self._state = state
+
def State(self, request_stream, context=None):
# Note that this eagerly mutates state, assuming any failures are fatal.
# Thus it is safe to ignore instruction_reference.
@@ -1149,14 +1158,14 @@ def State(self, request_stream, context=None):
yield beam_fn_api_pb2.StateResponse(
id=request.id,
get=beam_fn_api_pb2.StateGetResponse(
- data=self.blocking_get(request.state_key)))
+ data=self._state.blocking_get(request.state_key)))
elif request_type == 'append':
- self.blocking_append(request.state_key, request.append.data)
+ self._state.blocking_append(request.state_key, request.append.data)
yield beam_fn_api_pb2.StateResponse(
id=request.id,
append=beam_fn_api_pb2.StateAppendResponse())
elif request_type == 'clear':
- self.blocking_clear(request.state_key)
+ self._state.blocking_clear(request.state_key)
yield beam_fn_api_pb2.StateResponse(
id=request.id,
clear=beam_fn_api_pb2.StateClearResponse())
@@ -1177,111 +1186,224 @@ def close(self):
"""Does nothing."""
pass
- class DirectController(object):
- """An in-memory controller for fn API control, state and data planes."""
- def __init__(self):
- self.control_handler = self
- self.data_plane_handler = data_plane.InMemoryDataChannel()
- self.state_handler = FnApiRunner.StateServicer()
- self.worker = sdk_worker.SdkWorker(
- FnApiRunner.SingletonStateHandlerFactory(self.state_handler),
- data_plane.InMemoryDataChannelFactory(
- self.data_plane_handler.inverse()), {})
- self._uid_counter = 0
-
- def push(self, request):
- if not request.instruction_id:
- self._uid_counter += 1
- request.instruction_id = 'control_%s' % self._uid_counter
- logging.debug('CONTROL REQUEST %s', request)
- response = self.worker.do_instruction(request)
- logging.debug('CONTROL RESPONSE %s', response)
- return ControlFuture(request.instruction_id, response)
+class WorkerHandler(object):
- def done(self):
- pass
+ _registered_environments = {}
- def close(self):
- pass
+ def __init__(self, control_handler, data_plane_handler, state):
+ self.control_handler = control_handler
+ self.data_plane_handler = data_plane_handler
+ self.state = state
- def data_api_service_descriptor(self):
- return None
-
- def state_api_service_descriptor(self):
- return None
-
- class GrpcController(object):
- """An grpc based controller for fn API control, state and data planes."""
-
- def __init__(self, sdk_harness_factory=None):
- self.sdk_harness_factory = sdk_harness_factory
- self.control_server = grpc.server(
- futures.ThreadPoolExecutor(max_workers=10))
- self.control_port = self.control_server.add_insecure_port('[::]:0')
-
- # Options to have no limits (-1) on the size of the messages
- # received or sent over the data plane. The actual buffer size
- # is controlled in a layer above.
- no_max_message_sizes = [("grpc.max_receive_message_length", -1),
- ("grpc.max_send_message_length", -1)]
- self.data_server = grpc.server(
- futures.ThreadPoolExecutor(max_workers=10),
- options=no_max_message_sizes)
- self.data_port = self.data_server.add_insecure_port('[::]:0')
-
- self.state_server = grpc.server(
- futures.ThreadPoolExecutor(max_workers=10),
- options=no_max_message_sizes)
- self.state_port = self.state_server.add_insecure_port('[::]:0')
-
- self.control_handler = BeamFnControlServicer()
- beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
- self.control_handler, self.control_server)
-
- self.data_plane_handler = data_plane.GrpcServerDataChannel()
- beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server(
- self.data_plane_handler, self.data_server)
-
- self.state_handler = FnApiRunner.GrpcStateServicer()
- beam_fn_api_pb2_grpc.add_BeamFnStateServicer_to_server(
- self.state_handler, self.state_server)
-
- logging.info('starting control server on port %s', self.control_port)
- logging.info('starting data server on port %s', self.data_port)
- self.state_server.start()
- self.data_server.start()
- self.control_server.start()
-
- self.worker = self.sdk_harness_factory(
- 'localhost:%s' % self.control_port
- ) if self.sdk_harness_factory else sdk_worker.SdkHarness(
- 'localhost:%s' % self.control_port, worker_count=1)
-
- self.worker_thread = threading.Thread(
- name='run_worker', target=self.worker.run)
- logging.info('starting worker')
- self.worker_thread.start()
-
- def data_api_service_descriptor(self):
- url = 'localhost:%s' % self.data_port
- api_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
- api_service_descriptor.url = url
- return api_service_descriptor
-
- def state_api_service_descriptor(self):
- url = 'localhost:%s' % self.state_port
- api_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
- api_service_descriptor.url = url
- return api_service_descriptor
+ def close(self):
+ self.stop_worker()
- def close(self):
- self.control_handler.done()
- self.worker_thread.join()
- self.data_plane_handler.close()
- self.control_server.stop(5).wait()
- self.data_server.stop(5).wait()
- self.state_server.stop(5).wait()
+ def start_worker(self):
+ raise NotImplementedError
+
+ def stop_worker(self):
+ raise NotImplementedError
+
+ def data_api_service_descriptor(self):
+ raise NotImplementedError
+
+ def state_api_service_descriptor(self):
+ raise NotImplementedError
+
+ @classmethod
+ def register_environment(cls, urn, payload_type):
+ def wrapper(constructor):
+ cls._registered_environments[urn] = constructor, payload_type
+ return constructor
+ return wrapper
+
+ @classmethod
+ def create(cls, environment, state):
+ constructor, payload_type = cls._registered_environments[environment.urn]
+ return constructor(
+ proto_utils.parse_Bytes(environment.payload, payload_type), state)
+
+
[email protected]_environment(python_urns.EMBEDDED_PYTHON, None)
+class EmbeddedWorkerHandler(WorkerHandler):
+ """An in-memory controller for fn API control, state and data planes."""
+
+ def __init__(self, unused_payload, state):
+ super(EmbeddedWorkerHandler, self).__init__(
+ self, data_plane.InMemoryDataChannel(), state)
+ self.worker = sdk_worker.SdkWorker(
+ FnApiRunner.SingletonStateHandlerFactory(self.state),
+ data_plane.InMemoryDataChannelFactory(
+ self.data_plane_handler.inverse()), {})
+ self._uid_counter = 0
+
+ def push(self, request):
+ if not request.instruction_id:
+ self._uid_counter += 1
+ request.instruction_id = 'control_%s' % self._uid_counter
+ logging.debug('CONTROL REQUEST %s', request)
+ response = self.worker.do_instruction(request)
+ logging.debug('CONTROL RESPONSE %s', response)
+ return ControlFuture(request.instruction_id, response)
+
+ def start_worker(self):
+ pass
+
+ def stop_worker(self):
+ pass
+
+ def done(self):
+ pass
+
+ def data_api_service_descriptor(self):
+ return None
+
+ def state_api_service_descriptor(self):
+ return None
+
+
+class GrpcWorkerHandler(WorkerHandler):
+ """An grpc based controller for fn API control, state and data planes."""
+
+ def __init__(self, state=None):
+ self.control_server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=10))
+ self.control_port = self.control_server.add_insecure_port('[::]:0')
+ self.control_address = 'localhost:%s' % self.control_port
+
+ # Options to have no limits (-1) on the size of the messages
+ # received or sent over the data plane. The actual buffer size
+ # is controlled in a layer above.
+ no_max_message_sizes = [("grpc.max_receive_message_length", -1),
+ ("grpc.max_send_message_length", -1)]
+ self.data_server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=10),
+ options=no_max_message_sizes)
+ self.data_port = self.data_server.add_insecure_port('[::]:0')
+
+ self.state_server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=10),
+ options=no_max_message_sizes)
+ self.state_port = self.state_server.add_insecure_port('[::]:0')
+
+ self.control_handler = BeamFnControlServicer()
+ beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
+ self.control_handler, self.control_server)
+
+ self.data_plane_handler = data_plane.GrpcServerDataChannel()
+ beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server(
+ self.data_plane_handler, self.data_server)
+
+ self.state = state
+ beam_fn_api_pb2_grpc.add_BeamFnStateServicer_to_server(
+ FnApiRunner.GrpcStateServicer(state),
+ self.state_server)
+
+ logging.info('starting control server on port %s', self.control_port)
+ logging.info('starting data server on port %s', self.data_port)
+ self.state_server.start()
+ self.data_server.start()
+ self.control_server.start()
+
+ def data_api_service_descriptor(self):
+ return endpoints_pb2.ApiServiceDescriptor(
+ url='localhost:%s' % self.data_port)
+
+ def state_api_service_descriptor(self):
+ return endpoints_pb2.ApiServiceDescriptor(
+ url='localhost:%s' % self.state_port)
+
+ def close(self):
+ self.control_handler.done()
+ self.data_plane_handler.close()
+ self.control_server.stop(5).wait()
+ self.data_server.stop(5).wait()
+ self.state_server.stop(5).wait()
+ super(GrpcWorkerHandler, self).close()
+
+
[email protected]_environment(
+ common_urns.environments.EXTERNAL.urn, beam_runner_api_pb2.ExternalPayload)
+class ExternalWorkerHandler(GrpcWorkerHandler):
+ def __init__(self, external_payload, state):
+ super(ExternalWorkerHandler, self).__init__(state)
+ self._external_payload = external_payload
+
+ def start_worker(self):
+ stub = beam_fn_api_pb2_grpc.BeamFnExternalWorkerPoolStub(
+ grpc.insecure_channel(self._external_payload.endpoint.url))
+ response = stub.NotifyRunnerAvailable(
+ beam_fn_api_pb2.NotifyRunnerAvailableRequest(
+ control_endpoint=endpoints_pb2.ApiServiceDescriptor(
+ url=self.control_address),
+ params=self._external_payload.params))
+ if response.error:
+ raise RuntimeError("Error starting worker: %s" % response.error)
+
+ def stop_worker(self):
+ pass
+
+
[email protected]_environment(python_urns.EMBEDDED_PYTHON_GRPC, bytes)
+class EmbeddedGrpcWorkerHandler(GrpcWorkerHandler):
+ def __init__(self, num_workers_payload, state):
+ super(EmbeddedGrpcWorkerHandler, self).__init__(state)
+ self._num_threads = int(num_workers_payload) if num_workers_payload else 1
+
+ def start_worker(self):
+ self.worker = sdk_worker.SdkHarness(
+ self.control_address, worker_count=self._num_threads)
+ self.worker_thread = threading.Thread(
+ name='run_worker', target=self.worker.run)
+ self.worker_thread.start()
+
+ def stop_worker(self):
+ self.worker_thread.join()
+
+
[email protected]_environment(python_urns.SUBPROCESS_SDK, bytes)
+class SubprocessSdkWorkerHandler(GrpcWorkerHandler):
+ def __init__(self, worker_command_line, state):
+ super(SubprocessSdkWorkerHandler, self).__init__(state)
+ self._worker_command_line = worker_command_line
+
+ def start_worker(self):
+ from apache_beam.runners.portability import local_job_service
+ self.worker = local_job_service.SubprocessSdkWorker(
+ self._worker_command_line, self.control_address)
+ self.worker_thread = threading.Thread(
+ name='run_worker', target=self.worker.run)
+ self.worker_thread.start()
+
+ def stop_worker(self):
+ self.worker_thread.join()
+
+
+class WorkerHandlerManager(object):
+ def __init__(self, environments):
+ self._environments = environments
+ self._cached_handlers = {}
+ self._state = FnApiRunner.StateServicer() # rename?
+
+ def get_worker_handler(self, environment_id):
+ if environment_id is None:
+ # Any environment will do, pick one arbitrarily.
+ environment_id = next(iter(self._environments.keys()))
+ environment = self._environments[environment_id]
+
+ worker_handler = self._cached_handlers.get(environment_id)
+ if worker_handler is None:
+ worker_handler = self._cached_handlers[
+ environment_id] = WorkerHandler.create(
+ environment, self._state)
+ worker_handler.start_worker()
+ return worker_handler
+
+ def close_all(self):
+ for controller in set(self._cached_handlers.values()):
+ controller.close()
+ self._cached_handlers = {}
class BundleManager(object):
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 9a59b8bbc2ba..3fecbe5a85e8 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -17,7 +17,6 @@
from __future__ import absolute_import
from __future__ import print_function
-import functools
import logging
import os
import sys
@@ -32,9 +31,10 @@
from apache_beam.metrics.execution import MetricKey
from apache_beam.metrics.execution import MetricsEnvironment
from apache_beam.metrics.metricbase import MetricName
+from apache_beam.portability import python_urns
+from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.portability import fn_api_runner
from apache_beam.runners.worker import data_plane
-from apache_beam.runners.worker import sdk_worker
from apache_beam.runners.worker import statesampler
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
@@ -50,8 +50,7 @@
class FnApiRunnerTest(unittest.TestCase):
def create_pipeline(self):
- return beam.Pipeline(
- runner=fn_api_runner.FnApiRunner(use_grpc=False))
+ return beam.Pipeline(runner=fn_api_runner.FnApiRunner())
def test_assert_that(self):
# TODO: figure out a way for fn_api_runner to parse and raise the
@@ -684,7 +683,9 @@ class FnApiRunnerTestWithGrpc(FnApiRunnerTest):
def create_pipeline(self):
return beam.Pipeline(
- runner=fn_api_runner.FnApiRunner(use_grpc=True))
+ runner=fn_api_runner.FnApiRunner(
+ default_environment=beam_runner_api_pb2.Environment(
+ urn=python_urns.EMBEDDED_PYTHON_GRPC)))
class FnApiRunnerTestWithGrpcMultiThreaded(FnApiRunnerTest):
@@ -692,16 +693,16 @@ class
FnApiRunnerTestWithGrpcMultiThreaded(FnApiRunnerTest):
def create_pipeline(self):
return beam.Pipeline(
runner=fn_api_runner.FnApiRunner(
- use_grpc=True,
- sdk_harness_factory=functools.partial(
- sdk_worker.SdkHarness, worker_count=2)))
+ default_environment=beam_runner_api_pb2.Environment(
+ urn=python_urns.EMBEDDED_PYTHON_GRPC,
+ payload=b'2')))
class FnApiRunnerTestWithBundleRepeat(FnApiRunnerTest):
def create_pipeline(self):
return beam.Pipeline(
- runner=fn_api_runner.FnApiRunner(use_grpc=False, bundle_repeat=3))
+ runner=fn_api_runner.FnApiRunner(bundle_repeat=3))
if __name__ == '__main__':
diff --git
a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
index 0d3fdc76b7f4..615bf1fa13c6 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
@@ -20,6 +20,7 @@
from __future__ import absolute_import
from __future__ import print_function
+import functools
from builtins import object
from apache_beam.portability import common_urns
@@ -39,13 +40,18 @@ class Stage(object):
"""A set of Transforms that can be sent to the worker for processing."""
def __init__(self, name, transforms,
downstream_side_inputs=None, must_follow=frozenset(),
- parent=None):
+ parent=None, environment=None):
self.name = name
self.transforms = transforms
self.downstream_side_inputs = downstream_side_inputs
self.must_follow = must_follow
self.timer_pcollections = []
self.parent = parent
+ if environment is None:
+ environment = functools.reduce(
+ self._merge_environments,
+ (self._extract_environment(t) for t in transforms))
+ self.environment = environment
def __repr__(self):
must_follow = ', '.join(prev.name for prev in self.must_follow)
@@ -61,9 +67,45 @@ def __repr__(self):
must_follow,
downstream_side_inputs)
+ @staticmethod
+ def _extract_environment(transform):
+ if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
+ pardo_payload = proto_utils.parse_Bytes(
+ transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
+ return pardo_payload.do_fn.environment_id
+ elif transform.spec.urn in (
+ common_urns.composites.COMBINE_PER_KEY.urn,
+ common_urns.combine_components.COMBINE_PGBKCV.urn,
+ common_urns.combine_components.COMBINE_MERGE_ACCUMULATORS.urn,
+ common_urns.combine_components.COMBINE_EXTRACT_OUTPUTS.urn):
+ combine_payload = proto_utils.parse_Bytes(
+ transform.spec.payload, beam_runner_api_pb2.CombinePayload)
+ return combine_payload.combine_fn.environment_id
+ else:
+ return None
+
+ @staticmethod
+ def _merge_environments(env1, env2):
+ if env1 is None:
+ return env2
+ elif env2 is None:
+ return env1
+ else:
+ if env1 != env2:
+ raise ValueError("Incompatible environments: '%s' != '%s'" % (
+ str(env1).replace('\n', ' '),
+ str(env2).replace('\n', ' ')))
+ return env1
+
def can_fuse(self, consumer):
+ try:
+ self._merge_environments(self.environment, consumer.environment)
+ except ValueError:
+ return False
+
def no_overlap(a, b):
return not a.intersection(b)
+
return (
not self in consumer.must_follow
and not self.is_flatten() and not consumer.is_flatten()
@@ -74,7 +116,9 @@ def fuse(self, other):
"(%s)+(%s)" % (self.name, other.name),
self.transforms + other.transforms,
union(self.downstream_side_inputs, other.downstream_side_inputs),
- union(self.must_follow, other.must_follow))
+ union(self.must_follow, other.must_follow),
+ environment=self._merge_environments(
+ self.environment, other.environment))
def is_flatten(self):
return any(transform.spec.urn == common_urns.primitives.FLATTEN.urn
diff --git a/sdks/python/apache_beam/runners/portability/local_job_service.py
b/sdks/python/apache_beam/runners/portability/local_job_service.py
index 5819832aa9ac..cb8c55260d37 100644
--- a/sdks/python/apache_beam/runners/portability/local_job_service.py
+++ b/sdks/python/apache_beam/runners/portability/local_job_service.py
@@ -16,7 +16,6 @@
#
from __future__ import absolute_import
-import functools
import logging
import os
import queue
@@ -61,9 +60,7 @@ class
LocalJobServicer(beam_job_api_pb2_grpc.JobServiceServicer):
subprocesses for the runner and worker(s).
"""
- def __init__(self, worker_command_line=None, use_grpc=True):
- self._worker_command_line = worker_command_line
- self._use_grpc = use_grpc or bool(worker_command_line)
+ def __init__(self):
self._jobs = {}
def start_grpc_server(self, port=0):
@@ -78,17 +75,10 @@ def Prepare(self, request, context=None):
# For now, just use the job name as the job id.
logging.debug('Got Prepare request.')
preparation_id = '%s-%s' % (request.job_name, uuid.uuid4())
- if self._worker_command_line:
- sdk_harness_factory = functools.partial(SubprocessSdkWorker,
- self._worker_command_line)
- else:
- sdk_harness_factory = None
self._jobs[preparation_id] = BeamJob(
preparation_id,
request.pipeline_options,
- request.pipeline,
- use_grpc=self._use_grpc,
- sdk_harness_factory=sdk_harness_factory)
+ request.pipeline)
logging.debug("Prepared job '%s' as '%s'", request.job_name,
preparation_id)
# TODO(angoenka): Pass an appropriate staging_session_token. The token can
# be obtained in PutArtifactResponse from JobService
@@ -186,15 +176,11 @@ class BeamJob(threading.Thread):
def __init__(self,
job_id,
pipeline_options,
- pipeline_proto,
- use_grpc=True,
- sdk_harness_factory=None):
+ pipeline_proto):
super(BeamJob, self).__init__()
self._job_id = job_id
self._pipeline_options = pipeline_options
self._pipeline_proto = pipeline_proto
- self._use_grpc = use_grpc
- self._sdk_harness_factory = sdk_harness_factory
self._log_queue = queue.Queue()
self._state_change_callbacks = [
lambda new_state: self._log_queue.put(
@@ -227,10 +213,7 @@ def state(self, new_state):
def run(self):
with JobLogHandler(self._log_queue):
try:
- fn_api_runner.FnApiRunner(
- use_grpc=self._use_grpc,
- sdk_harness_factory=self._sdk_harness_factory).run_via_runner_api(
- self._pipeline_proto)
+ fn_api_runner.FnApiRunner().run_via_runner_api(self._pipeline_proto)
logging.info('Successfully completed job.')
self.state = beam_job_api_pb2.JobState.DONE
except: # pylint: disable=bare-except
diff --git
a/sdks/python/apache_beam/runners/portability/local_job_service_main.py
b/sdks/python/apache_beam/runners/portability/local_job_service_main.py
index dc70f4556883..f3be4fe79ca0 100644
--- a/sdks/python/apache_beam/runners/portability/local_job_service_main.py
+++ b/sdks/python/apache_beam/runners/portability/local_job_service_main.py
@@ -32,10 +32,8 @@ def run(argv):
parser.add_argument('-p', '--port',
type=int,
help='port on which to serve the job api')
- parser.add_argument('--worker_command_line',
- help='command line for starting up a worker process')
options = parser.parse_args(argv)
- job_servicer =
local_job_service.LocalJobServicer(options.worker_command_line)
+ job_servicer = local_job_service.LocalJobServicer()
port = job_servicer.start_grpc_server(options.port)
while True:
logging.info("Listening for jobs at %d", port)
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py
b/sdks/python/apache_beam/runners/portability/portable_runner.py
index 98bc6609615e..dc21a28d0923 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner.py
@@ -17,10 +17,12 @@
from __future__ import absolute_import
+import functools
import json
import logging
import os
import threading
+from concurrent import futures
import grpc
@@ -29,15 +31,18 @@
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.portability import common_urns
+from apache_beam.portability.api import beam_fn_api_pb2
+from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.portability.api import beam_job_api_pb2
from apache_beam.portability.api import beam_job_api_pb2_grpc
from apache_beam.portability.api import beam_runner_api_pb2
-from apache_beam.runners import pipeline_context
+from apache_beam.portability.api import endpoints_pb2
from apache_beam.runners import runner
from apache_beam.runners.job import utils as job_utils
from apache_beam.runners.portability import fn_api_runner_transforms
from apache_beam.runners.portability import portable_stager
from apache_beam.runners.portability.job_server import DockerizedJobServer
+from apache_beam.runners.worker import sdk_worker
__all__ = ['PortableRunner']
@@ -85,6 +90,14 @@ def _create_environment(options):
environment_urn = common_urns.environments.DOCKER.urn
elif portable_options.environment_type == 'PROCESS':
environment_urn = common_urns.environments.PROCESS.urn
+ elif portable_options.environment_type in ('EXTERNAL', 'LOOPBACK'):
+ environment_urn = common_urns.environments.EXTERNAL.urn
+ elif portable_options.environment_type:
+ if portable_options.environment_type.startswith('beam:env:'):
+ environment_urn = portable_options.environment_type
+ else:
+ raise ValueError(
+ 'Unknown environment type: %s' % portable_options.environment_type)
if environment_urn == common_urns.environments.DOCKER.urn:
docker_image = (
@@ -106,6 +119,18 @@ def _create_environment(options):
command=config.get('command'),
env=(config.get('env') or '')
).SerializeToString())
+ elif environment_urn == common_urns.environments.EXTERNAL.urn:
+ return beam_runner_api_pb2.Environment(
+ urn=common_urns.environments.EXTERNAL.urn,
+ payload=beam_runner_api_pb2.ExternalPayload(
+ endpoint=endpoints_pb2.ApiServiceDescriptor(
+ url=portable_options.environment_config)
+ ).SerializeToString())
+ else:
+ return beam_runner_api_pb2.Environment(
+ urn=environment_urn,
+ payload=(portable_options.environment_config.encode('ascii')
+ if portable_options.environment_config else None))
def run_pipeline(self, pipeline, options):
portable_options = options.view_as(PortableOptions)
@@ -120,10 +145,18 @@ def run_pipeline(self, pipeline, options):
docker = DockerizedJobServer()
job_endpoint = docker.start()
- proto_context = pipeline_context.PipelineContext(
+ # This is needed as we start a worker server if one is requested
+ # but none is provided.
+ if portable_options.environment_type == 'LOOPBACK':
+ portable_options.environment_config, server = (
+ BeamFnExternalWorkerPoolServicer.start())
+ cleanup_callbacks = [functools.partial(server.stop, 1)]
+ else:
+ cleanup_callbacks = []
+
+ proto_pipeline = pipeline.to_runner_api(
default_environment=PortableRunner._create_environment(
portable_options))
- proto_pipeline = pipeline.to_runner_api(context=proto_context)
# Some runners won't detect the GroupByKey transform unless it has no
# subtransforms. Remove all sub-transforms until BEAM-4605 is resolved.
@@ -186,7 +219,7 @@ def send_prepare_request(max_retries=5):
beam_job_api_pb2.RunJobRequest(
preparation_id=prepare_response.preparation_id,
retrieval_token=retrieval_token))
- return PipelineResult(job_service, run_response.job_id)
+ return PipelineResult(job_service, run_response.job_id, cleanup_callbacks)
class PortableMetrics(metrics.metric.MetricResults):
@@ -201,15 +234,19 @@ def query(self, filter=None):
class PipelineResult(runner.PipelineResult):
- def __init__(self, job_service, job_id):
+ def __init__(self, job_service, job_id, cleanup_callbacks=()):
super(PipelineResult, self).__init__(beam_job_api_pb2.JobState.UNSPECIFIED)
self._job_service = job_service
self._job_id = job_id
self._messages = []
+ self._cleanup_callbacks = cleanup_callbacks
def cancel(self):
- self._job_service.Cancel(beam_job_api_pb2.CancelJobRequest(
- job_id=self._job_id))
+ try:
+ self._job_service.Cancel(beam_job_api_pb2.CancelJobRequest(
+ job_id=self._job_id))
+ finally:
+ self._cleanup()
@property
def state(self):
@@ -262,16 +299,59 @@ def read_messages():
t.daemon = True
t.start()
- for state_response in self._job_service.GetStateStream(
- beam_job_api_pb2.GetJobStateRequest(job_id=self._job_id)):
- self._state = self._runner_api_state_to_pipeline_state(
- state_response.state)
- if state_response.state in TERMINAL_STATES:
- # Wait for any last messages.
- t.join(10)
- break
- if self._state != runner.PipelineState.DONE:
- raise RuntimeError(
- 'Pipeline %s failed in state %s: %s' % (
- self._job_id, self._state, self._last_error_message()))
- return self._state
+ try:
+ for state_response in self._job_service.GetStateStream(
+ beam_job_api_pb2.GetJobStateRequest(job_id=self._job_id)):
+ self._state = self._runner_api_state_to_pipeline_state(
+ state_response.state)
+ if state_response.state in TERMINAL_STATES:
+ # Wait for any last messages.
+ t.join(10)
+ break
+ if self._state != runner.PipelineState.DONE:
+ raise RuntimeError(
+ 'Pipeline %s failed in state %s: %s' % (
+ self._job_id, self._state, self._last_error_message()))
+ return self._state
+ finally:
+ self._cleanup()
+
+ def _cleanup(self):
+ has_exception = None
+ for callback in self._cleanup_callbacks:
+ try:
+ callback()
+ except Exception:
+ has_exception = True
+ self._cleanup_callbacks = ()
+ if has_exception:
+ raise
+
+
+class BeamFnExternalWorkerPoolServicer(
+ beam_fn_api_pb2_grpc.BeamFnExternalWorkerPoolServicer):
+
+ @classmethod
+ def start(cls):
+ worker_server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ worker_address = 'localhost:%s' % worker_server.add_insecure_port('[::]:0')
+ beam_fn_api_pb2_grpc.add_BeamFnExternalWorkerPoolServicer_to_server(
+ cls(), worker_server)
+ worker_server.start()
+ return worker_address, worker_server
+
+ def NotifyRunnerAvailable(self, start_worker_request, context):
+ try:
+ worker = sdk_worker.SdkHarness(
+ start_worker_request.control_endpoint.url,
+ worker_count=1,
+ worker_id=start_worker_request.worker_id)
+ worker_thread = threading.Thread(
+ name='run_worker_%s' % start_worker_request.worker_id,
+ target=worker.run)
+ worker_thread.daemon = True
+ worker_thread.start()
+ return beam_fn_api_pb2.NotifyRunnerAvailableResponse()
+ except Exception as exn:
+ return beam_fn_api_pb2.NotifyRunnerAvailableResponse(
+ error=str(exn))
diff --git
a/sdks/python/apache_beam/runners/portability/portable_runner_test.py
b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
index 5f178799feda..62214d3abc6c 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
@@ -35,6 +35,7 @@
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import PortableOptions
from apache_beam.portability import common_urns
+from apache_beam.portability import python_urns
from apache_beam.portability.api import beam_job_api_pb2
from apache_beam.portability.api import beam_job_api_pb2_grpc
from apache_beam.portability.api import beam_runner_api_pb2
@@ -48,7 +49,7 @@ class PortableRunnerTest(fn_api_runner_test.FnApiRunnerTest):
TIMEOUT_SECS = 30
- _use_grpc = False
+ # Controls job service interaction, not sdk harness interaction.
_use_subprocesses = False
def setUp(self):
@@ -127,13 +128,8 @@ def _get_job_endpoint(cls):
def _create_job_endpoint(cls):
if cls._use_subprocesses:
return cls._start_local_runner_subprocess_job_service()
- elif cls._use_grpc:
- # Use GRPC for workers.
- cls._servicer = LocalJobServicer(use_grpc=True)
- return 'localhost:%d' % cls._servicer.start_grpc_server()
else:
- # Do not use GRPC for worker.
- cls._servicer = LocalJobServicer(use_grpc=False)
+ cls._servicer = LocalJobServicer()
return 'localhost:%d' % cls._servicer.start_grpc_server()
@classmethod
@@ -162,6 +158,9 @@ def get_pipeline_name():
'job_name': get_pipeline_name() + '_' + str(time.time())
})
options.view_as(PortableOptions).job_endpoint = self._get_job_endpoint()
+ # Override the default environment type for testing.
+ options.view_as(PortableOptions).environment_type = (
+ python_urns.EMBEDDED_PYTHON)
return options
def create_pipeline(self):
@@ -170,23 +169,43 @@ def create_pipeline(self):
# Inherits all tests from fn_api_runner_test.FnApiRunnerTest
-class PortableRunnerTestWithGrpc(PortableRunnerTest):
- _use_grpc = True
+class PortableRunnerTestWithExternalEnv(PortableRunnerTest):
+
+ @classmethod
+ def setUpClass(cls):
+ cls._worker_address, cls._worker_server = (
+ portable_runner.BeamFnExternalWorkerPoolServicer.start())
+
+ @classmethod
+ def tearDownClass(cls):
+ cls._worker_server.stop(1)
+
+ def create_options(self):
+ options = super(PortableRunnerTestWithExternalEnv, self).create_options()
+ options.view_as(PortableOptions).environment_type = 'EXTERNAL'
+ options.view_as(PortableOptions).environment_config = self._worker_address
+ return options
@unittest.skip("BEAM-3040")
class PortableRunnerTestWithSubprocesses(PortableRunnerTest):
- _use_grpc = True
_use_subprocesses = True
+ def create_options(self):
+ options = super(PortableRunnerTestWithSubprocesses, self).create_options()
+ options.view_as(PortableOptions).environment_type = (
+ python_urns.SUBPROCESS_SDK)
+ options.view_as(PortableOptions).environment_config = (
+ b'%s -m apache_beam.runners.worker.sdk_worker_main' %
+ sys.executable.encode('ascii'))
+ return options
+
@classmethod
def _subprocess_command(cls, port):
return [
sys.executable,
'-m', 'apache_beam.runners.portability.local_job_service_main',
'-p', str(port),
- '--worker_command_line',
- '%s -m apache_beam.runners.worker.sdk_worker_main' % sys.executable,
]
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index c3f2b91693e2..967928bf651e 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -48,11 +48,13 @@ class SdkHarness(object):
REQUEST_METHOD_PREFIX = '_request_'
SCHEDULING_DELAY_THRESHOLD_SEC = 5*60 # 5 Minutes
- def __init__(self, control_address, worker_count, credentials=None,
- profiler_factory=None):
+ def __init__(
+ self, control_address, worker_count, credentials=None, worker_id=None,
+ profiler_factory=None):
self._alive = True
self._worker_count = worker_count
self._worker_index = 0
+ self._worker_id = worker_id
if credentials is None:
logging.info('Creating insecure control channel.')
self._control_channel = grpc.insecure_channel(control_address)
@@ -63,7 +65,7 @@ def __init__(self, control_address, worker_count,
credentials=None,
logging.info('Control channel established.')
self._control_channel = grpc.intercept_channel(
- self._control_channel, WorkerIdInterceptor())
+ self._control_channel, WorkerIdInterceptor(self._worker_id))
self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
credentials)
self._state_handler_factory = GrpcStateHandlerFactory()
diff --git a/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py
b/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py
index 7b2c9cfbf4fe..6c9a605c3a5c 100644
--- a/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py
+++ b/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py
@@ -40,8 +40,9 @@ class WorkerIdInterceptor(grpc.StreamStreamClientInterceptor):
# Unique worker Id for this worker.
_worker_id = os.environ.get('WORKER_ID')
- def __init__(self):
- pass
+ def __init__(self, worker_id=None):
+ if worker_id:
+ self._worker_id = worker_id
def intercept_stream_stream(self, continuation, client_call_details,
request_iterator):
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
Issue Time Tracking
-------------------
Worklog Id: (was: 174496)
Time Spent: 3h 20m (was: 3h 10m)
> Implement External environment for Portable Beam
> ------------------------------------------------
>
> Key: BEAM-6094
> URL: https://issues.apache.org/jira/browse/BEAM-6094
> Project: Beam
> Issue Type: Improvement
> Components: beam-model
> Reporter: Robert Bradshaw
> Assignee: Robert Bradshaw
> Priority: Major
> Time Spent: 3h 20m
> Remaining Estimate: 0h
>
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)