This is an automated email from the ASF dual-hosted git repository. mxm pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 60131b8 [BEAM-6747] Adding ExternalTransform in JavaSDK new 4c32210 Merge pull request #7954 from ihji/BEAM-6747 60131b8 is described below commit 60131b84d38f1c223b3e1715c8964fb5cfdea54f Author: Heejong Lee <heej...@gmail.com> AuthorDate: Wed Feb 20 18:22:35 2019 -0800 [BEAM-6747] Adding ExternalTransform in JavaSDK --- .../org/apache/beam/gradle/BeamModulePlugin.groovy | 7 +- .../DefaultExpansionServiceClientFactory.java | 67 ++++++ .../core/construction/ExpansionServiceClient.java | 25 +++ .../ExpansionServiceClientFactory.java | 28 +++ .../beam/runners/core/construction/External.java | 250 +++++++++++++++++++++ .../core/construction/ExternalTranslation.java | 158 +++++++++++++ .../core/construction/PTransformTranslation.java | 2 + .../runners/core/construction/SdkComponents.java | 30 ++- .../construction/expansion/ExpansionService.java | 8 +- .../runners/core/construction/ExternalTest.java | 197 ++++++++++++++++ runners/direct-java/build.gradle | 3 + runners/flink/job-server/flink_job_server.gradle | 22 ++ runners/google-cloud-dataflow-java/build.gradle | 1 + .../sdk/testing/UsesCrossLanguageTransforms.java | 24 ++ .../runners/portability/expansion_service.py | 24 +- .../runners/portability/expansion_service_test.py | 166 ++++++++++++++ sdks/python/apache_beam/transforms/external.py | 11 +- .../python/apache_beam/transforms/external_test.py | 91 +------- 18 files changed, 981 insertions(+), 133 deletions(-) diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 918df02..4852e5c 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -21,6 +21,7 @@ package org.apache.beam.gradle import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar import groovy.json.JsonOutput import groovy.json.JsonSlurper +import java.util.concurrent.atomic.AtomicInteger import org.gradle.api.GradleException import org.gradle.api.Plugin import org.gradle.api.Project @@ -71,6 +72,7 @@ class BeamModulePlugin implements Plugin<Project> { * limitations under the License. */ """ + static AtomicInteger startingExpansionPortNumber = new AtomicInteger(18091) /** A class defining the set of configurable properties accepted by applyJavaNature. */ class JavaNatureConfiguration { @@ -1545,6 +1547,7 @@ class BeamModulePlugin implements Plugin<Project> { */ project.evaluationDependsOn(":beam-sdks-java-core") project.evaluationDependsOn(":beam-runners-core-java") + project.evaluationDependsOn(":beam-runners-core-construction-java") def config = it ? it as PortableValidatesRunnerConfiguration : new PortableValidatesRunnerConfiguration() def name = config.name def beamTestPipelineOptions = [ @@ -1552,6 +1555,7 @@ class BeamModulePlugin implements Plugin<Project> { "--jobServerDriver=${config.jobServerDriver}", "--environmentCacheMillis=10000" ] + def expansionPort = startingExpansionPortNumber.getAndDecrement() beamTestPipelineOptions.addAll(config.pipelineOpts) if (config.environment == PortableValidatesRunnerConfiguration.Environment.EMBEDDED) { beamTestPipelineOptions += "--defaultEnvironmentType=EMBEDDED" @@ -1563,8 +1567,9 @@ class BeamModulePlugin implements Plugin<Project> { group = "Verification" description = "Validates the PortableRunner with JobServer ${config.jobServerDriver}" systemProperty "beamTestPipelineOptions", JsonOutput.toJson(beamTestPipelineOptions) + systemProperty "expansionPort", expansionPort classpath = config.testClasspathConfiguration - testClassesDirs = project.files(project.project(":beam-sdks-java-core").sourceSets.test.output.classesDirs, project.project(":beam-runners-core-java").sourceSets.test.output.classesDirs) + testClassesDirs = project.files(project.project(":beam-sdks-java-core").sourceSets.test.output.classesDirs, project.project(":beam-runners-core-java").sourceSets.test.output.classesDirs, project.project(":beam-runners-core-construction-java").sourceSets.test.output.classesDirs) maxParallelForks config.numParallelTests useJUnit(config.testCategories) // increase maxHeapSize as this is directly correlated to direct memory, diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DefaultExpansionServiceClientFactory.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DefaultExpansionServiceClientFactory.java new file mode 100644 index 0000000..a1425cb --- /dev/null +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DefaultExpansionServiceClientFactory.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.construction; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; +import org.apache.beam.model.expansion.v1.ExpansionApi; +import org.apache.beam.model.expansion.v1.ExpansionServiceGrpc; +import org.apache.beam.model.pipeline.v1.Endpoints; +import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ManagedChannel; + +/** Default factory for ExpansionServiceClient used by External transform. */ +public class DefaultExpansionServiceClientFactory implements ExpansionServiceClientFactory { + private Map<Endpoints.ApiServiceDescriptor, ExpansionServiceClient> expansionServiceMap; + private Function<Endpoints.ApiServiceDescriptor, ManagedChannel> channelFactory; + + DefaultExpansionServiceClientFactory( + Function<Endpoints.ApiServiceDescriptor, ManagedChannel> channelFactory) { + this.expansionServiceMap = new ConcurrentHashMap<>(); + this.channelFactory = channelFactory; + } + + @Override + public void close() throws Exception { + for (ExpansionServiceClient client : expansionServiceMap.values()) { + try (AutoCloseable closer = client) {} + } + } + + @Override + public ExpansionServiceClient getExpansionServiceClient(Endpoints.ApiServiceDescriptor endpoint) { + return expansionServiceMap.computeIfAbsent( + endpoint, + e -> + new ExpansionServiceClient() { + private final ManagedChannel channel = channelFactory.apply(endpoint); + private final ExpansionServiceGrpc.ExpansionServiceBlockingStub service = + ExpansionServiceGrpc.newBlockingStub(channel); + + @Override + public ExpansionApi.ExpansionResponse expand(ExpansionApi.ExpansionRequest request) { + return service.expand(request); + } + + @Override + public void close() throws Exception { + channel.shutdown(); + } + }); + } +} diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExpansionServiceClient.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExpansionServiceClient.java new file mode 100644 index 0000000..e678766 --- /dev/null +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExpansionServiceClient.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.construction; + +import org.apache.beam.model.expansion.v1.ExpansionApi; + +/** A high-level client for a cross-language expansion service. */ +interface ExpansionServiceClient extends AutoCloseable { + ExpansionApi.ExpansionResponse expand(ExpansionApi.ExpansionRequest request); +} diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExpansionServiceClientFactory.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExpansionServiceClientFactory.java new file mode 100644 index 0000000..a43cb26 --- /dev/null +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExpansionServiceClientFactory.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.construction; + +import org.apache.beam.model.pipeline.v1.Endpoints; + +/** + * A factory for generating {@link ExpansionServiceClient} from {@link + * org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor}. + */ +interface ExpansionServiceClientFactory extends AutoCloseable { + ExpansionServiceClient getExpansionServiceClient(Endpoints.ApiServiceDescriptor endpoint); +} diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java new file mode 100644 index 0000000..6fd5c00 --- /dev/null +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.construction; + +import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; +import org.apache.beam.model.expansion.v1.ExpansionApi; +import org.apache.beam.model.pipeline.v1.Endpoints; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.transforms.Impulse; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.grpc.v1p13p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ManagedChannelBuilder; +import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables; + +/** + * Cross-language external transform. + * + * <p>{@link External} provides a cross-language transform via expansion services in foreign SDKs. + * In order to use {@link External} transform, a user should know 1) URN of the target transform 2) + * bytes encoding schema for configuration parameters 3) connection endpoint of the expansion + * service. Note that this is a low-level API and mainly for internal use. A user may want to use + * high-level wrapper classes rather than this one. + */ +public class External { + private static final String EXPANDED_TRANSFORM_BASE_NAME = "external"; + private static final String IMPULSE_PREFIX = "IMPULSE"; + private static AtomicInteger namespaceCounter = new AtomicInteger(0); + + private static final ExpansionServiceClientFactory DEFAULT = + new DefaultExpansionServiceClientFactory( + endPoint -> ManagedChannelBuilder.forTarget(endPoint.getUrl()).usePlaintext().build()); + + private static int getFreshNamespaceIndex() { + return namespaceCounter.getAndIncrement(); + } + + public static <OutputT> SingleOutputExpandableTransform<OutputT> of( + String urn, byte[] payload, String endpoint) { + Endpoints.ApiServiceDescriptor apiDesc = + Endpoints.ApiServiceDescriptor.newBuilder().setUrl(endpoint).build(); + return new SingleOutputExpandableTransform<>(urn, payload, apiDesc, getFreshNamespaceIndex()); + } + + /** Expandable transform for output type of PCollection. */ + public static class SingleOutputExpandableTransform<OutputT> + extends ExpandableTransform<PCollection<OutputT>> { + SingleOutputExpandableTransform( + String urn, + byte[] payload, + Endpoints.ApiServiceDescriptor endpoint, + Integer namespaceIndex) { + super(urn, payload, endpoint, namespaceIndex); + } + + @Override + PCollection<OutputT> toOutputCollection(Map<TupleTag<?>, PCollection> output) { + checkArgument(output.size() > 0, "output shouldn't be empty."); + return Iterables.getOnlyElement(output.values()); + } + + public MultiOutputExpandableTransform withMultiOutputs() { + return new MultiOutputExpandableTransform( + getUrn(), getPayload(), getEndpoint(), getNamespaceIndex()); + } + } + + /** Expandable transform for output type of PCollectionTuple. */ + public static class MultiOutputExpandableTransform extends ExpandableTransform<PCollectionTuple> { + MultiOutputExpandableTransform( + String urn, + byte[] payload, + Endpoints.ApiServiceDescriptor endpoint, + Integer namespaceIndex) { + super(urn, payload, endpoint, namespaceIndex); + } + + @Override + PCollectionTuple toOutputCollection(Map<TupleTag<?>, PCollection> output) { + checkArgument(output.size() > 0, "output shouldn't be empty."); + PCollection firstElem = Iterables.getFirst(output.values(), null); + PCollectionTuple pCollectionTuple = PCollectionTuple.empty(firstElem.getPipeline()); + for (Map.Entry<TupleTag<?>, PCollection> entry : output.entrySet()) { + pCollectionTuple = pCollectionTuple.and(entry.getKey(), entry.getValue()); + } + return pCollectionTuple; + } + } + + /** Base Expandable Transform which calls ExpansionService to expand itself. */ + public abstract static class ExpandableTransform<OutputT extends POutput> + extends PTransform<PInput, OutputT> { + private final String urn; + private final byte[] payload; + private final Endpoints.ApiServiceDescriptor endpoint; + private final Integer namespaceIndex; + + @Nullable private transient RunnerApi.Components expandedComponents; + @Nullable private transient RunnerApi.PTransform expandedTransform; + @Nullable private transient Map<PCollection, String> externalPCollectionIdMap; + + ExpandableTransform( + String urn, + byte[] payload, + Endpoints.ApiServiceDescriptor endpoint, + Integer namespaceIndex) { + this.urn = urn; + this.payload = payload; + this.endpoint = endpoint; + this.namespaceIndex = namespaceIndex; + } + + @Override + public OutputT expand(PInput input) { + Pipeline p = input.getPipeline(); + SdkComponents components = SdkComponents.create(p.getOptions()); + RunnerApi.PTransform.Builder ptransformBuilder = + RunnerApi.PTransform.newBuilder() + .setUniqueName(EXPANDED_TRANSFORM_BASE_NAME + namespaceIndex) + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn(urn) + .setPayload(ByteString.copyFrom(payload)) + .build()); + ImmutableMap.Builder<PCollection, String> externalPCollectionIdMapBuilder = + ImmutableMap.builder(); + for (Map.Entry<TupleTag<?>, PValue> entry : input.expand().entrySet()) { + if (entry.getValue() instanceof PCollection<?>) { + try { + String id = components.registerPCollection((PCollection) entry.getValue()); + externalPCollectionIdMapBuilder.put((PCollection) entry.getValue(), id); + ptransformBuilder.putInputs(entry.getKey().getId(), id); + AppliedPTransform<?, ?, ?> fakeImpulse = + AppliedPTransform.of( + String.format("%s_%s", IMPULSE_PREFIX, entry.getKey().getId()), + PBegin.in(p).expand(), + ImmutableMap.of(entry.getKey(), entry.getValue()), + Impulse.create(), + p); + // using fake Impulses to provide inputs + components.registerPTransform(fakeImpulse, Collections.emptyList()); + } catch (IOException e) { + throw new RuntimeException( + String.format("cannot register component: %s", e.getMessage())); + } + } + } + + ExpansionApi.ExpansionRequest request = + ExpansionApi.ExpansionRequest.newBuilder() + .setComponents(components.toComponents()) + .setTransform(ptransformBuilder.build()) + .setNamespace(getNamespace()) + .build(); + + ExpansionApi.ExpansionResponse response = + DEFAULT.getExpansionServiceClient(endpoint).expand(request); + + expandedComponents = response.getComponents(); + expandedTransform = response.getTransform(); + + RehydratedComponents rehydratedComponents = + RehydratedComponents.forComponents(expandedComponents).withPipeline(p); + ImmutableMap.Builder<TupleTag<?>, PCollection> outputMapBuilder = ImmutableMap.builder(); + expandedTransform + .getOutputsMap() + .forEach( + (localId, pCollectionId) -> { + try { + PCollection col = rehydratedComponents.getPCollection(pCollectionId); + externalPCollectionIdMapBuilder.put(col, pCollectionId); + outputMapBuilder.put(new TupleTag<>(localId), col); + } catch (IOException e) { + throw new RuntimeException("cannot rehydrate PCollection."); + } + }); + externalPCollectionIdMap = externalPCollectionIdMapBuilder.build(); + + return toOutputCollection(outputMapBuilder.build()); + } + + abstract OutputT toOutputCollection(Map<TupleTag<?>, PCollection> output); + + String getNamespace() { + return String.format("External_%s", namespaceIndex); + } + + String getImpulsePrefix() { + return IMPULSE_PREFIX; + } + + RunnerApi.PTransform getExpandedTransform() { + return expandedTransform; + } + + RunnerApi.Components getExpandedComponents() { + return expandedComponents; + } + + Map<PCollection, String> getExternalPCollectionIdMap() { + return externalPCollectionIdMap; + } + + String getUrn() { + return urn; + } + + byte[] getPayload() { + return payload; + } + + Endpoints.ApiServiceDescriptor getEndpoint() { + return endpoint; + } + + Integer getNamespaceIndex() { + return namespaceIndex; + } + } +} diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslation.java new file mode 100644 index 0000000..9e1bafa --- /dev/null +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslation.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.construction; + +import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkState; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap; + +/** Translating External transforms to proto. */ +public class ExternalTranslation { + public static final String EXTERNAL_TRANSFORM_URN = "urn:beam:transform:external:v1"; + + /** Translator for ExpandableTransform. */ + public static class ExternalTranslator + implements PTransformTranslation.TransformTranslator<External.ExpandableTransform<?>> { + public static PTransformTranslation.TransformTranslator create() { + return new ExternalTranslator(); + } + + @Nullable + @Override + public String getUrn(External.ExpandableTransform transform) { + return EXTERNAL_TRANSFORM_URN; + } + + @Override + public boolean canTranslate(PTransform<?, ?> pTransform) { + return pTransform instanceof External.ExpandableTransform; + } + + @Override + public RunnerApi.PTransform translate( + AppliedPTransform<?, ?, ?> appliedPTransform, + List<AppliedPTransform<?, ?, ?>> subtransforms, + SdkComponents components) + throws IOException { + checkArgument( + canTranslate(appliedPTransform.getTransform()), "can only translate ExpandableTransform"); + + External.ExpandableTransform expandableTransform = + (External.ExpandableTransform) appliedPTransform.getTransform(); + String nameSpace = expandableTransform.getNamespace(); + String impulsePrefix = expandableTransform.getImpulsePrefix(); + ImmutableMap.Builder<String, String> pColRenameMapBuilder = ImmutableMap.builder(); + RunnerApi.PTransform expandedTransform = expandableTransform.getExpandedTransform(); + RunnerApi.Components expandedComponents = expandableTransform.getExpandedComponents(); + Map<PCollection, String> externalPCollectionIdMap = + expandableTransform.getExternalPCollectionIdMap(); + + for (PValue pcol : appliedPTransform.getInputs().values()) { + if (!(pcol instanceof PCollection)) { + throw new RuntimeException("unknown input type."); + } + pColRenameMapBuilder.put( + externalPCollectionIdMap.get(pcol), components.registerPCollection((PCollection) pcol)); + } + for (PValue pcol : appliedPTransform.getOutputs().values()) { + if (!(pcol instanceof PCollection)) { + throw new RuntimeException("unknown input type."); + } + pColRenameMapBuilder.put( + externalPCollectionIdMap.get(pcol), components.registerPCollection((PCollection) pcol)); + } + + ImmutableMap<String, String> pColRenameMap = pColRenameMapBuilder.build(); + RunnerApi.Components.Builder mergingComponentsBuilder = RunnerApi.Components.newBuilder(); + for (Map.Entry<String, RunnerApi.Coder> entry : + expandedComponents.getCodersMap().entrySet()) { + if (entry.getKey().startsWith(nameSpace)) { + mergingComponentsBuilder.putCoders(entry.getKey(), entry.getValue()); + } + } + for (Map.Entry<String, RunnerApi.WindowingStrategy> entry : + expandedComponents.getWindowingStrategiesMap().entrySet()) { + if (entry.getKey().startsWith(nameSpace)) { + mergingComponentsBuilder.putWindowingStrategies(entry.getKey(), entry.getValue()); + } + } + for (Map.Entry<String, RunnerApi.Environment> entry : + expandedComponents.getEnvironmentsMap().entrySet()) { + if (entry.getKey().startsWith(nameSpace)) { + mergingComponentsBuilder.putEnvironments(entry.getKey(), entry.getValue()); + } + } + for (Map.Entry<String, RunnerApi.PCollection> entry : + expandedComponents.getPcollectionsMap().entrySet()) { + if (entry.getKey().startsWith(nameSpace)) { + mergingComponentsBuilder.putPcollections(entry.getKey(), entry.getValue()); + } + } + for (Map.Entry<String, RunnerApi.PTransform> entry : + expandedComponents.getTransformsMap().entrySet()) { + // ignore dummy Impulses we added for fake inputs + if (entry.getKey().startsWith(impulsePrefix)) { + continue; + } + checkState(entry.getKey().startsWith(nameSpace), "unknown transform found"); + RunnerApi.PTransform proto = entry.getValue(); + RunnerApi.PTransform.Builder transformBuilder = RunnerApi.PTransform.newBuilder(); + transformBuilder + .setUniqueName(proto.getUniqueName()) + .setSpec(proto.getSpec()) + .addAllSubtransforms(proto.getSubtransformsList()); + for (Map.Entry<String, String> inputEntry : proto.getInputsMap().entrySet()) { + transformBuilder.putInputs( + inputEntry.getKey(), + pColRenameMap.getOrDefault(inputEntry.getValue(), inputEntry.getValue())); + } + for (Map.Entry<String, String> outputEntry : proto.getOutputsMap().entrySet()) { + transformBuilder.putOutputs( + outputEntry.getKey(), + pColRenameMap.getOrDefault(outputEntry.getValue(), outputEntry.getValue())); + } + mergingComponentsBuilder.putTransforms(entry.getKey(), transformBuilder.build()); + } + + RunnerApi.PTransform.Builder rootTransformBuilder = RunnerApi.PTransform.newBuilder(); + rootTransformBuilder + .setUniqueName(expandedTransform.getUniqueName()) + .setSpec(expandedTransform.getSpec()) + .addAllSubtransforms(expandedTransform.getSubtransformsList()) + .putAllInputs(expandedTransform.getInputsMap()); + for (Map.Entry<String, String> outputEntry : expandedTransform.getOutputsMap().entrySet()) { + rootTransformBuilder.putOutputs( + outputEntry.getKey(), + pColRenameMap.getOrDefault(outputEntry.getValue(), outputEntry.getValue())); + } + components.mergeFrom(mergingComponentsBuilder.build()); + + return rootTransformBuilder.build(); + } + } +} diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java index a072ab3..b880283 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java @@ -35,6 +35,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; import org.apache.beam.model.pipeline.v1.RunnerApi.StandardPTransforms; import org.apache.beam.model.pipeline.v1.RunnerApi.StandardPTransforms.CombineComponents; import org.apache.beam.model.pipeline.v1.RunnerApi.StandardPTransforms.SplittableParDoComponents; +import org.apache.beam.runners.core.construction.ExternalTranslation.ExternalTranslator; import org.apache.beam.runners.core.construction.ParDoTranslation.ParDoTranslator; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.runners.AppliedPTransform; @@ -125,6 +126,7 @@ public class PTransformTranslation { .add(new RawPTransformTranslator()) .add(new KnownTransformPayloadTranslator()) .add(ParDoTranslator.create()) + .add(ExternalTranslator.create()) .build(); } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java index 2a1b335..8362bd6 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java @@ -54,9 +54,11 @@ public class SdkComponents { private final Set<String> reservedIds = new HashSet<>(); + private String defaultEnvironmentId; + /** Create a new {@link SdkComponents} with no components. */ public static SdkComponents create() { - return new SdkComponents(""); + return new SdkComponents(RunnerApi.Components.getDefaultInstance(), ""); } /** @@ -85,19 +87,16 @@ public class SdkComponents { } public static SdkComponents create(PipelineOptions options) { - SdkComponents sdkComponents = new SdkComponents(""); + SdkComponents sdkComponents = new SdkComponents(RunnerApi.Components.getDefaultInstance(), ""); PortablePipelineOptions portablePipelineOptions = options.as(PortablePipelineOptions.class); - sdkComponents.registerEnvironment( - Environments.createOrGetDefaultEnvironment( - portablePipelineOptions.getDefaultEnvironmentType(), - portablePipelineOptions.getDefaultEnvironmentConfig())); + sdkComponents.defaultEnvironmentId = + sdkComponents.registerEnvironment( + Environments.createOrGetDefaultEnvironment( + portablePipelineOptions.getDefaultEnvironmentType(), + portablePipelineOptions.getDefaultEnvironmentConfig())); return sdkComponents; } - private SdkComponents(String newIdPrefix) { - this.newIdPrefix = newIdPrefix; - } - private SdkComponents(RunnerApi.Components components, String newIdPrefix) { this.newIdPrefix = newIdPrefix; @@ -105,6 +104,11 @@ public class SdkComponents { return; } + mergeFrom(components); + } + + /** Merge Components proto into this SdkComponents instance. */ + public void mergeFrom(RunnerApi.Components components) { reservedIds.addAll(components.getTransformsMap().keySet()); reservedIds.addAll(components.getPcollectionsMap().keySet()); reservedIds.addAll(components.getWindowingStrategiesMap().keySet()); @@ -270,7 +274,11 @@ public class SdkComponents { public String getOnlyEnvironmentId() { // TODO Support multiple environments. The environment should be decided by the translation. - return Iterables.getOnlyElement(componentsBuilder.getEnvironmentsMap().keySet()); + if (defaultEnvironmentId != null) { + return defaultEnvironmentId; + } else { + return Iterables.getOnlyElement(componentsBuilder.getEnvironmentsMap().keySet()); + } } private String uniqify(String baseName, Set<String> existing) { diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/expansion/ExpansionService.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/expansion/ExpansionService.java index 75c7aba..2380808 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/expansion/ExpansionService.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/expansion/ExpansionService.java @@ -49,7 +49,7 @@ import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.PInput; -import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.Server; import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ServerBuilder; @@ -242,10 +242,10 @@ public class ExpansionService extends ExpansionServiceGrpc.ExpansionServiceImplB * Provides a mapping of {@link RunnerApi.FunctionSpec} to a {@link PTransform}, together with * mappings of its inputs and outputs to maps of PCollections. * - * @param <InputT> input {@link PValue} type of the transform - * @param <OutputT> output {@link PValue} type of the transform + * @param <InputT> input {@link PInput} type of the transform + * @param <OutputT> output {@link POutput} type of the transform */ - public interface TransformProvider<InputT extends PInput, OutputT extends PValue> { + public interface TransformProvider<InputT extends PInput, OutputT extends POutput> { default InputT createInput(Pipeline p, Map<String, PCollection<?>> inputs) { if (inputs.size() == 0) { diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ExternalTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ExternalTest.java new file mode 100644 index 0000000..1e35bd4 --- /dev/null +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ExternalTest.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.construction; + +import com.google.auto.service.AutoService; +import java.io.IOException; +import java.io.Serializable; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import org.apache.beam.runners.core.construction.expansion.ExpansionService; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.UsesCrossLanguageTransforms; +import org.apache.beam.sdk.testing.ValidatesRunner; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Filter; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ConnectivityState; +import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ManagedChannelBuilder; +import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ServerBuilder; +import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test External transforms. */ +@RunWith(JUnit4.class) +public class ExternalTest implements Serializable { + @Rule public transient TestPipeline testPipeline = TestPipeline.create(); + + private static final String TEST_URN_SIMPLE = "simple"; + private static final String TEST_URN_LE = "le"; + private static final String TEST_URN_MULTI = "multi"; + + private static String pythonServerCommand; + private static Integer expansionPort; + private static String localExpansionAddr; + private static Server localExpansionServer; + + @BeforeClass + public static void setUp() throws IOException { + pythonServerCommand = System.getProperty("pythonTestExpansionCommand"); + expansionPort = Integer.valueOf(System.getProperty("expansionPort")); + int localExpansionPort = expansionPort + 100; + localExpansionAddr = String.format("localhost:%s", localExpansionPort); + + localExpansionServer = + ServerBuilder.forPort(localExpansionPort).addService(new ExpansionService()).build(); + localExpansionServer.start(); + } + + @AfterClass + public static void tearDown() { + localExpansionServer.shutdownNow(); + } + + @Test + @Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class}) + public void expandSingleTest() { + PCollection<Integer> col = + testPipeline + .apply(Create.of(1, 2, 3)) + .apply(External.of(TEST_URN_SIMPLE, new byte[] {}, localExpansionAddr)); + PAssert.that(col).containsInAnyOrder(2, 3, 4); + testPipeline.run(); + } + + @Test + @Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class}) + public void expandMultipleTest() { + PCollection<Integer> pcol = + testPipeline + .apply(Create.of(1, 2, 3)) + .apply("add one", External.of(TEST_URN_SIMPLE, new byte[] {}, localExpansionAddr)) + .apply( + "filter <=3", + External.of(TEST_URN_LE, "3".getBytes(StandardCharsets.UTF_8), localExpansionAddr)); + + PAssert.that(pcol).containsInAnyOrder(2, 3); + testPipeline.run(); + } + + @Test + @Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class}) + public void expandMultiOutputTest() { + PCollectionTuple pTuple = + testPipeline + .apply(Create.of(1, 2, 3, 4, 5, 6)) + .apply( + External.of(TEST_URN_MULTI, new byte[] {}, localExpansionAddr).withMultiOutputs()); + + PAssert.that(pTuple.get(new TupleTag<Integer>("even") {})).containsInAnyOrder(2, 4, 6); + PAssert.that(pTuple.get(new TupleTag<Integer>("odd") {})).containsInAnyOrder(1, 3, 5); + testPipeline.run(); + } + + private Process runCommandline(String command) { + ProcessBuilder builder = new ProcessBuilder("sh", "-c", command); + try { + return builder.start(); + } catch (IOException e) { + throw new AssertionError("process launch failed."); + } + } + + @Test + @Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class}) + public void expandPythonTest() { + String target = String.format("localhost:%s", expansionPort); + Process p = runCommandline(String.format("%s -p %s", pythonServerCommand, expansionPort)); + try { + ManagedChannel channel = ManagedChannelBuilder.forTarget(target).build(); + ConnectivityState state = channel.getState(true); + for (int retry = 0; retry < 30 && state != ConnectivityState.READY; retry++) { + Thread.sleep(500); + state = channel.getState(true); + } + channel.shutdownNow(); + + PCollection<String> pCol = + testPipeline + .apply(Create.of("1", "2", "2", "3", "3", "3")) + .apply( + "toBytes", + MapElements.into(new TypeDescriptor<byte[]>() {}).via(String::getBytes)) + .apply(External.<byte[]>of("count_per_element_bytes", new byte[] {}, target)) + .apply("toString", MapElements.into(TypeDescriptors.strings()).via(String::new)); + + PAssert.that(pCol).containsInAnyOrder("1->1", "2->2", "3->3"); + testPipeline.run(); + } catch (InterruptedException e) { + throw new RuntimeException("interrupted."); + } finally { + p.destroyForcibly(); + } + } + + /** Test TransformProvider. */ + @AutoService(ExpansionService.ExpansionServiceRegistrar.class) + public static class TestTransforms + implements ExpansionService.ExpansionServiceRegistrar, Serializable { + private final TupleTag<Integer> even = new TupleTag<Integer>("even") {}; + private final TupleTag<Integer> odd = new TupleTag<Integer>("odd") {}; + + @Override + public Map<String, ExpansionService.TransformProvider> knownTransforms() { + return ImmutableMap.of( + TEST_URN_SIMPLE, + spec -> MapElements.into(TypeDescriptors.integers()).via((Integer x) -> x + 1), + TEST_URN_LE, + spec -> Filter.lessThanEq(Integer.parseInt(spec.getPayload().toStringUtf8())), + TEST_URN_MULTI, + spec -> + ParDo.of( + new DoFn<Integer, Integer>() { + @ProcessElement + public void processElement(ProcessContext c) { + if (c.element() % 2 == 0) { + c.output(c.element()); + } else { + c.output(odd, c.element()); + } + } + }) + .withOutputTags(even, TupleTagList.of(odd))); + } + } +} diff --git a/runners/direct-java/build.gradle b/runners/direct-java/build.gradle index b6fbbff..3b90565 100644 --- a/runners/direct-java/build.gradle +++ b/runners/direct-java/build.gradle @@ -117,6 +117,7 @@ task needsRunnerTests(type: Test) { excludeCategories "org.apache.beam.sdk.testing.LargeKeys\$Above100MB" // MetricsPusher isn't implemented in direct runner excludeCategories "org.apache.beam.sdk.testing.UsesMetricsPusher" + excludeCategories "org.apache.beam.sdk.testing.UsesCrossLanguageTransforms" } } @@ -140,6 +141,7 @@ task validatesRunner(type: Test) { includeCategories "org.apache.beam.sdk.testing.ValidatesRunner" excludeCategories "org.apache.beam.sdk.testing.LargeKeys\$Above100MB" excludeCategories 'org.apache.beam.sdk.testing.UsesMetricsPusher' + excludeCategories "org.apache.beam.sdk.testing.UsesCrossLanguageTransforms" } } @@ -167,6 +169,7 @@ createPortableValidatesRunnerTask( includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner' excludeCategories 'org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders' excludeCategories 'org.apache.beam.sdk.testing.LargeKeys$Above100MB' + excludeCategories 'org.apache.beam.sdk.testing.UsesCrossLanguageTransforms' excludeCategories 'org.apache.beam.sdk.testing.UsesCustomWindowMerging' excludeCategories 'org.apache.beam.sdk.testing.UsesDistributionMetrics' excludeCategories 'org.apache.beam.sdk.testing.UsesFailureMessage' diff --git a/runners/flink/job-server/flink_job_server.gradle b/runners/flink/job-server/flink_job_server.gradle index ebed8e4..969c957 100644 --- a/runners/flink/job-server/flink_job_server.gradle +++ b/runners/flink/job-server/flink_job_server.gradle @@ -132,6 +132,7 @@ def portableValidatesRunnerTask(String name, Boolean streaming) { excludeCategories 'org.apache.beam.sdk.testing.UsesAttemptedMetrics' excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics' excludeCategories 'org.apache.beam.sdk.testing.UsesCounterMetrics' + excludeCategories 'org.apache.beam.sdk.testing.UsesCrossLanguageTransforms' excludeCategories 'org.apache.beam.sdk.testing.UsesCustomWindowMerging' excludeCategories 'org.apache.beam.sdk.testing.UsesDistributionMetrics' excludeCategories 'org.apache.beam.sdk.testing.UsesFailureMessage' @@ -155,3 +156,24 @@ task validatesPortableRunner() { dependsOn validatesPortableRunnerBatch dependsOn validatesPortableRunnerStreaming } + +project.ext.validatesCrossLanguageTransforms = + createPortableValidatesRunnerTask( + name: "validatesCrossLanguageTransforms", + jobServerDriver: "org.apache.beam.runners.flink.FlinkJobServerDriver", + jobServerConfig: "--clean-artifacts-per-job,--job-host=localhost,--job-port=0,--artifact-port=0,--expansion-port=0", + testClasspathConfiguration: configurations.validatesPortableRunner, + numParallelTests: 1, + pipelineOpts: [ + // Limit resource consumption via parallelism + "--parallelism=2", + "--shutdownSourcesOnFinalWatermark", + ], + testCategories: { + // Only include cross-language transform tests. Avoid to retest everything on Docker environment. + includeCategories 'org.apache.beam.sdk.testing.UsesCrossLanguageTransforms' + }, + ) +project.evaluationDependsOn ':beam-sdks-python' +validatesCrossLanguageTransforms.dependsOn ':beam-sdks-python:setupVirtualenv' +validatesCrossLanguageTransforms.systemProperty "pythonTestExpansionCommand", ". ${project(':beam-sdks-python').envdir}/bin/activate && pip install -e ${project(':beam-sdks-python').projectDir}[test] && python -m apache_beam.runners.portability.expansion_service_test" diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle index 8280011..86d6f06 100644 --- a/runners/google-cloud-dataflow-java/build.gradle +++ b/runners/google-cloud-dataflow-java/build.gradle @@ -131,6 +131,7 @@ def fnApiPipelineOptions = [ def commonExcludeCategories = [ 'org.apache.beam.sdk.testing.LargeKeys$Above10MB', 'org.apache.beam.sdk.testing.UsesAttemptedMetrics', + 'org.apache.beam.sdk.testing.UsesCrossLanguageTransforms', 'org.apache.beam.sdk.testing.UsesDistributionMetrics', 'org.apache.beam.sdk.testing.UsesGaugeMetrics', 'org.apache.beam.sdk.testing.UsesSetState', diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesCrossLanguageTransforms.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesCrossLanguageTransforms.java new file mode 100644 index 0000000..b249eed --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesCrossLanguageTransforms.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.testing; + +/** + * Category tag for validation tests which use cross-language transforms. Tests tagged with {@link + * UsesCrossLanguageTransforms} should be run for runners which support cross-language transforms. + */ +public interface UsesCrossLanguageTransforms {} diff --git a/sdks/python/apache_beam/runners/portability/expansion_service.py b/sdks/python/apache_beam/runners/portability/expansion_service.py index e407892..55a526d 100644 --- a/sdks/python/apache_beam/runners/portability/expansion_service.py +++ b/sdks/python/apache_beam/runners/portability/expansion_service.py @@ -20,10 +20,6 @@ from __future__ import absolute_import from __future__ import print_function -import argparse -import logging -import sys -import time import traceback from apache_beam import pipeline as beam_pipeline @@ -43,7 +39,7 @@ class ExpansionServiceServicer( self._options = options or beam_pipeline.PipelineOptions( environment_type=python_urns.EMBEDDED_PYTHON) - def Expand(self, request): + def Expand(self, request, context): try: pipeline = beam_pipeline.Pipeline(options=self._options) @@ -98,21 +94,3 @@ class ExpansionServiceServicer( except Exception: # pylint: disable=broad-except return beam_expansion_api_pb2.ExpansionResponse( error=traceback.format_exc()) - - -def main(unused_argv): - parser = argparse.ArgumentParser() - parser.add_argument('-p', '--port', - type=int, - help='port on which to serve the job api') - options = parser.parse_args() - expansion_servicer = ExpansionServiceServicer() - port = expansion_servicer.start_grpc_server(options.port) - while True: - logging.info('Listening for expansion requests at %d', port) - time.sleep(300) - - -if __name__ == '__main__': - logging.getLogger().setLevel(logging.INFO) - main(sys.argv) diff --git a/sdks/python/apache_beam/runners/portability/expansion_service_test.py b/sdks/python/apache_beam/runners/portability/expansion_service_test.py new file mode 100644 index 0000000..9e93a64 --- /dev/null +++ b/sdks/python/apache_beam/runners/portability/expansion_service_test.py @@ -0,0 +1,166 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import + +import argparse +import concurrent.futures as futures +import logging +import signal +import sys + +import grpc + +import apache_beam as beam +import apache_beam.transforms.combiners as combine +from apache_beam.pipeline import PipelineOptions +from apache_beam.portability.api import beam_expansion_api_pb2_grpc +from apache_beam.runners.portability import expansion_service +from apache_beam.transforms import ptransform + +# This script provides an expansion service and example ptransforms for running +# external transform test cases. See external_test.py for details. + + +@ptransform.PTransform.register_urn('count_per_element_bytes', None) +class KV2BytesTransform(ptransform.PTransform): + def expand(self, pcoll): + return ( + pcoll + | combine.Count.PerElement() + | beam.Map( + lambda x: '{}->{}'.format(x[0], x[1])).with_output_types(bytes) + ) + + def to_runner_api_parameter(self, unused_context): + return 'kv_to_bytes', None + + @staticmethod + def from_runner_api_parameter(unused_parameter, unused_context): + return KV2BytesTransform() + + +@ptransform.PTransform.register_urn('simple', None) +class SimpleTransform(ptransform.PTransform): + def expand(self, pcoll): + return pcoll | 'TestLabel' >> beam.Map(lambda x: 'Simple(%s)' % x) + + def to_runner_api_parameter(self, unused_context): + return 'simple', None + + @staticmethod + def from_runner_api_parameter(unused_parameter, unused_context): + return SimpleTransform() + + +@ptransform.PTransform.register_urn('multi', None) +class MutltiTransform(ptransform.PTransform): + def expand(self, pcolls): + return { + 'main': + (pcolls['main1'], pcolls['main2']) + | beam.Flatten() + | beam.Map(lambda x, s: x + s, + beam.pvalue.AsSingleton(pcolls['side'])), + 'side': pcolls['side'] | beam.Map(lambda x: x + x), + } + + def to_runner_api_parameter(self, unused_context): + return 'multi', None + + @staticmethod + def from_runner_api_parameter(unused_parameter, unused_context): + return MutltiTransform() + + +@ptransform.PTransform.register_urn('payload', bytes) +class PayloadTransform(ptransform.PTransform): + def __init__(self, payload): + self._payload = payload + + def expand(self, pcoll): + return pcoll | beam.Map(lambda x, s: x + s, self._payload) + + def to_runner_api_parameter(self, unused_context): + return b'payload', self._payload.encode('ascii') + + @staticmethod + def from_runner_api_parameter(payload, unused_context): + return PayloadTransform(payload.decode('ascii')) + + +@ptransform.PTransform.register_urn('fib', bytes) +class FibTransform(ptransform.PTransform): + def __init__(self, level): + self._level = level + + def expand(self, p): + if self._level <= 2: + return p | beam.Create([1]) + else: + a = p | 'A' >> beam.ExternalTransform( + 'fib', str(self._level - 1).encode('ascii'), + expansion_service.ExpansionServiceServicer()) + b = p | 'B' >> beam.ExternalTransform( + 'fib', str(self._level - 2).encode('ascii'), + expansion_service.ExpansionServiceServicer()) + return ( + (a, b) + | beam.Flatten() + | beam.CombineGlobally(sum).without_defaults()) + + def to_runner_api_parameter(self, unused_context): + return 'fib', str(self._level).encode('ascii') + + @staticmethod + def from_runner_api_parameter(level, unused_context): + return FibTransform(int(level.decode('ascii'))) + + +server = None + + +def main(unused_argv): + parser = argparse.ArgumentParser() + parser.add_argument('-p', '--port', + type=int, + help='port on which to serve the job api') + options = parser.parse_args() + global server + server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) + beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server( + expansion_service.ExpansionServiceServicer(PipelineOptions()), server + ) + server.add_insecure_port('localhost:{}'.format(options.port)) + server.start() + logging.info('Listening for expansion requests at %d', options.port) + + # blocking main thread forever. + signal.pause() + + +def cleanup(unused_signum, unused_frame): + logging.info('Shutting down expansion service.') + server.stop(None) + + +signal.signal(signal.SIGTERM, cleanup) +signal.signal(signal.SIGINT, cleanup) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + main(sys.argv) diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index cd670b6..9077ebe 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -45,7 +45,6 @@ class ExternalTransform(ptransform.PTransform): _namespace_counter = 0 _namespace = threading.local() - _namespace.value = 'external' _EXPANDED_TRANSFORM_UNIQUE_NAME = 'root' _IMPULSE_PREFIX = 'impulse' @@ -63,9 +62,13 @@ class ExternalTransform(ptransform.PTransform): return '%s(%s)' % (self.__class__.__name__, self._urn) @classmethod + def get_local_namespace(cls): + return getattr(cls._namespace, 'value', 'external') + + @classmethod @contextlib.contextmanager def outer_namespace(cls, namespace): - prev = cls._namespace.value + prev = cls.get_local_namespace() cls._namespace.value = namespace yield cls._namespace.value = prev @@ -73,7 +76,7 @@ class ExternalTransform(ptransform.PTransform): @classmethod def _fresh_namespace(cls): ExternalTransform._namespace_counter += 1 - return '%s_%d' % (cls._namespace.value, cls._namespace_counter) + return '%s_%d' % (cls.get_local_namespace(), cls._namespace_counter) def expand(self, pvalueish): if isinstance(pvalueish, pvalue.PBegin): @@ -115,7 +118,7 @@ class ExternalTransform(ptransform.PTransform): response = beam_expansion_api_pb2_grpc.ExpansionServiceStub( channel).Expand(request) else: - response = self._endpoint.Expand(request) + response = self._endpoint.Expand(request, None) if response.error: raise RuntimeError(response.error) diff --git a/sdks/python/apache_beam/transforms/external_test.py b/sdks/python/apache_beam/transforms/external_test.py index c3448c6..481d673 100644 --- a/sdks/python/apache_beam/transforms/external_test.py +++ b/sdks/python/apache_beam/transforms/external_test.py @@ -32,9 +32,9 @@ from apache_beam import Pipeline from apache_beam.io.external.generate_sequence import GenerateSequence from apache_beam.portability import python_urns from apache_beam.runners.portability import expansion_service +from apache_beam.runners.portability.expansion_service_test import FibTransform from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to -from apache_beam.transforms import ptransform class ExternalTransformTest(unittest.TestCase): @@ -43,19 +43,6 @@ class ExternalTransformTest(unittest.TestCase): expansion_service_jar = None def test_pipeline_generation(self): - - @ptransform.PTransform.register_urn('simple', None) - class SimpleTransform(ptransform.PTransform): - def expand(self, pcoll): - return pcoll | 'TestLabel' >> beam.Map(lambda x: 'Simple(%s)' % x) - - def to_runner_api_parameter(self, unused_context): - return 'simple', None - - @staticmethod - def from_runner_api_parameter(unused_parameter, unused_context): - return SimpleTransform() - pipeline = beam.Pipeline() res = (pipeline | beam.Create(['a', 'b']) @@ -81,19 +68,6 @@ class ExternalTransformTest(unittest.TestCase): pipeline_from_proto.transforms_stack[0].parts[1].parts[0].full_label) def test_simple(self): - - @ptransform.PTransform.register_urn('simple', None) - class SimpleTransform(ptransform.PTransform): - def expand(self, pcoll): - return pcoll | beam.Map(lambda x: 'Simple(%s)' % x) - - def to_runner_api_parameter(self, unused_context): - return 'simple', None - - @staticmethod - def from_runner_api_parameter(unused_parameter, unused_context): - return SimpleTransform() - with beam.Pipeline() as p: res = ( p @@ -105,26 +79,6 @@ class ExternalTransformTest(unittest.TestCase): assert_that(res, equal_to(['Simple(a)', 'Simple(b)'])) def test_multi(self): - - @ptransform.PTransform.register_urn('multi', None) - class MutltiTransform(ptransform.PTransform): - def expand(self, pcolls): - return { - 'main': - (pcolls['main1'], pcolls['main2']) - | beam.Flatten() - | beam.Map(lambda x, s: x + s, - beam.pvalue.AsSingleton(pcolls['side'])), - 'side': pcolls['side'] | beam.Map(lambda x: x + x), - } - - def to_runner_api_parameter(self, unused_context): - return 'multi', None - - @staticmethod - def from_runner_api_parameter(unused_parameter, unused_context): - return MutltiTransform() - with beam.Pipeline() as p: main1 = p | 'Main1' >> beam.Create(['a', 'bb'], reshuffle=False) main2 = p | 'Main2' >> beam.Create(['x', 'yy', 'zzz'], reshuffle=False) @@ -135,22 +89,6 @@ class ExternalTransformTest(unittest.TestCase): assert_that(res['side'], equal_to(['ss']), label='CheckSide') def test_payload(self): - - @ptransform.PTransform.register_urn('payload', bytes) - class PayloadTransform(ptransform.PTransform): - def __init__(self, payload): - self._payload = payload - - def expand(self, pcoll): - return pcoll | beam.Map(lambda x, s: x + s, self._payload) - - def to_runner_api_parameter(self, unused_context): - return b'payload', self._payload.encode('ascii') - - @staticmethod - def from_runner_api_parameter(payload, unused_context): - return PayloadTransform(payload.decode('ascii')) - with beam.Pipeline() as p: res = ( p @@ -161,33 +99,6 @@ class ExternalTransformTest(unittest.TestCase): assert_that(res, equal_to(['as', 'bbs'])) def test_nested(self): - @ptransform.PTransform.register_urn('fib', bytes) - class FibTransform(ptransform.PTransform): - def __init__(self, level): - self._level = level - - def expand(self, p): - if self._level <= 2: - return p | beam.Create([1]) - else: - a = p | 'A' >> beam.ExternalTransform( - 'fib', str(self._level - 1).encode('ascii'), - expansion_service.ExpansionServiceServicer()) - b = p | 'B' >> beam.ExternalTransform( - 'fib', str(self._level - 2).encode('ascii'), - expansion_service.ExpansionServiceServicer()) - return ( - (a, b) - | beam.Flatten() - | beam.CombineGlobally(sum).without_defaults()) - - def to_runner_api_parameter(self, unused_context): - return 'fib', str(self._level).encode('ascii') - - @staticmethod - def from_runner_api_parameter(level, unused_context): - return FibTransform(int(level.decode('ascii'))) - with beam.Pipeline() as p: assert_that(p | FibTransform(6), equal_to([8]))