[BEAM-1347] Rename DoFnRunnerFactory to FnApiDoFnRunner.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/2295b905 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/2295b905 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/2295b905 Branch: refs/heads/master Commit: 2295b905ecb055d9348170948e25f89665dd647d Parents: f62586a Author: Luke Cwik <[email protected]> Authored: Fri Jun 23 14:31:58 2017 -0700 Committer: Luke Cwik <[email protected]> Committed: Fri Jul 7 13:44:22 2017 -0700 ---------------------------------------------------------------------- .../beam/runners/core/DoFnRunnerFactory.java | 182 ---------------- .../beam/runners/core/FnApiDoFnRunner.java | 182 ++++++++++++++++ .../runners/core/DoFnRunnerFactoryTest.java | 209 ------------------- .../beam/runners/core/FnApiDoFnRunnerTest.java | 209 +++++++++++++++++++ 4 files changed, 391 insertions(+), 391 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/2295b905/sdks/java/harness/src/main/java/org/apache/beam/runners/core/DoFnRunnerFactory.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/DoFnRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/DoFnRunnerFactory.java deleted file mode 100644 index 3c0b6eb..0000000 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/DoFnRunnerFactory.java +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.core; - -import static com.google.common.base.Preconditions.checkArgument; - -import com.google.auto.service.AutoService; -import com.google.common.collect.Collections2; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; -import com.google.common.collect.Multimap; -import com.google.protobuf.ByteString; -import com.google.protobuf.BytesValue; -import com.google.protobuf.InvalidProtocolBufferException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.Map; -import java.util.Objects; -import java.util.function.Consumer; -import java.util.function.Supplier; -import org.apache.beam.fn.harness.data.BeamFnDataClient; -import org.apache.beam.fn.harness.fake.FakeStepContext; -import org.apache.beam.fn.harness.fn.ThrowingConsumer; -import org.apache.beam.fn.harness.fn.ThrowingRunnable; -import org.apache.beam.runners.core.DoFnRunners.OutputManager; -import org.apache.beam.runners.dataflow.util.DoFnInfo; -import org.apache.beam.sdk.common.runner.v1.RunnerApi; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.util.SerializableUtils; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.WindowingStrategy; - -/** - * Classes associated with converting {@link RunnerApi.PTransform}s to {@link DoFnRunner}s. - * - * <p>TODO: Move DoFnRunners into SDK harness and merge the methods below into it removing this - * class. - */ -public class DoFnRunnerFactory { - - private static final String URN = "urn:org.apache.beam:dofn:java:0.1"; - - /** A registrar which provides a factory to handle Java {@link DoFn}s. */ - @AutoService(PTransformRunnerFactory.Registrar.class) - public static class Registrar implements - PTransformRunnerFactory.Registrar { - - @Override - public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() { - return ImmutableMap.of(URN, new Factory()); - } - } - - /** A factory for {@link DoFnRunner}s. */ - static class Factory<InputT, OutputT> - implements PTransformRunnerFactory<DoFnRunner<InputT, OutputT>> { - - @Override - public DoFnRunner<InputT, OutputT> createRunnerForPTransform( - PipelineOptions pipelineOptions, - BeamFnDataClient beamFnDataClient, - String pTransformId, - RunnerApi.PTransform pTransform, - Supplier<String> processBundleInstructionId, - Map<String, RunnerApi.PCollection> pCollections, - Map<String, RunnerApi.Coder> coders, - Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, - Consumer<ThrowingRunnable> addStartFunction, - Consumer<ThrowingRunnable> addFinishFunction) { - - // For every output PCollection, create a map from output name to Consumer - ImmutableMap.Builder<String, Collection<ThrowingConsumer<WindowedValue<?>>>> - outputMapBuilder = ImmutableMap.builder(); - for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) { - outputMapBuilder.put( - entry.getKey(), - pCollectionIdsToConsumers.get(entry.getValue())); - } - ImmutableMap<String, Collection<ThrowingConsumer<WindowedValue<?>>>> outputMap = - outputMapBuilder.build(); - - // Get the DoFnInfo from the serialized blob. - ByteString serializedFn; - try { - serializedFn = pTransform.getSpec().getParameter().unpack(BytesValue.class).getValue(); - } catch (InvalidProtocolBufferException e) { - throw new IllegalArgumentException( - String.format("Unable to unwrap DoFn %s", pTransform.getSpec()), e); - } - DoFnInfo<?, ?> doFnInfo = - (DoFnInfo<?, ?>) - SerializableUtils.deserializeFromByteArray(serializedFn.toByteArray(), "DoFnInfo"); - - // Verify that the DoFnInfo tag to output map matches the output map on the PTransform. - checkArgument( - Objects.equals( - new HashSet<>(Collections2.transform(outputMap.keySet(), Long::parseLong)), - doFnInfo.getOutputMap().keySet()), - "Unexpected mismatch between transform output map %s and DoFnInfo output map %s.", - outputMap.keySet(), - doFnInfo.getOutputMap()); - - ImmutableMultimap.Builder<TupleTag<?>, - ThrowingConsumer<WindowedValue<OutputT>>> tagToOutput = - ImmutableMultimap.builder(); - for (Map.Entry<Long, TupleTag<?>> entry : doFnInfo.getOutputMap().entrySet()) { - @SuppressWarnings({"unchecked", "rawtypes"}) - Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers = - (Collection) outputMap.get(Long.toString(entry.getKey())); - tagToOutput.putAll(entry.getValue(), consumers); - } - - @SuppressWarnings({"unchecked", "rawtypes"}) - Map<TupleTag<?>, Collection<ThrowingConsumer<WindowedValue<?>>>> tagBasedOutputMap = - (Map) tagToOutput.build().asMap(); - - OutputManager outputManager = - new OutputManager() { - Map<TupleTag<?>, Collection<ThrowingConsumer<WindowedValue<?>>>> tupleTagToOutput = - tagBasedOutputMap; - - @Override - public <T> void output(TupleTag<T> tag, WindowedValue<T> output) { - try { - Collection<ThrowingConsumer<WindowedValue<?>>> consumers = - tupleTagToOutput.get(tag); - if (consumers == null) { - /* This is a normal case, e.g., if a DoFn has output but that output is not - * consumed. Drop the output. */ - return; - } - for (ThrowingConsumer<WindowedValue<?>> consumer : consumers) { - consumer.accept(output); - } - } catch (Throwable t) { - throw new RuntimeException(t); - } - } - }; - - @SuppressWarnings({"unchecked", "rawtypes", "deprecation"}) - DoFnRunner<InputT, OutputT> runner = - DoFnRunners.simpleRunner( - pipelineOptions, - (DoFn) doFnInfo.getDoFn(), - NullSideInputReader.empty(), /* TODO */ - outputManager, - (TupleTag) doFnInfo.getOutputMap().get(doFnInfo.getMainOutput()), - new ArrayList<>(doFnInfo.getOutputMap().values()), - new FakeStepContext(), - (WindowingStrategy) doFnInfo.getWindowingStrategy()); - - // Register the appropriate handlers. - addStartFunction.accept(runner::startBundle); - for (String pcollectionId : pTransform.getInputsMap().values()) { - pCollectionIdsToConsumers.put( - pcollectionId, - (ThrowingConsumer) (ThrowingConsumer<WindowedValue<InputT>>) runner::processElement); - } - addFinishFunction.accept(runner::finishBundle); - return runner; - } - } -} http://git-wip-us.apache.org/repos/asf/beam/blob/2295b905/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java new file mode 100644 index 0000000..adf735a --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.auto.service.AutoService; +import com.google.common.collect.Collections2; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.Multimap; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import com.google.protobuf.InvalidProtocolBufferException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.apache.beam.fn.harness.data.BeamFnDataClient; +import org.apache.beam.fn.harness.fake.FakeStepContext; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.runners.core.DoFnRunners.OutputManager; +import org.apache.beam.runners.dataflow.util.DoFnInfo; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; + +/** + * Classes associated with converting {@link RunnerApi.PTransform}s to {@link DoFnRunner}s. + * + * <p>TODO: Move DoFnRunners into SDK harness and merge the methods below into it removing this + * class. + */ +public class FnApiDoFnRunner { + + private static final String URN = "urn:org.apache.beam:dofn:java:0.1"; + + /** A registrar which provides a factory to handle Java {@link DoFn}s. */ + @AutoService(PTransformRunnerFactory.Registrar.class) + public static class Registrar implements + PTransformRunnerFactory.Registrar { + + @Override + public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() { + return ImmutableMap.of(URN, new Factory()); + } + } + + /** A factory for {@link DoFnRunner}s. */ + static class Factory<InputT, OutputT> + implements PTransformRunnerFactory<DoFnRunner<InputT, OutputT>> { + + @Override + public DoFnRunner<InputT, OutputT> createRunnerForPTransform( + PipelineOptions pipelineOptions, + BeamFnDataClient beamFnDataClient, + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier<String> processBundleInstructionId, + Map<String, RunnerApi.PCollection> pCollections, + Map<String, RunnerApi.Coder> coders, + Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, + Consumer<ThrowingRunnable> addStartFunction, + Consumer<ThrowingRunnable> addFinishFunction) { + + // For every output PCollection, create a map from output name to Consumer + ImmutableMap.Builder<String, Collection<ThrowingConsumer<WindowedValue<?>>>> + outputMapBuilder = ImmutableMap.builder(); + for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) { + outputMapBuilder.put( + entry.getKey(), + pCollectionIdsToConsumers.get(entry.getValue())); + } + ImmutableMap<String, Collection<ThrowingConsumer<WindowedValue<?>>>> outputMap = + outputMapBuilder.build(); + + // Get the DoFnInfo from the serialized blob. + ByteString serializedFn; + try { + serializedFn = pTransform.getSpec().getParameter().unpack(BytesValue.class).getValue(); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException( + String.format("Unable to unwrap DoFn %s", pTransform.getSpec()), e); + } + DoFnInfo<?, ?> doFnInfo = + (DoFnInfo<?, ?>) + SerializableUtils.deserializeFromByteArray(serializedFn.toByteArray(), "DoFnInfo"); + + // Verify that the DoFnInfo tag to output map matches the output map on the PTransform. + checkArgument( + Objects.equals( + new HashSet<>(Collections2.transform(outputMap.keySet(), Long::parseLong)), + doFnInfo.getOutputMap().keySet()), + "Unexpected mismatch between transform output map %s and DoFnInfo output map %s.", + outputMap.keySet(), + doFnInfo.getOutputMap()); + + ImmutableMultimap.Builder<TupleTag<?>, + ThrowingConsumer<WindowedValue<OutputT>>> tagToOutput = + ImmutableMultimap.builder(); + for (Map.Entry<Long, TupleTag<?>> entry : doFnInfo.getOutputMap().entrySet()) { + @SuppressWarnings({"unchecked", "rawtypes"}) + Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers = + (Collection) outputMap.get(Long.toString(entry.getKey())); + tagToOutput.putAll(entry.getValue(), consumers); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + Map<TupleTag<?>, Collection<ThrowingConsumer<WindowedValue<?>>>> tagBasedOutputMap = + (Map) tagToOutput.build().asMap(); + + OutputManager outputManager = + new OutputManager() { + Map<TupleTag<?>, Collection<ThrowingConsumer<WindowedValue<?>>>> tupleTagToOutput = + tagBasedOutputMap; + + @Override + public <T> void output(TupleTag<T> tag, WindowedValue<T> output) { + try { + Collection<ThrowingConsumer<WindowedValue<?>>> consumers = + tupleTagToOutput.get(tag); + if (consumers == null) { + /* This is a normal case, e.g., if a DoFn has output but that output is not + * consumed. Drop the output. */ + return; + } + for (ThrowingConsumer<WindowedValue<?>> consumer : consumers) { + consumer.accept(output); + } + } catch (Throwable t) { + throw new RuntimeException(t); + } + } + }; + + @SuppressWarnings({"unchecked", "rawtypes", "deprecation"}) + DoFnRunner<InputT, OutputT> runner = + DoFnRunners.simpleRunner( + pipelineOptions, + (DoFn) doFnInfo.getDoFn(), + NullSideInputReader.empty(), /* TODO */ + outputManager, + (TupleTag) doFnInfo.getOutputMap().get(doFnInfo.getMainOutput()), + new ArrayList<>(doFnInfo.getOutputMap().values()), + new FakeStepContext(), + (WindowingStrategy) doFnInfo.getWindowingStrategy()); + + // Register the appropriate handlers. + addStartFunction.accept(runner::startBundle); + for (String pcollectionId : pTransform.getInputsMap().values()) { + pCollectionIdsToConsumers.put( + pcollectionId, + (ThrowingConsumer) (ThrowingConsumer<WindowedValue<InputT>>) runner::processElement); + } + addFinishFunction.accept(runner::finishBundle); + return runner; + } + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/2295b905/sdks/java/harness/src/test/java/org/apache/beam/runners/core/DoFnRunnerFactoryTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/DoFnRunnerFactoryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/DoFnRunnerFactoryTest.java deleted file mode 100644 index 62646ff..0000000 --- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/DoFnRunnerFactoryTest.java +++ /dev/null @@ -1,209 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.beam.runners.core; - -import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow; -import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; -import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.fail; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.base.Suppliers; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterables; -import com.google.common.collect.Multimap; -import com.google.protobuf.Any; -import com.google.protobuf.ByteString; -import com.google.protobuf.BytesValue; -import com.google.protobuf.Message; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.ServiceLoader; -import org.apache.beam.fn.harness.fn.ThrowingConsumer; -import org.apache.beam.fn.harness.fn.ThrowingRunnable; -import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar; -import org.apache.beam.runners.dataflow.util.CloudObjects; -import org.apache.beam.runners.dataflow.util.DoFnInfo; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.common.runner.v1.RunnerApi; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.GlobalWindow; -import org.apache.beam.sdk.util.SerializableUtils; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.WindowingStrategy; -import org.hamcrest.collection.IsMapContaining; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests for {@link DoFnRunnerFactory}. */ -@RunWith(JUnit4.class) -public class DoFnRunnerFactoryTest { - - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - private static final Coder<WindowedValue<String>> STRING_CODER = - WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); - private static final String STRING_CODER_SPEC_ID = "999L"; - private static final RunnerApi.Coder STRING_CODER_SPEC; - private static final String URN = "urn:org.apache.beam:dofn:java:0.1"; - - static { - try { - STRING_CODER_SPEC = RunnerApi.Coder.newBuilder() - .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() - .setSpec(RunnerApi.FunctionSpec.newBuilder() - .setParameter(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( - OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(STRING_CODER)))) - .build()))) - .build()) - .build(); - } catch (IOException e) { - throw new ExceptionInInitializerError(e); - } - } - - private static class TestDoFn extends DoFn<String, String> { - private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput"); - private static final TupleTag<String> additionalOutput = new TupleTag<>("output"); - - private BoundedWindow window; - - @ProcessElement - public void processElement(ProcessContext context, BoundedWindow window) { - context.output("MainOutput" + context.element()); - context.output(additionalOutput, "AdditionalOutput" + context.element()); - this.window = window; - } - - @FinishBundle - public void finishBundle(FinishBundleContext context) { - if (window != null) { - context.output("FinishBundle", window.maxTimestamp(), window); - window = null; - } - } - } - - /** - * Create a DoFn that has 3 inputs (inputATarget1, inputATarget2, inputBTarget) and 2 outputs - * (mainOutput, output). Validate that inputs are fed to the {@link DoFn} and that outputs - * are directed to the correct consumers. - */ - @Test - public void testCreatingAndProcessingDoFn() throws Exception { - Map<String, Message> fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); - String pTransformId = "pTransformId"; - String mainOutputId = "101"; - String additionalOutputId = "102"; - - DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn( - new TestDoFn(), - WindowingStrategy.globalDefault(), - ImmutableList.of(), - StringUtf8Coder.of(), - Long.parseLong(mainOutputId), - ImmutableMap.of( - Long.parseLong(mainOutputId), TestDoFn.mainOutput, - Long.parseLong(additionalOutputId), TestDoFn.additionalOutput)); - RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() - .setUrn("urn:org.apache.beam:dofn:java:0.1") - .setParameter(Any.pack(BytesValue.newBuilder() - .setValue(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo))) - .build())) - .build(); - RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() - .setSpec(functionSpec) - .putInputs("inputA", "inputATarget") - .putInputs("inputB", "inputBTarget") - .putOutputs(mainOutputId, "mainOutputTarget") - .putOutputs(additionalOutputId, "additionalOutputTarget") - .build(); - - List<WindowedValue<String>> mainOutputValues = new ArrayList<>(); - List<WindowedValue<String>> additionalOutputValues = new ArrayList<>(); - Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create(); - consumers.put("mainOutputTarget", - (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) mainOutputValues::add); - consumers.put("additionalOutputTarget", - (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) additionalOutputValues::add); - List<ThrowingRunnable> startFunctions = new ArrayList<>(); - List<ThrowingRunnable> finishFunctions = new ArrayList<>(); - - new DoFnRunnerFactory.Factory<>().createRunnerForPTransform( - PipelineOptionsFactory.create(), - null /* beamFnDataClient */, - pTransformId, - pTransform, - Suppliers.ofInstance("57L")::get, - ImmutableMap.of(), - ImmutableMap.of(), - consumers, - startFunctions::add, - finishFunctions::add); - - Iterables.getOnlyElement(startFunctions).run(); - mainOutputValues.clear(); - - assertThat(consumers.keySet(), containsInAnyOrder( - "inputATarget", "inputBTarget", "mainOutputTarget", "additionalOutputTarget")); - - Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A1")); - Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A2")); - Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("B")); - assertThat(mainOutputValues, contains( - valueInGlobalWindow("MainOutputA1"), - valueInGlobalWindow("MainOutputA2"), - valueInGlobalWindow("MainOutputB"))); - assertThat(additionalOutputValues, contains( - valueInGlobalWindow("AdditionalOutputA1"), - valueInGlobalWindow("AdditionalOutputA2"), - valueInGlobalWindow("AdditionalOutputB"))); - mainOutputValues.clear(); - additionalOutputValues.clear(); - - Iterables.getOnlyElement(finishFunctions).run(); - assertThat( - mainOutputValues, - contains( - timestampedValueInGlobalWindow("FinishBundle", GlobalWindow.INSTANCE.maxTimestamp()))); - mainOutputValues.clear(); - } - - @Test - public void testRegistration() { - for (Registrar registrar : - ServiceLoader.load(Registrar.class)) { - if (registrar instanceof DoFnRunnerFactory.Registrar) { - assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN)); - return; - } - } - fail("Expected registrar not found."); - } -} http://git-wip-us.apache.org/repos/asf/beam/blob/2295b905/sdks/java/harness/src/test/java/org/apache/beam/runners/core/FnApiDoFnRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/FnApiDoFnRunnerTest.java new file mode 100644 index 0000000..ae5cbac --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/FnApiDoFnRunnerTest.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core; + +import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow; +import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Suppliers; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import com.google.protobuf.Message; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.ServiceLoader; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar; +import org.apache.beam.runners.dataflow.util.CloudObjects; +import org.apache.beam.runners.dataflow.util.DoFnInfo; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.hamcrest.collection.IsMapContaining; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link FnApiDoFnRunner}. */ +@RunWith(JUnit4.class) +public class FnApiDoFnRunnerTest { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final Coder<WindowedValue<String>> STRING_CODER = + WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); + private static final String STRING_CODER_SPEC_ID = "999L"; + private static final RunnerApi.Coder STRING_CODER_SPEC; + private static final String URN = "urn:org.apache.beam:dofn:java:0.1"; + + static { + try { + STRING_CODER_SPEC = RunnerApi.Coder.newBuilder() + .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder() + .setParameter(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( + OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(STRING_CODER)))) + .build()))) + .build()) + .build(); + } catch (IOException e) { + throw new ExceptionInInitializerError(e); + } + } + + private static class TestDoFn extends DoFn<String, String> { + private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput"); + private static final TupleTag<String> additionalOutput = new TupleTag<>("output"); + + private BoundedWindow window; + + @ProcessElement + public void processElement(ProcessContext context, BoundedWindow window) { + context.output("MainOutput" + context.element()); + context.output(additionalOutput, "AdditionalOutput" + context.element()); + this.window = window; + } + + @FinishBundle + public void finishBundle(FinishBundleContext context) { + if (window != null) { + context.output("FinishBundle", window.maxTimestamp(), window); + window = null; + } + } + } + + /** + * Create a DoFn that has 3 inputs (inputATarget1, inputATarget2, inputBTarget) and 2 outputs + * (mainOutput, output). Validate that inputs are fed to the {@link DoFn} and that outputs + * are directed to the correct consumers. + */ + @Test + public void testCreatingAndProcessingDoFn() throws Exception { + Map<String, Message> fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); + String pTransformId = "pTransformId"; + String mainOutputId = "101"; + String additionalOutputId = "102"; + + DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn( + new TestDoFn(), + WindowingStrategy.globalDefault(), + ImmutableList.of(), + StringUtf8Coder.of(), + Long.parseLong(mainOutputId), + ImmutableMap.of( + Long.parseLong(mainOutputId), TestDoFn.mainOutput, + Long.parseLong(additionalOutputId), TestDoFn.additionalOutput)); + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() + .setUrn("urn:org.apache.beam:dofn:java:0.1") + .setParameter(Any.pack(BytesValue.newBuilder() + .setValue(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo))) + .build())) + .build(); + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putInputs("inputA", "inputATarget") + .putInputs("inputB", "inputBTarget") + .putOutputs(mainOutputId, "mainOutputTarget") + .putOutputs(additionalOutputId, "additionalOutputTarget") + .build(); + + List<WindowedValue<String>> mainOutputValues = new ArrayList<>(); + List<WindowedValue<String>> additionalOutputValues = new ArrayList<>(); + Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create(); + consumers.put("mainOutputTarget", + (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) mainOutputValues::add); + consumers.put("additionalOutputTarget", + (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) additionalOutputValues::add); + List<ThrowingRunnable> startFunctions = new ArrayList<>(); + List<ThrowingRunnable> finishFunctions = new ArrayList<>(); + + new FnApiDoFnRunner.Factory<>().createRunnerForPTransform( + PipelineOptionsFactory.create(), + null /* beamFnDataClient */, + pTransformId, + pTransform, + Suppliers.ofInstance("57L")::get, + ImmutableMap.of(), + ImmutableMap.of(), + consumers, + startFunctions::add, + finishFunctions::add); + + Iterables.getOnlyElement(startFunctions).run(); + mainOutputValues.clear(); + + assertThat(consumers.keySet(), containsInAnyOrder( + "inputATarget", "inputBTarget", "mainOutputTarget", "additionalOutputTarget")); + + Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A1")); + Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A2")); + Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("B")); + assertThat(mainOutputValues, contains( + valueInGlobalWindow("MainOutputA1"), + valueInGlobalWindow("MainOutputA2"), + valueInGlobalWindow("MainOutputB"))); + assertThat(additionalOutputValues, contains( + valueInGlobalWindow("AdditionalOutputA1"), + valueInGlobalWindow("AdditionalOutputA2"), + valueInGlobalWindow("AdditionalOutputB"))); + mainOutputValues.clear(); + additionalOutputValues.clear(); + + Iterables.getOnlyElement(finishFunctions).run(); + assertThat( + mainOutputValues, + contains( + timestampedValueInGlobalWindow("FinishBundle", GlobalWindow.INSTANCE.maxTimestamp()))); + mainOutputValues.clear(); + } + + @Test + public void testRegistration() { + for (Registrar registrar : + ServiceLoader.load(Registrar.class)) { + if (registrar instanceof FnApiDoFnRunner.Registrar) { + assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN)); + return; + } + } + fail("Expected registrar not found."); + } +}
