Repository: beam Updated Branches: refs/heads/master 1be6f67aa -> 2040e2bd4
Add CombineTranslation This translates Combines to CombinePayloads and back Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/5b899a85 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/5b899a85 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/5b899a85 Branch: refs/heads/master Commit: 5b899a8518cc8910f0a855303c14088d72b332e5 Parents: 4ec3366 Author: Thomas Groh <[email protected]> Authored: Thu May 18 10:23:16 2017 -0700 Committer: Thomas Groh <[email protected]> Committed: Wed May 24 13:04:55 2017 -0700 ---------------------------------------------------------------------- .../core/construction/CombineTranslation.java | 125 ++++++++++++++++++ .../construction/CombineTranslationTest.java | 130 +++++++++++++++++++ .../org/apache/beam/sdk/transforms/Count.java | 10 ++ .../org/apache/beam/sdk/transforms/Sum.java | 30 +++++ 4 files changed, 295 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/5b899a85/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java ---------------------------------------------------------------------- diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java new file mode 100644 index 0000000..e0b6d5c --- /dev/null +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core.construction; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.common.collect.Iterables; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.common.runner.v1.RunnerApi.CombinePayload; +import org.apache.beam.sdk.common.runner.v1.RunnerApi.FunctionSpec; +import org.apache.beam.sdk.common.runner.v1.RunnerApi.SdkFunctionSpec; +import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.AppliedCombineFn; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; + +/** + * Methods for translating between {@link Combine.PerKey} {@link PTransform PTransforms} and {@link + * RunnerApi.CombinePayload} protos. + */ +public class CombineTranslation { + private static final String JAVA_SERIALIZED_COMBINE_FN_URN = "urn:beam:java:combinefn:v1"; + + public static CombinePayload toProto( + AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>> combine, SdkComponents sdkComponents) + throws IOException { + GlobalCombineFn<?, ?, ?> combineFn = combine.getTransform().getFn(); + try { + Coder<?> accumulatorCoder = extractAccumulatorCoder(combineFn, (AppliedPTransform) combine); + Map<String, SideInput> sideInputs = new HashMap<>(); + return RunnerApi.CombinePayload.newBuilder() + .setAccumulatorCoderId(sdkComponents.registerCoder(accumulatorCoder)) + .putAllSideInputs(sideInputs) + .setCombineFn(toProto(combineFn)) + .build(); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException(e); + } + } + + private static <K, InputT, AccumT> Coder<AccumT> extractAccumulatorCoder( + GlobalCombineFn<InputT, AccumT, ?> combineFn, + AppliedPTransform<PCollection<KV<K, InputT>>, ?, Combine.PerKey<K, InputT, ?>> transform) + throws CannotProvideCoderException { + KvCoder<K, InputT> inputCoder = + (KvCoder<K, InputT>) + ((PCollection<KV<K, InputT>>) Iterables.getOnlyElement(transform.getInputs().values())) + .getCoder(); + return AppliedCombineFn.withInputCoder( + combineFn, + transform.getPipeline().getCoderRegistry(), + inputCoder, + transform.getTransform().getSideInputs(), + ((PCollection<?>) Iterables.getOnlyElement(transform.getOutputs().values())) + .getWindowingStrategy()) + .getAccumulatorCoder(); + } + + private static SdkFunctionSpec toProto(GlobalCombineFn<?, ?, ?> combineFn) { + return SdkFunctionSpec.newBuilder() + // TODO: Set Java SDK Environment URN + .setSpec( + FunctionSpec.newBuilder() + .setUrn(JAVA_SERIALIZED_COMBINE_FN_URN) + .setParameter( + Any.pack( + BytesValue.newBuilder() + .setValue( + ByteString.copyFrom( + SerializableUtils.serializeToByteArray(combineFn))) + .build()))) + .build(); + } + + public static Coder<?> getAccumulatorCoder( + CombinePayload payload, RunnerApi.Components components) throws IOException { + String id = payload.getAccumulatorCoderId(); + return Coders.fromProto(components.getCodersOrThrow(id), components); + } + + public static GlobalCombineFn<?, ?, ?> getCombineFn(CombinePayload payload) + throws IOException { + checkArgument(payload.getCombineFn().getSpec().getUrn().equals(JAVA_SERIALIZED_COMBINE_FN_URN)); + return (GlobalCombineFn<?, ?, ?>) + SerializableUtils.deserializeFromByteArray( + payload + .getCombineFn() + .getSpec() + .getParameter() + .unpack(BytesValue.class) + .getValue() + .toByteArray(), + "CombineFn"); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/5b899a85/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java ---------------------------------------------------------------------- diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java new file mode 100644 index 0000000..6251545 --- /dev/null +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core.construction; + +import static com.google.common.base.Preconditions.checkState; +import static org.junit.Assert.assertEquals; + +import com.google.common.collect.ImmutableList; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.sdk.Pipeline.PipelineVisitor; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.common.runner.v1.RunnerApi.CombinePayload; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.runners.TransformHierarchy.Node; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Combine.BinaryCombineIntegerFn; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.values.PCollection; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +/** + * Tests for {@link CombineTranslation}. + */ +@RunWith(Parameterized.class) +public class CombineTranslationTest { + @Parameters(name = "{index}: {0}") + public static Iterable<Combine.CombineFn<Integer, ?, ?>> params() { + BinaryCombineIntegerFn sum = Sum.ofIntegers(); + CombineFn<Integer, ?, Long> count = Count.combineFn(); + TestCombineFn test = new TestCombineFn(); + return ImmutableList.<CombineFn<Integer, ?, ?>>builder().add(sum).add(count).add(test).build(); + } + + @Rule public TestPipeline pipeline = TestPipeline.create(); + @Parameter(0) + public Combine.CombineFn<Integer, ?, ?> combineFn; + + @Test + public void testToFromProto() throws Exception { + PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3)); + input.apply(Combine.globally(combineFn)); + final AtomicReference<AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>>> combine = + new AtomicReference<>(); + pipeline.traverseTopologically( + new PipelineVisitor.Defaults() { + @Override + public void leaveCompositeTransform(Node node) { + if (node.getTransform() instanceof Combine.PerKey) { + checkState(combine.get() == null); + combine.set((AppliedPTransform) node.toAppliedPTransform(getPipeline())); + } + } + }); + checkState(combine.get() != null); + + SdkComponents sdkComponents = SdkComponents.create(); + CombinePayload combineProto = CombineTranslation.toProto(combine.get(), sdkComponents); + RunnerApi.Components componentsProto = sdkComponents.toComponents(); + + assertEquals( + combineFn.getAccumulatorCoder(pipeline.getCoderRegistry(), input.getCoder()), + CombineTranslation.getAccumulatorCoder(combineProto, componentsProto)); + assertEquals(combineFn, CombineTranslation.getCombineFn(combineProto)); + } + + private static class TestCombineFn extends Combine.CombineFn<Integer, Void, Void> { + @Override + public Void createAccumulator() { + return null; + } + + @Override + public Coder<Void> getAccumulatorCoder(CoderRegistry registry, Coder<Integer> inputCoder) { + return (Coder) VoidCoder.of(); + } + + @Override + public Void extractOutput(Void accumulator) { + return accumulator; + } + + @Override + public Void mergeAccumulators(Iterable<Void> accumulators) { + return null; + } + + @Override + public Void addInput(Void accumulator, Integer input) { + return accumulator; + } + + @Override + public boolean equals(Object other) { + return other != null && other.getClass().equals(TestCombineFn.class); + } + + @Override + public int hashCode() { + return TestCombineFn.class.hashCode(); + } + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/5b899a85/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Count.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Count.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Count.java index b405dd1..ee24b3f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Count.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Count.java @@ -195,5 +195,15 @@ public class Count { } }; } + + @Override + public boolean equals(Object other) { + return other != null && getClass().equals(other.getClass()); + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } } } http://git-wip-us.apache.org/repos/asf/beam/blob/5b899a85/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sum.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sum.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sum.java index ccade4d..6b65416 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sum.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sum.java @@ -151,6 +151,16 @@ public class Sum { public int identity() { return 0; } + + @Override + public boolean equals(Object other) { + return other != null && other.getClass().equals(this.getClass()); + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } } private static class SumLongFn extends Combine.BinaryCombineLongFn { @@ -164,6 +174,16 @@ public class Sum { public long identity() { return 0; } + + @Override + public boolean equals(Object other) { + return other != null && other.getClass().equals(this.getClass()); + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } } private static class SumDoubleFn extends Combine.BinaryCombineDoubleFn { @@ -177,5 +197,15 @@ public class Sum { public double identity() { return 0; } + + @Override + public boolean equals(Object other) { + return other != null && other.getClass().equals(this.getClass()); + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } } }
