This is an automated email from the ASF dual-hosted git repository.
mxm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 60131b8 [BEAM-6747] Adding ExternalTransform in JavaSDK
new 4c32210 Merge pull request #7954 from ihji/BEAM-6747
60131b8 is described below
commit 60131b84d38f1c223b3e1715c8964fb5cfdea54f
Author: Heejong Lee <[email protected]>
AuthorDate: Wed Feb 20 18:22:35 2019 -0800
[BEAM-6747] Adding ExternalTransform in JavaSDK
---
.../org/apache/beam/gradle/BeamModulePlugin.groovy | 7 +-
.../DefaultExpansionServiceClientFactory.java | 67 ++++++
.../core/construction/ExpansionServiceClient.java | 25 +++
.../ExpansionServiceClientFactory.java | 28 +++
.../beam/runners/core/construction/External.java | 250 +++++++++++++++++++++
.../core/construction/ExternalTranslation.java | 158 +++++++++++++
.../core/construction/PTransformTranslation.java | 2 +
.../runners/core/construction/SdkComponents.java | 30 ++-
.../construction/expansion/ExpansionService.java | 8 +-
.../runners/core/construction/ExternalTest.java | 197 ++++++++++++++++
runners/direct-java/build.gradle | 3 +
runners/flink/job-server/flink_job_server.gradle | 22 ++
runners/google-cloud-dataflow-java/build.gradle | 1 +
.../sdk/testing/UsesCrossLanguageTransforms.java | 24 ++
.../runners/portability/expansion_service.py | 24 +-
.../runners/portability/expansion_service_test.py | 166 ++++++++++++++
sdks/python/apache_beam/transforms/external.py | 11 +-
.../python/apache_beam/transforms/external_test.py | 91 +-------
18 files changed, 981 insertions(+), 133 deletions(-)
diff --git
a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 918df02..4852e5c 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -21,6 +21,7 @@ package org.apache.beam.gradle
import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar
import groovy.json.JsonOutput
import groovy.json.JsonSlurper
+import java.util.concurrent.atomic.AtomicInteger
import org.gradle.api.GradleException
import org.gradle.api.Plugin
import org.gradle.api.Project
@@ -71,6 +72,7 @@ class BeamModulePlugin implements Plugin<Project> {
* limitations under the License.
*/
"""
+ static AtomicInteger startingExpansionPortNumber = new AtomicInteger(18091)
/** A class defining the set of configurable properties accepted by
applyJavaNature. */
class JavaNatureConfiguration {
@@ -1545,6 +1547,7 @@ class BeamModulePlugin implements Plugin<Project> {
*/
project.evaluationDependsOn(":beam-sdks-java-core")
project.evaluationDependsOn(":beam-runners-core-java")
+ project.evaluationDependsOn(":beam-runners-core-construction-java")
def config = it ? it as PortableValidatesRunnerConfiguration : new
PortableValidatesRunnerConfiguration()
def name = config.name
def beamTestPipelineOptions = [
@@ -1552,6 +1555,7 @@ class BeamModulePlugin implements Plugin<Project> {
"--jobServerDriver=${config.jobServerDriver}",
"--environmentCacheMillis=10000"
]
+ def expansionPort = startingExpansionPortNumber.getAndDecrement()
beamTestPipelineOptions.addAll(config.pipelineOpts)
if (config.environment ==
PortableValidatesRunnerConfiguration.Environment.EMBEDDED) {
beamTestPipelineOptions += "--defaultEnvironmentType=EMBEDDED"
@@ -1563,8 +1567,9 @@ class BeamModulePlugin implements Plugin<Project> {
group = "Verification"
description = "Validates the PortableRunner with JobServer
${config.jobServerDriver}"
systemProperty "beamTestPipelineOptions",
JsonOutput.toJson(beamTestPipelineOptions)
+ systemProperty "expansionPort", expansionPort
classpath = config.testClasspathConfiguration
- testClassesDirs =
project.files(project.project(":beam-sdks-java-core").sourceSets.test.output.classesDirs,
project.project(":beam-runners-core-java").sourceSets.test.output.classesDirs)
+ testClassesDirs =
project.files(project.project(":beam-sdks-java-core").sourceSets.test.output.classesDirs,
project.project(":beam-runners-core-java").sourceSets.test.output.classesDirs,
project.project(":beam-runners-core-construction-java").sourceSets.test.output.classesDirs)
maxParallelForks config.numParallelTests
useJUnit(config.testCategories)
// increase maxHeapSize as this is directly correlated to direct
memory,
diff --git
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DefaultExpansionServiceClientFactory.java
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DefaultExpansionServiceClientFactory.java
new file mode 100644
index 0000000..a1425cb
--- /dev/null
+++
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DefaultExpansionServiceClientFactory.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.core.construction;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Function;
+import org.apache.beam.model.expansion.v1.ExpansionApi;
+import org.apache.beam.model.expansion.v1.ExpansionServiceGrpc;
+import org.apache.beam.model.pipeline.v1.Endpoints;
+import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ManagedChannel;
+
+/** Default factory for ExpansionServiceClient used by External transform. */
+public class DefaultExpansionServiceClientFactory implements
ExpansionServiceClientFactory {
+ private Map<Endpoints.ApiServiceDescriptor, ExpansionServiceClient>
expansionServiceMap;
+ private Function<Endpoints.ApiServiceDescriptor, ManagedChannel>
channelFactory;
+
+ DefaultExpansionServiceClientFactory(
+ Function<Endpoints.ApiServiceDescriptor, ManagedChannel> channelFactory)
{
+ this.expansionServiceMap = new ConcurrentHashMap<>();
+ this.channelFactory = channelFactory;
+ }
+
+ @Override
+ public void close() throws Exception {
+ for (ExpansionServiceClient client : expansionServiceMap.values()) {
+ try (AutoCloseable closer = client) {}
+ }
+ }
+
+ @Override
+ public ExpansionServiceClient
getExpansionServiceClient(Endpoints.ApiServiceDescriptor endpoint) {
+ return expansionServiceMap.computeIfAbsent(
+ endpoint,
+ e ->
+ new ExpansionServiceClient() {
+ private final ManagedChannel channel =
channelFactory.apply(endpoint);
+ private final ExpansionServiceGrpc.ExpansionServiceBlockingStub
service =
+ ExpansionServiceGrpc.newBlockingStub(channel);
+
+ @Override
+ public ExpansionApi.ExpansionResponse
expand(ExpansionApi.ExpansionRequest request) {
+ return service.expand(request);
+ }
+
+ @Override
+ public void close() throws Exception {
+ channel.shutdown();
+ }
+ });
+ }
+}
diff --git
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExpansionServiceClient.java
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExpansionServiceClient.java
new file mode 100644
index 0000000..e678766
--- /dev/null
+++
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExpansionServiceClient.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.core.construction;
+
+import org.apache.beam.model.expansion.v1.ExpansionApi;
+
+/** A high-level client for a cross-language expansion service. */
+interface ExpansionServiceClient extends AutoCloseable {
+ ExpansionApi.ExpansionResponse expand(ExpansionApi.ExpansionRequest request);
+}
diff --git
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExpansionServiceClientFactory.java
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExpansionServiceClientFactory.java
new file mode 100644
index 0000000..a43cb26
--- /dev/null
+++
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExpansionServiceClientFactory.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.core.construction;
+
+import org.apache.beam.model.pipeline.v1.Endpoints;
+
+/**
+ * A factory for generating {@link ExpansionServiceClient} from {@link
+ * org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor}.
+ */
+interface ExpansionServiceClientFactory extends AutoCloseable {
+ ExpansionServiceClient
getExpansionServiceClient(Endpoints.ApiServiceDescriptor endpoint);
+}
diff --git
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java
new file mode 100644
index 0000000..6fd5c00
--- /dev/null
+++
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.core.construction;
+
+import static
org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+import javax.annotation.Nullable;
+import org.apache.beam.model.expansion.v1.ExpansionApi;
+import org.apache.beam.model.pipeline.v1.Endpoints;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.runners.AppliedPTransform;
+import org.apache.beam.sdk.transforms.Impulse;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PBegin;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.PInput;
+import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.vendor.grpc.v1p13p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ManagedChannelBuilder;
+import
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
+import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
+
+/**
+ * Cross-language external transform.
+ *
+ * <p>{@link External} provides a cross-language transform via expansion
services in foreign SDKs.
+ * In order to use {@link External} transform, a user should know 1) URN of
the target transform 2)
+ * bytes encoding schema for configuration parameters 3) connection endpoint
of the expansion
+ * service. Note that this is a low-level API and mainly for internal use. A
user may want to use
+ * high-level wrapper classes rather than this one.
+ */
+public class External {
+ private static final String EXPANDED_TRANSFORM_BASE_NAME = "external";
+ private static final String IMPULSE_PREFIX = "IMPULSE";
+ private static AtomicInteger namespaceCounter = new AtomicInteger(0);
+
+ private static final ExpansionServiceClientFactory DEFAULT =
+ new DefaultExpansionServiceClientFactory(
+ endPoint ->
ManagedChannelBuilder.forTarget(endPoint.getUrl()).usePlaintext().build());
+
+ private static int getFreshNamespaceIndex() {
+ return namespaceCounter.getAndIncrement();
+ }
+
+ public static <OutputT> SingleOutputExpandableTransform<OutputT> of(
+ String urn, byte[] payload, String endpoint) {
+ Endpoints.ApiServiceDescriptor apiDesc =
+ Endpoints.ApiServiceDescriptor.newBuilder().setUrl(endpoint).build();
+ return new SingleOutputExpandableTransform<>(urn, payload, apiDesc,
getFreshNamespaceIndex());
+ }
+
+ /** Expandable transform for output type of PCollection. */
+ public static class SingleOutputExpandableTransform<OutputT>
+ extends ExpandableTransform<PCollection<OutputT>> {
+ SingleOutputExpandableTransform(
+ String urn,
+ byte[] payload,
+ Endpoints.ApiServiceDescriptor endpoint,
+ Integer namespaceIndex) {
+ super(urn, payload, endpoint, namespaceIndex);
+ }
+
+ @Override
+ PCollection<OutputT> toOutputCollection(Map<TupleTag<?>, PCollection>
output) {
+ checkArgument(output.size() > 0, "output shouldn't be empty.");
+ return Iterables.getOnlyElement(output.values());
+ }
+
+ public MultiOutputExpandableTransform withMultiOutputs() {
+ return new MultiOutputExpandableTransform(
+ getUrn(), getPayload(), getEndpoint(), getNamespaceIndex());
+ }
+ }
+
+ /** Expandable transform for output type of PCollectionTuple. */
+ public static class MultiOutputExpandableTransform extends
ExpandableTransform<PCollectionTuple> {
+ MultiOutputExpandableTransform(
+ String urn,
+ byte[] payload,
+ Endpoints.ApiServiceDescriptor endpoint,
+ Integer namespaceIndex) {
+ super(urn, payload, endpoint, namespaceIndex);
+ }
+
+ @Override
+ PCollectionTuple toOutputCollection(Map<TupleTag<?>, PCollection> output) {
+ checkArgument(output.size() > 0, "output shouldn't be empty.");
+ PCollection firstElem = Iterables.getFirst(output.values(), null);
+ PCollectionTuple pCollectionTuple =
PCollectionTuple.empty(firstElem.getPipeline());
+ for (Map.Entry<TupleTag<?>, PCollection> entry : output.entrySet()) {
+ pCollectionTuple = pCollectionTuple.and(entry.getKey(),
entry.getValue());
+ }
+ return pCollectionTuple;
+ }
+ }
+
+ /** Base Expandable Transform which calls ExpansionService to expand itself.
*/
+ public abstract static class ExpandableTransform<OutputT extends POutput>
+ extends PTransform<PInput, OutputT> {
+ private final String urn;
+ private final byte[] payload;
+ private final Endpoints.ApiServiceDescriptor endpoint;
+ private final Integer namespaceIndex;
+
+ @Nullable private transient RunnerApi.Components expandedComponents;
+ @Nullable private transient RunnerApi.PTransform expandedTransform;
+ @Nullable private transient Map<PCollection, String>
externalPCollectionIdMap;
+
+ ExpandableTransform(
+ String urn,
+ byte[] payload,
+ Endpoints.ApiServiceDescriptor endpoint,
+ Integer namespaceIndex) {
+ this.urn = urn;
+ this.payload = payload;
+ this.endpoint = endpoint;
+ this.namespaceIndex = namespaceIndex;
+ }
+
+ @Override
+ public OutputT expand(PInput input) {
+ Pipeline p = input.getPipeline();
+ SdkComponents components = SdkComponents.create(p.getOptions());
+ RunnerApi.PTransform.Builder ptransformBuilder =
+ RunnerApi.PTransform.newBuilder()
+ .setUniqueName(EXPANDED_TRANSFORM_BASE_NAME + namespaceIndex)
+ .setSpec(
+ RunnerApi.FunctionSpec.newBuilder()
+ .setUrn(urn)
+ .setPayload(ByteString.copyFrom(payload))
+ .build());
+ ImmutableMap.Builder<PCollection, String>
externalPCollectionIdMapBuilder =
+ ImmutableMap.builder();
+ for (Map.Entry<TupleTag<?>, PValue> entry : input.expand().entrySet()) {
+ if (entry.getValue() instanceof PCollection<?>) {
+ try {
+ String id = components.registerPCollection((PCollection)
entry.getValue());
+ externalPCollectionIdMapBuilder.put((PCollection)
entry.getValue(), id);
+ ptransformBuilder.putInputs(entry.getKey().getId(), id);
+ AppliedPTransform<?, ?, ?> fakeImpulse =
+ AppliedPTransform.of(
+ String.format("%s_%s", IMPULSE_PREFIX,
entry.getKey().getId()),
+ PBegin.in(p).expand(),
+ ImmutableMap.of(entry.getKey(), entry.getValue()),
+ Impulse.create(),
+ p);
+ // using fake Impulses to provide inputs
+ components.registerPTransform(fakeImpulse,
Collections.emptyList());
+ } catch (IOException e) {
+ throw new RuntimeException(
+ String.format("cannot register component: %s",
e.getMessage()));
+ }
+ }
+ }
+
+ ExpansionApi.ExpansionRequest request =
+ ExpansionApi.ExpansionRequest.newBuilder()
+ .setComponents(components.toComponents())
+ .setTransform(ptransformBuilder.build())
+ .setNamespace(getNamespace())
+ .build();
+
+ ExpansionApi.ExpansionResponse response =
+ DEFAULT.getExpansionServiceClient(endpoint).expand(request);
+
+ expandedComponents = response.getComponents();
+ expandedTransform = response.getTransform();
+
+ RehydratedComponents rehydratedComponents =
+
RehydratedComponents.forComponents(expandedComponents).withPipeline(p);
+ ImmutableMap.Builder<TupleTag<?>, PCollection> outputMapBuilder =
ImmutableMap.builder();
+ expandedTransform
+ .getOutputsMap()
+ .forEach(
+ (localId, pCollectionId) -> {
+ try {
+ PCollection col =
rehydratedComponents.getPCollection(pCollectionId);
+ externalPCollectionIdMapBuilder.put(col, pCollectionId);
+ outputMapBuilder.put(new TupleTag<>(localId), col);
+ } catch (IOException e) {
+ throw new RuntimeException("cannot rehydrate PCollection.");
+ }
+ });
+ externalPCollectionIdMap = externalPCollectionIdMapBuilder.build();
+
+ return toOutputCollection(outputMapBuilder.build());
+ }
+
+ abstract OutputT toOutputCollection(Map<TupleTag<?>, PCollection> output);
+
+ String getNamespace() {
+ return String.format("External_%s", namespaceIndex);
+ }
+
+ String getImpulsePrefix() {
+ return IMPULSE_PREFIX;
+ }
+
+ RunnerApi.PTransform getExpandedTransform() {
+ return expandedTransform;
+ }
+
+ RunnerApi.Components getExpandedComponents() {
+ return expandedComponents;
+ }
+
+ Map<PCollection, String> getExternalPCollectionIdMap() {
+ return externalPCollectionIdMap;
+ }
+
+ String getUrn() {
+ return urn;
+ }
+
+ byte[] getPayload() {
+ return payload;
+ }
+
+ Endpoints.ApiServiceDescriptor getEndpoint() {
+ return endpoint;
+ }
+
+ Integer getNamespaceIndex() {
+ return namespaceIndex;
+ }
+ }
+}
diff --git
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslation.java
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslation.java
new file mode 100644
index 0000000..9e1bafa
--- /dev/null
+++
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslation.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.core.construction;
+
+import static
org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;
+import static
org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkState;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import javax.annotation.Nullable;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.sdk.runners.AppliedPTransform;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PValue;
+import
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
+
+/** Translating External transforms to proto. */
+public class ExternalTranslation {
+ public static final String EXTERNAL_TRANSFORM_URN =
"urn:beam:transform:external:v1";
+
+ /** Translator for ExpandableTransform. */
+ public static class ExternalTranslator
+ implements
PTransformTranslation.TransformTranslator<External.ExpandableTransform<?>> {
+ public static PTransformTranslation.TransformTranslator create() {
+ return new ExternalTranslator();
+ }
+
+ @Nullable
+ @Override
+ public String getUrn(External.ExpandableTransform transform) {
+ return EXTERNAL_TRANSFORM_URN;
+ }
+
+ @Override
+ public boolean canTranslate(PTransform<?, ?> pTransform) {
+ return pTransform instanceof External.ExpandableTransform;
+ }
+
+ @Override
+ public RunnerApi.PTransform translate(
+ AppliedPTransform<?, ?, ?> appliedPTransform,
+ List<AppliedPTransform<?, ?, ?>> subtransforms,
+ SdkComponents components)
+ throws IOException {
+ checkArgument(
+ canTranslate(appliedPTransform.getTransform()), "can only translate
ExpandableTransform");
+
+ External.ExpandableTransform expandableTransform =
+ (External.ExpandableTransform) appliedPTransform.getTransform();
+ String nameSpace = expandableTransform.getNamespace();
+ String impulsePrefix = expandableTransform.getImpulsePrefix();
+ ImmutableMap.Builder<String, String> pColRenameMapBuilder =
ImmutableMap.builder();
+ RunnerApi.PTransform expandedTransform =
expandableTransform.getExpandedTransform();
+ RunnerApi.Components expandedComponents =
expandableTransform.getExpandedComponents();
+ Map<PCollection, String> externalPCollectionIdMap =
+ expandableTransform.getExternalPCollectionIdMap();
+
+ for (PValue pcol : appliedPTransform.getInputs().values()) {
+ if (!(pcol instanceof PCollection)) {
+ throw new RuntimeException("unknown input type.");
+ }
+ pColRenameMapBuilder.put(
+ externalPCollectionIdMap.get(pcol),
components.registerPCollection((PCollection) pcol));
+ }
+ for (PValue pcol : appliedPTransform.getOutputs().values()) {
+ if (!(pcol instanceof PCollection)) {
+ throw new RuntimeException("unknown input type.");
+ }
+ pColRenameMapBuilder.put(
+ externalPCollectionIdMap.get(pcol),
components.registerPCollection((PCollection) pcol));
+ }
+
+ ImmutableMap<String, String> pColRenameMap =
pColRenameMapBuilder.build();
+ RunnerApi.Components.Builder mergingComponentsBuilder =
RunnerApi.Components.newBuilder();
+ for (Map.Entry<String, RunnerApi.Coder> entry :
+ expandedComponents.getCodersMap().entrySet()) {
+ if (entry.getKey().startsWith(nameSpace)) {
+ mergingComponentsBuilder.putCoders(entry.getKey(), entry.getValue());
+ }
+ }
+ for (Map.Entry<String, RunnerApi.WindowingStrategy> entry :
+ expandedComponents.getWindowingStrategiesMap().entrySet()) {
+ if (entry.getKey().startsWith(nameSpace)) {
+ mergingComponentsBuilder.putWindowingStrategies(entry.getKey(),
entry.getValue());
+ }
+ }
+ for (Map.Entry<String, RunnerApi.Environment> entry :
+ expandedComponents.getEnvironmentsMap().entrySet()) {
+ if (entry.getKey().startsWith(nameSpace)) {
+ mergingComponentsBuilder.putEnvironments(entry.getKey(),
entry.getValue());
+ }
+ }
+ for (Map.Entry<String, RunnerApi.PCollection> entry :
+ expandedComponents.getPcollectionsMap().entrySet()) {
+ if (entry.getKey().startsWith(nameSpace)) {
+ mergingComponentsBuilder.putPcollections(entry.getKey(),
entry.getValue());
+ }
+ }
+ for (Map.Entry<String, RunnerApi.PTransform> entry :
+ expandedComponents.getTransformsMap().entrySet()) {
+ // ignore dummy Impulses we added for fake inputs
+ if (entry.getKey().startsWith(impulsePrefix)) {
+ continue;
+ }
+ checkState(entry.getKey().startsWith(nameSpace), "unknown transform
found");
+ RunnerApi.PTransform proto = entry.getValue();
+ RunnerApi.PTransform.Builder transformBuilder =
RunnerApi.PTransform.newBuilder();
+ transformBuilder
+ .setUniqueName(proto.getUniqueName())
+ .setSpec(proto.getSpec())
+ .addAllSubtransforms(proto.getSubtransformsList());
+ for (Map.Entry<String, String> inputEntry :
proto.getInputsMap().entrySet()) {
+ transformBuilder.putInputs(
+ inputEntry.getKey(),
+ pColRenameMap.getOrDefault(inputEntry.getValue(),
inputEntry.getValue()));
+ }
+ for (Map.Entry<String, String> outputEntry :
proto.getOutputsMap().entrySet()) {
+ transformBuilder.putOutputs(
+ outputEntry.getKey(),
+ pColRenameMap.getOrDefault(outputEntry.getValue(),
outputEntry.getValue()));
+ }
+ mergingComponentsBuilder.putTransforms(entry.getKey(),
transformBuilder.build());
+ }
+
+ RunnerApi.PTransform.Builder rootTransformBuilder =
RunnerApi.PTransform.newBuilder();
+ rootTransformBuilder
+ .setUniqueName(expandedTransform.getUniqueName())
+ .setSpec(expandedTransform.getSpec())
+ .addAllSubtransforms(expandedTransform.getSubtransformsList())
+ .putAllInputs(expandedTransform.getInputsMap());
+ for (Map.Entry<String, String> outputEntry :
expandedTransform.getOutputsMap().entrySet()) {
+ rootTransformBuilder.putOutputs(
+ outputEntry.getKey(),
+ pColRenameMap.getOrDefault(outputEntry.getValue(),
outputEntry.getValue()));
+ }
+ components.mergeFrom(mergingComponentsBuilder.build());
+
+ return rootTransformBuilder.build();
+ }
+ }
+}
diff --git
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
index a072ab3..b880283 100644
---
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
+++
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
@@ -35,6 +35,7 @@ import
org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
import org.apache.beam.model.pipeline.v1.RunnerApi.StandardPTransforms;
import
org.apache.beam.model.pipeline.v1.RunnerApi.StandardPTransforms.CombineComponents;
import
org.apache.beam.model.pipeline.v1.RunnerApi.StandardPTransforms.SplittableParDoComponents;
+import
org.apache.beam.runners.core.construction.ExternalTranslation.ExternalTranslator;
import
org.apache.beam.runners.core.construction.ParDoTranslation.ParDoTranslator;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.runners.AppliedPTransform;
@@ -125,6 +126,7 @@ public class PTransformTranslation {
.add(new RawPTransformTranslator())
.add(new KnownTransformPayloadTranslator())
.add(ParDoTranslator.create())
+ .add(ExternalTranslator.create())
.build();
}
diff --git
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java
index 2a1b335..8362bd6 100644
---
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java
+++
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java
@@ -54,9 +54,11 @@ public class SdkComponents {
private final Set<String> reservedIds = new HashSet<>();
+ private String defaultEnvironmentId;
+
/** Create a new {@link SdkComponents} with no components. */
public static SdkComponents create() {
- return new SdkComponents("");
+ return new SdkComponents(RunnerApi.Components.getDefaultInstance(), "");
}
/**
@@ -85,19 +87,16 @@ public class SdkComponents {
}
public static SdkComponents create(PipelineOptions options) {
- SdkComponents sdkComponents = new SdkComponents("");
+ SdkComponents sdkComponents = new
SdkComponents(RunnerApi.Components.getDefaultInstance(), "");
PortablePipelineOptions portablePipelineOptions =
options.as(PortablePipelineOptions.class);
- sdkComponents.registerEnvironment(
- Environments.createOrGetDefaultEnvironment(
- portablePipelineOptions.getDefaultEnvironmentType(),
- portablePipelineOptions.getDefaultEnvironmentConfig()));
+ sdkComponents.defaultEnvironmentId =
+ sdkComponents.registerEnvironment(
+ Environments.createOrGetDefaultEnvironment(
+ portablePipelineOptions.getDefaultEnvironmentType(),
+ portablePipelineOptions.getDefaultEnvironmentConfig()));
return sdkComponents;
}
- private SdkComponents(String newIdPrefix) {
- this.newIdPrefix = newIdPrefix;
- }
-
private SdkComponents(RunnerApi.Components components, String newIdPrefix) {
this.newIdPrefix = newIdPrefix;
@@ -105,6 +104,11 @@ public class SdkComponents {
return;
}
+ mergeFrom(components);
+ }
+
+ /** Merge Components proto into this SdkComponents instance. */
+ public void mergeFrom(RunnerApi.Components components) {
reservedIds.addAll(components.getTransformsMap().keySet());
reservedIds.addAll(components.getPcollectionsMap().keySet());
reservedIds.addAll(components.getWindowingStrategiesMap().keySet());
@@ -270,7 +274,11 @@ public class SdkComponents {
public String getOnlyEnvironmentId() {
// TODO Support multiple environments. The environment should be decided
by the translation.
- return
Iterables.getOnlyElement(componentsBuilder.getEnvironmentsMap().keySet());
+ if (defaultEnvironmentId != null) {
+ return defaultEnvironmentId;
+ } else {
+ return
Iterables.getOnlyElement(componentsBuilder.getEnvironmentsMap().keySet());
+ }
}
private String uniqify(String baseName, Set<String> existing) {
diff --git
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/expansion/ExpansionService.java
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/expansion/ExpansionService.java
index 75c7aba..2380808 100644
---
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/expansion/ExpansionService.java
+++
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/expansion/ExpansionService.java
@@ -49,7 +49,7 @@ import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.sdk.values.PInput;
-import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ServerBuilder;
@@ -242,10 +242,10 @@ public class ExpansionService extends
ExpansionServiceGrpc.ExpansionServiceImplB
* Provides a mapping of {@link RunnerApi.FunctionSpec} to a {@link
PTransform}, together with
* mappings of its inputs and outputs to maps of PCollections.
*
- * @param <InputT> input {@link PValue} type of the transform
- * @param <OutputT> output {@link PValue} type of the transform
+ * @param <InputT> input {@link PInput} type of the transform
+ * @param <OutputT> output {@link POutput} type of the transform
*/
- public interface TransformProvider<InputT extends PInput, OutputT extends
PValue> {
+ public interface TransformProvider<InputT extends PInput, OutputT extends
POutput> {
default InputT createInput(Pipeline p, Map<String, PCollection<?>> inputs)
{
if (inputs.size() == 0) {
diff --git
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ExternalTest.java
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ExternalTest.java
new file mode 100644
index 0000000..1e35bd4
--- /dev/null
+++
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ExternalTest.java
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.core.construction;
+
+import com.google.auto.service.AutoService;
+import java.io.IOException;
+import java.io.Serializable;
+import java.nio.charset.StandardCharsets;
+import java.util.Map;
+import org.apache.beam.runners.core.construction.expansion.ExpansionService;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.UsesCrossLanguageTransforms;
+import org.apache.beam.sdk.testing.ValidatesRunner;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Filter;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TupleTagList;
+import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.sdk.values.TypeDescriptors;
+import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ConnectivityState;
+import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ManagedChannelBuilder;
+import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.Server;
+import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ServerBuilder;
+import
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Test External transforms. */
+@RunWith(JUnit4.class)
+public class ExternalTest implements Serializable {
+ @Rule public transient TestPipeline testPipeline = TestPipeline.create();
+
+ private static final String TEST_URN_SIMPLE = "simple";
+ private static final String TEST_URN_LE = "le";
+ private static final String TEST_URN_MULTI = "multi";
+
+ private static String pythonServerCommand;
+ private static Integer expansionPort;
+ private static String localExpansionAddr;
+ private static Server localExpansionServer;
+
+ @BeforeClass
+ public static void setUp() throws IOException {
+ pythonServerCommand = System.getProperty("pythonTestExpansionCommand");
+ expansionPort = Integer.valueOf(System.getProperty("expansionPort"));
+ int localExpansionPort = expansionPort + 100;
+ localExpansionAddr = String.format("localhost:%s", localExpansionPort);
+
+ localExpansionServer =
+ ServerBuilder.forPort(localExpansionPort).addService(new
ExpansionService()).build();
+ localExpansionServer.start();
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ localExpansionServer.shutdownNow();
+ }
+
+ @Test
+ @Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
+ public void expandSingleTest() {
+ PCollection<Integer> col =
+ testPipeline
+ .apply(Create.of(1, 2, 3))
+ .apply(External.of(TEST_URN_SIMPLE, new byte[] {},
localExpansionAddr));
+ PAssert.that(col).containsInAnyOrder(2, 3, 4);
+ testPipeline.run();
+ }
+
+ @Test
+ @Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
+ public void expandMultipleTest() {
+ PCollection<Integer> pcol =
+ testPipeline
+ .apply(Create.of(1, 2, 3))
+ .apply("add one", External.of(TEST_URN_SIMPLE, new byte[] {},
localExpansionAddr))
+ .apply(
+ "filter <=3",
+ External.of(TEST_URN_LE, "3".getBytes(StandardCharsets.UTF_8),
localExpansionAddr));
+
+ PAssert.that(pcol).containsInAnyOrder(2, 3);
+ testPipeline.run();
+ }
+
+ @Test
+ @Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
+ public void expandMultiOutputTest() {
+ PCollectionTuple pTuple =
+ testPipeline
+ .apply(Create.of(1, 2, 3, 4, 5, 6))
+ .apply(
+ External.of(TEST_URN_MULTI, new byte[] {},
localExpansionAddr).withMultiOutputs());
+
+ PAssert.that(pTuple.get(new TupleTag<Integer>("even")
{})).containsInAnyOrder(2, 4, 6);
+ PAssert.that(pTuple.get(new TupleTag<Integer>("odd")
{})).containsInAnyOrder(1, 3, 5);
+ testPipeline.run();
+ }
+
+ private Process runCommandline(String command) {
+ ProcessBuilder builder = new ProcessBuilder("sh", "-c", command);
+ try {
+ return builder.start();
+ } catch (IOException e) {
+ throw new AssertionError("process launch failed.");
+ }
+ }
+
+ @Test
+ @Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
+ public void expandPythonTest() {
+ String target = String.format("localhost:%s", expansionPort);
+ Process p = runCommandline(String.format("%s -p %s", pythonServerCommand,
expansionPort));
+ try {
+ ManagedChannel channel = ManagedChannelBuilder.forTarget(target).build();
+ ConnectivityState state = channel.getState(true);
+ for (int retry = 0; retry < 30 && state != ConnectivityState.READY;
retry++) {
+ Thread.sleep(500);
+ state = channel.getState(true);
+ }
+ channel.shutdownNow();
+
+ PCollection<String> pCol =
+ testPipeline
+ .apply(Create.of("1", "2", "2", "3", "3", "3"))
+ .apply(
+ "toBytes",
+ MapElements.into(new TypeDescriptor<byte[]>()
{}).via(String::getBytes))
+ .apply(External.<byte[]>of("count_per_element_bytes", new byte[]
{}, target))
+ .apply("toString",
MapElements.into(TypeDescriptors.strings()).via(String::new));
+
+ PAssert.that(pCol).containsInAnyOrder("1->1", "2->2", "3->3");
+ testPipeline.run();
+ } catch (InterruptedException e) {
+ throw new RuntimeException("interrupted.");
+ } finally {
+ p.destroyForcibly();
+ }
+ }
+
+ /** Test TransformProvider. */
+ @AutoService(ExpansionService.ExpansionServiceRegistrar.class)
+ public static class TestTransforms
+ implements ExpansionService.ExpansionServiceRegistrar, Serializable {
+ private final TupleTag<Integer> even = new TupleTag<Integer>("even") {};
+ private final TupleTag<Integer> odd = new TupleTag<Integer>("odd") {};
+
+ @Override
+ public Map<String, ExpansionService.TransformProvider> knownTransforms() {
+ return ImmutableMap.of(
+ TEST_URN_SIMPLE,
+ spec ->
MapElements.into(TypeDescriptors.integers()).via((Integer x) -> x + 1),
+ TEST_URN_LE,
+ spec ->
Filter.lessThanEq(Integer.parseInt(spec.getPayload().toStringUtf8())),
+ TEST_URN_MULTI,
+ spec ->
+ ParDo.of(
+ new DoFn<Integer, Integer>() {
+ @ProcessElement
+ public void processElement(ProcessContext c) {
+ if (c.element() % 2 == 0) {
+ c.output(c.element());
+ } else {
+ c.output(odd, c.element());
+ }
+ }
+ })
+ .withOutputTags(even, TupleTagList.of(odd)));
+ }
+ }
+}
diff --git a/runners/direct-java/build.gradle b/runners/direct-java/build.gradle
index b6fbbff..3b90565 100644
--- a/runners/direct-java/build.gradle
+++ b/runners/direct-java/build.gradle
@@ -117,6 +117,7 @@ task needsRunnerTests(type: Test) {
excludeCategories "org.apache.beam.sdk.testing.LargeKeys\$Above100MB"
// MetricsPusher isn't implemented in direct runner
excludeCategories "org.apache.beam.sdk.testing.UsesMetricsPusher"
+ excludeCategories "org.apache.beam.sdk.testing.UsesCrossLanguageTransforms"
}
}
@@ -140,6 +141,7 @@ task validatesRunner(type: Test) {
includeCategories "org.apache.beam.sdk.testing.ValidatesRunner"
excludeCategories "org.apache.beam.sdk.testing.LargeKeys\$Above100MB"
excludeCategories 'org.apache.beam.sdk.testing.UsesMetricsPusher'
+ excludeCategories "org.apache.beam.sdk.testing.UsesCrossLanguageTransforms"
}
}
@@ -167,6 +169,7 @@ createPortableValidatesRunnerTask(
includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner'
excludeCategories
'org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders'
excludeCategories 'org.apache.beam.sdk.testing.LargeKeys$Above100MB'
+ excludeCategories
'org.apache.beam.sdk.testing.UsesCrossLanguageTransforms'
excludeCategories
'org.apache.beam.sdk.testing.UsesCustomWindowMerging'
excludeCategories
'org.apache.beam.sdk.testing.UsesDistributionMetrics'
excludeCategories 'org.apache.beam.sdk.testing.UsesFailureMessage'
diff --git a/runners/flink/job-server/flink_job_server.gradle
b/runners/flink/job-server/flink_job_server.gradle
index ebed8e4..969c957 100644
--- a/runners/flink/job-server/flink_job_server.gradle
+++ b/runners/flink/job-server/flink_job_server.gradle
@@ -132,6 +132,7 @@ def portableValidatesRunnerTask(String name, Boolean
streaming) {
excludeCategories 'org.apache.beam.sdk.testing.UsesAttemptedMetrics'
excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics'
excludeCategories 'org.apache.beam.sdk.testing.UsesCounterMetrics'
+ excludeCategories
'org.apache.beam.sdk.testing.UsesCrossLanguageTransforms'
excludeCategories 'org.apache.beam.sdk.testing.UsesCustomWindowMerging'
excludeCategories 'org.apache.beam.sdk.testing.UsesDistributionMetrics'
excludeCategories 'org.apache.beam.sdk.testing.UsesFailureMessage'
@@ -155,3 +156,24 @@ task validatesPortableRunner() {
dependsOn validatesPortableRunnerBatch
dependsOn validatesPortableRunnerStreaming
}
+
+project.ext.validatesCrossLanguageTransforms =
+ createPortableValidatesRunnerTask(
+ name: "validatesCrossLanguageTransforms",
+ jobServerDriver: "org.apache.beam.runners.flink.FlinkJobServerDriver",
+ jobServerConfig:
"--clean-artifacts-per-job,--job-host=localhost,--job-port=0,--artifact-port=0,--expansion-port=0",
+ testClasspathConfiguration: configurations.validatesPortableRunner,
+ numParallelTests: 1,
+ pipelineOpts: [
+ // Limit resource consumption via parallelism
+ "--parallelism=2",
+ "--shutdownSourcesOnFinalWatermark",
+ ],
+ testCategories: {
+ // Only include cross-language transform tests. Avoid to retest
everything on Docker environment.
+ includeCategories
'org.apache.beam.sdk.testing.UsesCrossLanguageTransforms'
+ },
+ )
+project.evaluationDependsOn ':beam-sdks-python'
+validatesCrossLanguageTransforms.dependsOn ':beam-sdks-python:setupVirtualenv'
+validatesCrossLanguageTransforms.systemProperty "pythonTestExpansionCommand",
". ${project(':beam-sdks-python').envdir}/bin/activate && pip install -e
${project(':beam-sdks-python').projectDir}[test] && python -m
apache_beam.runners.portability.expansion_service_test"
diff --git a/runners/google-cloud-dataflow-java/build.gradle
b/runners/google-cloud-dataflow-java/build.gradle
index 8280011..86d6f06 100644
--- a/runners/google-cloud-dataflow-java/build.gradle
+++ b/runners/google-cloud-dataflow-java/build.gradle
@@ -131,6 +131,7 @@ def fnApiPipelineOptions = [
def commonExcludeCategories = [
'org.apache.beam.sdk.testing.LargeKeys$Above10MB',
'org.apache.beam.sdk.testing.UsesAttemptedMetrics',
+ 'org.apache.beam.sdk.testing.UsesCrossLanguageTransforms',
'org.apache.beam.sdk.testing.UsesDistributionMetrics',
'org.apache.beam.sdk.testing.UsesGaugeMetrics',
'org.apache.beam.sdk.testing.UsesSetState',
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesCrossLanguageTransforms.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesCrossLanguageTransforms.java
new file mode 100644
index 0000000..b249eed
--- /dev/null
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesCrossLanguageTransforms.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.testing;
+
+/**
+ * Category tag for validation tests which use cross-language transforms.
Tests tagged with {@link
+ * UsesCrossLanguageTransforms} should be run for runners which support
cross-language transforms.
+ */
+public interface UsesCrossLanguageTransforms {}
diff --git a/sdks/python/apache_beam/runners/portability/expansion_service.py
b/sdks/python/apache_beam/runners/portability/expansion_service.py
index e407892..55a526d 100644
--- a/sdks/python/apache_beam/runners/portability/expansion_service.py
+++ b/sdks/python/apache_beam/runners/portability/expansion_service.py
@@ -20,10 +20,6 @@
from __future__ import absolute_import
from __future__ import print_function
-import argparse
-import logging
-import sys
-import time
import traceback
from apache_beam import pipeline as beam_pipeline
@@ -43,7 +39,7 @@ class ExpansionServiceServicer(
self._options = options or beam_pipeline.PipelineOptions(
environment_type=python_urns.EMBEDDED_PYTHON)
- def Expand(self, request):
+ def Expand(self, request, context):
try:
pipeline = beam_pipeline.Pipeline(options=self._options)
@@ -98,21 +94,3 @@ class ExpansionServiceServicer(
except Exception: # pylint: disable=broad-except
return beam_expansion_api_pb2.ExpansionResponse(
error=traceback.format_exc())
-
-
-def main(unused_argv):
- parser = argparse.ArgumentParser()
- parser.add_argument('-p', '--port',
- type=int,
- help='port on which to serve the job api')
- options = parser.parse_args()
- expansion_servicer = ExpansionServiceServicer()
- port = expansion_servicer.start_grpc_server(options.port)
- while True:
- logging.info('Listening for expansion requests at %d', port)
- time.sleep(300)
-
-
-if __name__ == '__main__':
- logging.getLogger().setLevel(logging.INFO)
- main(sys.argv)
diff --git
a/sdks/python/apache_beam/runners/portability/expansion_service_test.py
b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
new file mode 100644
index 0000000..9e93a64
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
@@ -0,0 +1,166 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import absolute_import
+
+import argparse
+import concurrent.futures as futures
+import logging
+import signal
+import sys
+
+import grpc
+
+import apache_beam as beam
+import apache_beam.transforms.combiners as combine
+from apache_beam.pipeline import PipelineOptions
+from apache_beam.portability.api import beam_expansion_api_pb2_grpc
+from apache_beam.runners.portability import expansion_service
+from apache_beam.transforms import ptransform
+
+# This script provides an expansion service and example ptransforms for running
+# external transform test cases. See external_test.py for details.
+
+
[email protected]_urn('count_per_element_bytes', None)
+class KV2BytesTransform(ptransform.PTransform):
+ def expand(self, pcoll):
+ return (
+ pcoll
+ | combine.Count.PerElement()
+ | beam.Map(
+ lambda x: '{}->{}'.format(x[0], x[1])).with_output_types(bytes)
+ )
+
+ def to_runner_api_parameter(self, unused_context):
+ return 'kv_to_bytes', None
+
+ @staticmethod
+ def from_runner_api_parameter(unused_parameter, unused_context):
+ return KV2BytesTransform()
+
+
[email protected]_urn('simple', None)
+class SimpleTransform(ptransform.PTransform):
+ def expand(self, pcoll):
+ return pcoll | 'TestLabel' >> beam.Map(lambda x: 'Simple(%s)' % x)
+
+ def to_runner_api_parameter(self, unused_context):
+ return 'simple', None
+
+ @staticmethod
+ def from_runner_api_parameter(unused_parameter, unused_context):
+ return SimpleTransform()
+
+
[email protected]_urn('multi', None)
+class MutltiTransform(ptransform.PTransform):
+ def expand(self, pcolls):
+ return {
+ 'main':
+ (pcolls['main1'], pcolls['main2'])
+ | beam.Flatten()
+ | beam.Map(lambda x, s: x + s,
+ beam.pvalue.AsSingleton(pcolls['side'])),
+ 'side': pcolls['side'] | beam.Map(lambda x: x + x),
+ }
+
+ def to_runner_api_parameter(self, unused_context):
+ return 'multi', None
+
+ @staticmethod
+ def from_runner_api_parameter(unused_parameter, unused_context):
+ return MutltiTransform()
+
+
[email protected]_urn('payload', bytes)
+class PayloadTransform(ptransform.PTransform):
+ def __init__(self, payload):
+ self._payload = payload
+
+ def expand(self, pcoll):
+ return pcoll | beam.Map(lambda x, s: x + s, self._payload)
+
+ def to_runner_api_parameter(self, unused_context):
+ return b'payload', self._payload.encode('ascii')
+
+ @staticmethod
+ def from_runner_api_parameter(payload, unused_context):
+ return PayloadTransform(payload.decode('ascii'))
+
+
[email protected]_urn('fib', bytes)
+class FibTransform(ptransform.PTransform):
+ def __init__(self, level):
+ self._level = level
+
+ def expand(self, p):
+ if self._level <= 2:
+ return p | beam.Create([1])
+ else:
+ a = p | 'A' >> beam.ExternalTransform(
+ 'fib', str(self._level - 1).encode('ascii'),
+ expansion_service.ExpansionServiceServicer())
+ b = p | 'B' >> beam.ExternalTransform(
+ 'fib', str(self._level - 2).encode('ascii'),
+ expansion_service.ExpansionServiceServicer())
+ return (
+ (a, b)
+ | beam.Flatten()
+ | beam.CombineGlobally(sum).without_defaults())
+
+ def to_runner_api_parameter(self, unused_context):
+ return 'fib', str(self._level).encode('ascii')
+
+ @staticmethod
+ def from_runner_api_parameter(level, unused_context):
+ return FibTransform(int(level.decode('ascii')))
+
+
+server = None
+
+
+def main(unused_argv):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-p', '--port',
+ type=int,
+ help='port on which to serve the job api')
+ options = parser.parse_args()
+ global server
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
+ beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server(
+ expansion_service.ExpansionServiceServicer(PipelineOptions()), server
+ )
+ server.add_insecure_port('localhost:{}'.format(options.port))
+ server.start()
+ logging.info('Listening for expansion requests at %d', options.port)
+
+ # blocking main thread forever.
+ signal.pause()
+
+
+def cleanup(unused_signum, unused_frame):
+ logging.info('Shutting down expansion service.')
+ server.stop(None)
+
+
+signal.signal(signal.SIGTERM, cleanup)
+signal.signal(signal.SIGINT, cleanup)
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ main(sys.argv)
diff --git a/sdks/python/apache_beam/transforms/external.py
b/sdks/python/apache_beam/transforms/external.py
index cd670b6..9077ebe 100644
--- a/sdks/python/apache_beam/transforms/external.py
+++ b/sdks/python/apache_beam/transforms/external.py
@@ -45,7 +45,6 @@ class ExternalTransform(ptransform.PTransform):
_namespace_counter = 0
_namespace = threading.local()
- _namespace.value = 'external'
_EXPANDED_TRANSFORM_UNIQUE_NAME = 'root'
_IMPULSE_PREFIX = 'impulse'
@@ -63,9 +62,13 @@ class ExternalTransform(ptransform.PTransform):
return '%s(%s)' % (self.__class__.__name__, self._urn)
@classmethod
+ def get_local_namespace(cls):
+ return getattr(cls._namespace, 'value', 'external')
+
+ @classmethod
@contextlib.contextmanager
def outer_namespace(cls, namespace):
- prev = cls._namespace.value
+ prev = cls.get_local_namespace()
cls._namespace.value = namespace
yield
cls._namespace.value = prev
@@ -73,7 +76,7 @@ class ExternalTransform(ptransform.PTransform):
@classmethod
def _fresh_namespace(cls):
ExternalTransform._namespace_counter += 1
- return '%s_%d' % (cls._namespace.value, cls._namespace_counter)
+ return '%s_%d' % (cls.get_local_namespace(), cls._namespace_counter)
def expand(self, pvalueish):
if isinstance(pvalueish, pvalue.PBegin):
@@ -115,7 +118,7 @@ class ExternalTransform(ptransform.PTransform):
response = beam_expansion_api_pb2_grpc.ExpansionServiceStub(
channel).Expand(request)
else:
- response = self._endpoint.Expand(request)
+ response = self._endpoint.Expand(request, None)
if response.error:
raise RuntimeError(response.error)
diff --git a/sdks/python/apache_beam/transforms/external_test.py
b/sdks/python/apache_beam/transforms/external_test.py
index c3448c6..481d673 100644
--- a/sdks/python/apache_beam/transforms/external_test.py
+++ b/sdks/python/apache_beam/transforms/external_test.py
@@ -32,9 +32,9 @@ from apache_beam import Pipeline
from apache_beam.io.external.generate_sequence import GenerateSequence
from apache_beam.portability import python_urns
from apache_beam.runners.portability import expansion_service
+from apache_beam.runners.portability.expansion_service_test import FibTransform
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
-from apache_beam.transforms import ptransform
class ExternalTransformTest(unittest.TestCase):
@@ -43,19 +43,6 @@ class ExternalTransformTest(unittest.TestCase):
expansion_service_jar = None
def test_pipeline_generation(self):
-
- @ptransform.PTransform.register_urn('simple', None)
- class SimpleTransform(ptransform.PTransform):
- def expand(self, pcoll):
- return pcoll | 'TestLabel' >> beam.Map(lambda x: 'Simple(%s)' % x)
-
- def to_runner_api_parameter(self, unused_context):
- return 'simple', None
-
- @staticmethod
- def from_runner_api_parameter(unused_parameter, unused_context):
- return SimpleTransform()
-
pipeline = beam.Pipeline()
res = (pipeline
| beam.Create(['a', 'b'])
@@ -81,19 +68,6 @@ class ExternalTransformTest(unittest.TestCase):
pipeline_from_proto.transforms_stack[0].parts[1].parts[0].full_label)
def test_simple(self):
-
- @ptransform.PTransform.register_urn('simple', None)
- class SimpleTransform(ptransform.PTransform):
- def expand(self, pcoll):
- return pcoll | beam.Map(lambda x: 'Simple(%s)' % x)
-
- def to_runner_api_parameter(self, unused_context):
- return 'simple', None
-
- @staticmethod
- def from_runner_api_parameter(unused_parameter, unused_context):
- return SimpleTransform()
-
with beam.Pipeline() as p:
res = (
p
@@ -105,26 +79,6 @@ class ExternalTransformTest(unittest.TestCase):
assert_that(res, equal_to(['Simple(a)', 'Simple(b)']))
def test_multi(self):
-
- @ptransform.PTransform.register_urn('multi', None)
- class MutltiTransform(ptransform.PTransform):
- def expand(self, pcolls):
- return {
- 'main':
- (pcolls['main1'], pcolls['main2'])
- | beam.Flatten()
- | beam.Map(lambda x, s: x + s,
- beam.pvalue.AsSingleton(pcolls['side'])),
- 'side': pcolls['side'] | beam.Map(lambda x: x + x),
- }
-
- def to_runner_api_parameter(self, unused_context):
- return 'multi', None
-
- @staticmethod
- def from_runner_api_parameter(unused_parameter, unused_context):
- return MutltiTransform()
-
with beam.Pipeline() as p:
main1 = p | 'Main1' >> beam.Create(['a', 'bb'], reshuffle=False)
main2 = p | 'Main2' >> beam.Create(['x', 'yy', 'zzz'], reshuffle=False)
@@ -135,22 +89,6 @@ class ExternalTransformTest(unittest.TestCase):
assert_that(res['side'], equal_to(['ss']), label='CheckSide')
def test_payload(self):
-
- @ptransform.PTransform.register_urn('payload', bytes)
- class PayloadTransform(ptransform.PTransform):
- def __init__(self, payload):
- self._payload = payload
-
- def expand(self, pcoll):
- return pcoll | beam.Map(lambda x, s: x + s, self._payload)
-
- def to_runner_api_parameter(self, unused_context):
- return b'payload', self._payload.encode('ascii')
-
- @staticmethod
- def from_runner_api_parameter(payload, unused_context):
- return PayloadTransform(payload.decode('ascii'))
-
with beam.Pipeline() as p:
res = (
p
@@ -161,33 +99,6 @@ class ExternalTransformTest(unittest.TestCase):
assert_that(res, equal_to(['as', 'bbs']))
def test_nested(self):
- @ptransform.PTransform.register_urn('fib', bytes)
- class FibTransform(ptransform.PTransform):
- def __init__(self, level):
- self._level = level
-
- def expand(self, p):
- if self._level <= 2:
- return p | beam.Create([1])
- else:
- a = p | 'A' >> beam.ExternalTransform(
- 'fib', str(self._level - 1).encode('ascii'),
- expansion_service.ExpansionServiceServicer())
- b = p | 'B' >> beam.ExternalTransform(
- 'fib', str(self._level - 2).encode('ascii'),
- expansion_service.ExpansionServiceServicer())
- return (
- (a, b)
- | beam.Flatten()
- | beam.CombineGlobally(sum).without_defaults())
-
- def to_runner_api_parameter(self, unused_context):
- return 'fib', str(self._level).encode('ascii')
-
- @staticmethod
- def from_runner_api_parameter(level, unused_context):
- return FibTransform(int(level.decode('ascii')))
-
with beam.Pipeline() as p:
assert_that(p | FibTransform(6), equal_to([8]))