Add custom rehydration for WriteFiles
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/187beae4 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/187beae4 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/187beae4 Branch: refs/heads/master Commit: 187beae4d20576d0e0ea1ca80d03252d1f2507e5 Parents: 7fb3e79 Author: Kenneth Knowles <[email protected]> Authored: Tue Oct 3 21:17:38 2017 -0700 Committer: Kenneth Knowles <[email protected]> Committed: Tue Oct 17 12:45:11 2017 -0700 ---------------------------------------------------------------------- .../construction/WriteFilesTranslation.java | 166 ++++++++++++++++--- .../construction/WriteFilesTranslationTest.java | 3 +- 2 files changed, 148 insertions(+), 21 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/187beae4/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java ---------------------------------------------------------------------- diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java index 645b562..d0b2182 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java @@ -24,12 +24,13 @@ import static org.apache.beam.runners.core.construction.PTransformTranslation.WR import com.google.auto.service.AutoService; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; import com.google.common.collect.Lists; -import com.google.common.collect.Maps; import com.google.protobuf.ByteString; import java.io.IOException; import java.io.Serializable; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.beam.model.pipeline.v1.RunnerApi; @@ -46,6 +47,9 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; /** @@ -59,18 +63,37 @@ public class WriteFilesTranslation { "urn:beam:file_based_sink:javasdk:0.1"; @VisibleForTesting - static WriteFilesPayload toProto(WriteFiles<?, ?, ?> transform) { - Map<String, SideInput> sideInputs = Maps.newHashMap(); - for (PCollectionView<?> view : transform.getSink().getDynamicDestinations().getSideInputs()) { - sideInputs.put(view.getTagInternal().getId(), ParDoTranslation.toProto(view)); - } - return WriteFilesPayload.newBuilder() - .setSink(toProto(transform.getSink())) - .setWindowedWrites(transform.isWindowedWrites()) - .setRunnerDeterminedSharding( - transform.getNumShards() == null && transform.getSharding() == null) - .putAllSideInputs(sideInputs) - .build(); + static WriteFilesPayload payloadForWriteFiles( + final WriteFiles<?, ?, ?> transform, SdkComponents components) throws IOException { + return payloadForWriteFilesLike( + new WriteFilesLike() { + @Override + public SdkFunctionSpec translateSink(SdkComponents newComponents) { + // TODO: register the environment + return toProto(transform.getSink()); + } + + @Override + public Map<String, SideInput> translateSideInputs(SdkComponents components) { + Map<String, SideInput> sideInputs = new HashMap<>(); + for (PCollectionView<?> view : + transform.getSink().getDynamicDestinations().getSideInputs()) { + sideInputs.put(view.getTagInternal().getId(), ParDoTranslation.toProto(view)); + } + return sideInputs; + } + + @Override + public boolean isWindowedWrites() { + return transform.isWindowedWrites(); + } + + @Override + public boolean isRunnerDeterminedSharding() { + return transform.getNumShards() == null && transform.getSharding() == null; + } + }, + components); } private static SdkFunctionSpec toProto(FileBasedSink<?, ?, ?> sink) { @@ -174,8 +197,82 @@ public class WriteFilesTranslation { .getPayload()); } - static class WriteFilesTranslator - extends TransformPayloadTranslator.WithDefaultRehydration<WriteFiles<?, ?, ?>> { + static class RawWriteFiles extends PTransformTranslation.RawPTransform<PInput, POutput> + implements WriteFilesLike { + + private final RunnerApi.PTransform protoTransform; + private final transient RehydratedComponents rehydratedComponents; + + // Parsed from protoTransform and cached + private final FunctionSpec spec; + private final RunnerApi.WriteFilesPayload payload; + + public RawWriteFiles( + RunnerApi.PTransform protoTransform, RehydratedComponents rehydratedComponents) + throws IOException { + this.rehydratedComponents = rehydratedComponents; + this.protoTransform = protoTransform; + this.spec = protoTransform.getSpec(); + this.payload = RunnerApi.WriteFilesPayload.parseFrom(spec.getPayload()); + } + + @Override + public FunctionSpec getSpec() { + return spec; + } + + @Override + public FunctionSpec migrate(SdkComponents components) throws IOException { + return FunctionSpec.newBuilder() + .setUrn(WRITE_FILES_TRANSFORM_URN) + .setPayload(payloadForWriteFilesLike(this, components).toByteString()) + .build(); + } + + @Override + public Map<TupleTag<?>, PValue> getAdditionalInputs() { + Map<TupleTag<?>, PValue> additionalInputs = new HashMap<>(); + for (Map.Entry<String, SideInput> sideInputEntry : payload.getSideInputsMap().entrySet()) { + try { + additionalInputs.put( + new TupleTag<>(sideInputEntry.getKey()), + rehydratedComponents.getPCollection( + protoTransform.getInputsOrThrow(sideInputEntry.getKey()))); + } catch (IOException exc) { + throw new IllegalStateException( + String.format( + "Could not find input with name %s for %s transform", + sideInputEntry.getKey(), WriteFiles.class.getSimpleName())); + } + } + return additionalInputs; + } + + @Override + public SdkFunctionSpec translateSink(SdkComponents newComponents) { + // TODO: re-register the environment with the new components + return payload.getSink(); + } + + @Override + public Map<String, SideInput> translateSideInputs(SdkComponents components) { + // TODO: re-register the PCollections and UDF environments + return MoreObjects.firstNonNull( + payload.getSideInputsMap(), Collections.<String, SideInput>emptyMap()); + } + + @Override + public boolean isWindowedWrites() { + return payload.getWindowedWrites(); + } + + @Override + public boolean isRunnerDeterminedSharding() { + return payload.getRunnerDeterminedSharding(); + } + } + + static class WriteFilesTranslator implements TransformPayloadTranslator<WriteFiles<?, ?, ?>> { @Override public String getUrn(WriteFiles<?, ?, ?> transform) { return WRITE_FILES_TRANSFORM_URN; @@ -183,14 +280,21 @@ public class WriteFilesTranslation { @Override public FunctionSpec translate( - AppliedPTransform<?, ?, WriteFiles<?, ?, ?>> transform, SdkComponents components) { + AppliedPTransform<?, ?, WriteFiles<?, ?, ?>> transform, SdkComponents components) + throws IOException { return FunctionSpec.newBuilder() .setUrn(getUrn(transform.getTransform())) - .setPayload(toProto(transform.getTransform()).toByteString()) + .setPayload(payloadForWriteFiles(transform.getTransform(), components).toByteString()) .build(); } - } + @Override + public PTransformTranslation.RawPTransform<?, ?> rehydrate( + RunnerApi.PTransform protoTransform, RehydratedComponents rehydratedComponents) + throws IOException { + return new RawWriteFiles(protoTransform, rehydratedComponents); + } + } /** Registers {@link WriteFilesTranslator}. */ @AutoService(TransformPayloadTranslatorRegistrar.class) public static class Registrar implements TransformPayloadTranslatorRegistrar { @@ -202,8 +306,30 @@ public class WriteFilesTranslation { } @Override - public Map<String, TransformPayloadTranslator> getTransformRehydrators() { - return Collections.emptyMap(); + public Map<String, ? extends TransformPayloadTranslator> getTransformRehydrators() { + return Collections.singletonMap(WRITE_FILES_TRANSFORM_URN, new WriteFilesTranslator()); } } + + /** These methods drive to-proto translation from Java and from rehydrated WriteFiles. */ + private interface WriteFilesLike { + SdkFunctionSpec translateSink(SdkComponents newComponents); + + Map<String, RunnerApi.SideInput> translateSideInputs(SdkComponents components); + + boolean isWindowedWrites(); + + boolean isRunnerDeterminedSharding(); + } + + public static WriteFilesPayload payloadForWriteFilesLike( + WriteFilesLike writeFiles, SdkComponents components) throws IOException { + + return WriteFilesPayload.newBuilder() + .setSink(writeFiles.translateSink(components)) + .putAllSideInputs(writeFiles.translateSideInputs(components)) + .setWindowedWrites(writeFiles.isWindowedWrites()) + .setRunnerDeterminedSharding(writeFiles.isRunnerDeterminedSharding()) + .build(); + } } http://git-wip-us.apache.org/repos/asf/beam/blob/187beae4/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java ---------------------------------------------------------------------- diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java index c874828..4bc61d4 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java @@ -76,7 +76,8 @@ public class WriteFilesTranslationTest { @Test public void testEncodedProto() throws Exception { - RunnerApi.WriteFilesPayload payload = WriteFilesTranslation.toProto(writeFiles); + RunnerApi.WriteFilesPayload payload = + WriteFilesTranslation.payloadForWriteFiles(writeFiles, SdkComponents.create()); assertThat( payload.getRunnerDeterminedSharding(),
