Add custom rehydration for Combine
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/92209c32 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/92209c32 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/92209c32 Branch: refs/heads/master Commit: 92209c323eb54e8a57b496eb2035da44fec00714 Parents: 6abf6f5 Author: Kenneth Knowles <[email protected]> Authored: Tue Oct 3 11:40:54 2017 -0700 Committer: Kenneth Knowles <[email protected]> Committed: Tue Oct 17 12:45:11 2017 -0700 ---------------------------------------------------------------------- .../core/construction/CombineTranslation.java | 165 ++++++++++++++++++- .../construction/CombineTranslationTest.java | 16 +- 2 files changed, 161 insertions(+), 20 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/92209c32/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java ---------------------------------------------------------------------- diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java index 69591ee..21796aa 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java @@ -22,12 +22,15 @@ import static com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.runners.core.construction.PTransformTranslation.COMBINE_TRANSFORM_URN; import com.google.auto.service.AutoService; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.protobuf.ByteString; import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import javax.annotation.Nonnull; import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.model.pipeline.v1.RunnerApi.CombinePayload; import org.apache.beam.model.pipeline.v1.RunnerApi.Components; @@ -52,12 +55,12 @@ import org.apache.beam.sdk.values.PCollection; * RunnerApi.CombinePayload} protos. */ public class CombineTranslation { + public static final String JAVA_SERIALIZED_COMBINE_FN_URN = "urn:beam:combinefn:javasdk:v1"; /** A {@link TransformPayloadTranslator} for {@link Combine.PerKey}. */ public static class CombinePayloadTranslator - extends PTransformTranslation.TransformPayloadTranslator.WithDefaultRehydration< - Combine.PerKey<?, ?, ?>> { + implements PTransformTranslation.TransformPayloadTranslator<Combine.PerKey<?, ?, ?>> { public static TransformPayloadTranslator create() { return new CombinePayloadTranslator(); } @@ -73,13 +76,25 @@ public class CombineTranslation { public FunctionSpec translate( AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>> transform, SdkComponents components) throws IOException { - CombinePayload payload = toProto(transform, components); - return RunnerApi.FunctionSpec.newBuilder() + return FunctionSpec.newBuilder() .setUrn(COMBINE_TRANSFORM_URN) - .setPayload(payload.toByteString()) + .setPayload(payloadForCombine((AppliedPTransform) transform, components).toByteString()) .build(); } + @Override + public PTransformTranslation.RawPTransform<?, ?> rehydrate( + RunnerApi.PTransform protoTransform, RehydratedComponents rehydratedComponents) + throws IOException { + checkArgument( + protoTransform.getSpec() != null, + "%s received transform with null spec", + getClass().getSimpleName()); + checkArgument(protoTransform.getSpec().getUrn().equals(COMBINE_TRANSFORM_URN)); + return new RawCombine<>( + CombinePayload.parseFrom(protoTransform.getSpec().getPayload()), rehydratedComponents); + } + /** Registers {@link CombinePayloadTranslator}. */ @AutoService(TransformPayloadTranslatorRegistrar.class) public static class Registrar implements TransformPayloadTranslatorRegistrar { @@ -90,13 +105,147 @@ public class CombineTranslation { } @Override - public Map<String, TransformPayloadTranslator> getTransformRehydrators() { - return Collections.emptyMap(); + public Map<String, ? extends TransformPayloadTranslator> getTransformRehydrators() { + return Collections.singletonMap(COMBINE_TRANSFORM_URN, new CombinePayloadTranslator()); + } + } + } + + /** + * These methods drive to-proto translation for both Java SDK transforms and rehydrated + * transforms. + */ + interface CombineLike { + RunnerApi.SdkFunctionSpec getCombineFn(); + + Coder<?> getAccumulatorCoder(); + + Map<String, RunnerApi.SideInput> getSideInputs(); + } + + /** Produces a {@link RunnerApi.CombinePayload} from a portable {@link CombineLike}. */ + static RunnerApi.CombinePayload payloadForCombineLike( + CombineLike combine, SdkComponents components) throws IOException { + return RunnerApi.CombinePayload.newBuilder() + .setAccumulatorCoderId(components.registerCoder(combine.getAccumulatorCoder())) + .putAllSideInputs(combine.getSideInputs()) + .setCombineFn(combine.getCombineFn()) + .build(); + } + + static <K, InputT, OutputT> CombinePayload payloadForCombine( + final AppliedPTransform< + PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>, + Combine.PerKey<K, InputT, OutputT>> + combine, + SdkComponents components) + throws IOException { + + return payloadForCombineLike( + new CombineLike() { + @Override + public SdkFunctionSpec getCombineFn() { + return SdkFunctionSpec.newBuilder() + // TODO: Set Java SDK Environment + .setSpec( + FunctionSpec.newBuilder() + .setUrn(JAVA_SERIALIZED_COMBINE_FN_URN) + .setPayload( + ByteString.copyFrom( + SerializableUtils.serializeToByteArray( + combine.getTransform().getFn()))) + .build()) + .build(); + } + + @Override + public Coder<?> getAccumulatorCoder() { + GlobalCombineFn<?, ?, ?> combineFn = combine.getTransform().getFn(); + try { + return extractAccumulatorCoder(combineFn, (AppliedPTransform) combine); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException(e); + } + } + + @Override + public Map<String, SideInput> getSideInputs() { + // TODO: support side inputs + return ImmutableMap.of(); + } + }, + components); + } + + private static class RawCombine<K, InputT, AccumT, OutputT> + extends PTransformTranslation.RawPTransform< + PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>> + implements CombineLike { + + private final transient RehydratedComponents rehydratedComponents; + private final FunctionSpec spec; + private final CombinePayload payload; + private final Coder<AccumT> accumulatorCoder; + + private RawCombine(CombinePayload payload, RehydratedComponents rehydratedComponents) { + this.rehydratedComponents = rehydratedComponents; + this.payload = payload; + this.spec = + FunctionSpec.newBuilder() + .setUrn(COMBINE_TRANSFORM_URN) + .setPayload(payload.toByteString()) + .build(); + + // Eagerly extract the coder to throw a good exception here + try { + this.accumulatorCoder = + (Coder<AccumT>) rehydratedComponents.getCoder(payload.getAccumulatorCoderId()); + } catch (IOException exc) { + throw new IllegalArgumentException( + String.format( + "Failure extracting accumulator coder with id '%s' for %s", + payload.getAccumulatorCoderId(), Combine.class.getSimpleName()), + exc); } } + + @Override + public String getUrn() { + return COMBINE_TRANSFORM_URN; + } + + @Nonnull + @Override + public FunctionSpec getSpec() { + return spec; + } + + @Override + public RunnerApi.FunctionSpec migrate(SdkComponents sdkComponents) throws IOException { + return RunnerApi.FunctionSpec.newBuilder() + .setUrn(COMBINE_TRANSFORM_URN) + .setPayload(payloadForCombineLike(this, sdkComponents).toByteString()) + .build(); + } + + @Override + public SdkFunctionSpec getCombineFn() { + return payload.getCombineFn(); + } + + @Override + public Coder<?> getAccumulatorCoder() { + return accumulatorCoder; + } + + @Override + public Map<String, SideInput> getSideInputs() { + return payload.getSideInputsMap(); + } } - public static CombinePayload toProto( + @VisibleForTesting + static CombinePayload toProto( AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>> combine, SdkComponents sdkComponents) throws IOException { GlobalCombineFn<?, ?, ?> combineFn = combine.getTransform().getFn(); http://git-wip-us.apache.org/repos/asf/beam/blob/92209c32/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java ---------------------------------------------------------------------- diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java index 8740d7f..af162d3 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java @@ -52,15 +52,11 @@ import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameter; import org.junit.runners.Parameterized.Parameters; -/** - * Tests for {@link CombineTranslation}. - */ +/** Tests for {@link CombineTranslation}. */ @RunWith(Enclosed.class) public class CombineTranslationTest { - /** - * Tests that simple {@link CombineFn CombineFns} can be translated to and from proto. - */ + /** Tests that simple {@link CombineFn CombineFns} can be translated to and from proto. */ @RunWith(Parameterized.class) public static class TranslateSimpleCombinesTest { @Parameters(name = "{index}: {0}") @@ -111,14 +107,10 @@ public class CombineTranslationTest { } } - - /** - * Tests that a {@link CombineFnWithContext} can be translated. - */ + /** Tests that a {@link CombineFnWithContext} can be translated. */ @RunWith(JUnit4.class) public static class ValidateCombineWithContextTest { - @Rule - public TestPipeline pipeline = TestPipeline.create(); + @Rule public TestPipeline pipeline = TestPipeline.create(); @Test public void testToFromProtoWithSideInputs() throws Exception {
