Add custom rehydration for WriteFiles

Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/187beae4
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/187beae4
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/187beae4

Branch: refs/heads/master
Commit: 187beae4d20576d0e0ea1ca80d03252d1f2507e5
Parents: 7fb3e79
Author: Kenneth Knowles <[email protected]>
Authored: Tue Oct 3 21:17:38 2017 -0700
Committer: Kenneth Knowles <[email protected]>
Committed: Tue Oct 17 12:45:11 2017 -0700

----------------------------------------------------------------------
 .../construction/WriteFilesTranslation.java     | 166 ++++++++++++++++---
 .../construction/WriteFilesTranslationTest.java |   3 +-
 2 files changed, 148 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/187beae4/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java
----------------------------------------------------------------------
diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java
index 645b562..d0b2182 100644
--- 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java
@@ -24,12 +24,13 @@ import static 
org.apache.beam.runners.core.construction.PTransformTranslation.WR
 
 import com.google.auto.service.AutoService;
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.MoreObjects;
 import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
 import com.google.protobuf.ByteString;
 import java.io.IOException;
 import java.io.Serializable;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
@@ -46,6 +47,9 @@ import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.util.SerializableUtils;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PInput;
+import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TupleTag;
 
 /**
@@ -59,18 +63,37 @@ public class WriteFilesTranslation {
       "urn:beam:file_based_sink:javasdk:0.1";
 
   @VisibleForTesting
-  static WriteFilesPayload toProto(WriteFiles<?, ?, ?> transform) {
-    Map<String, SideInput> sideInputs = Maps.newHashMap();
-    for (PCollectionView<?> view : 
transform.getSink().getDynamicDestinations().getSideInputs()) {
-      sideInputs.put(view.getTagInternal().getId(), 
ParDoTranslation.toProto(view));
-    }
-    return WriteFilesPayload.newBuilder()
-        .setSink(toProto(transform.getSink()))
-        .setWindowedWrites(transform.isWindowedWrites())
-        .setRunnerDeterminedSharding(
-            transform.getNumShards() == null && transform.getSharding() == 
null)
-        .putAllSideInputs(sideInputs)
-        .build();
+  static WriteFilesPayload payloadForWriteFiles(
+      final WriteFiles<?, ?, ?> transform, SdkComponents components) throws 
IOException {
+    return payloadForWriteFilesLike(
+        new WriteFilesLike() {
+          @Override
+          public SdkFunctionSpec translateSink(SdkComponents newComponents) {
+            // TODO: register the environment
+            return toProto(transform.getSink());
+          }
+
+          @Override
+          public Map<String, SideInput> translateSideInputs(SdkComponents 
components) {
+            Map<String, SideInput> sideInputs = new HashMap<>();
+            for (PCollectionView<?> view :
+                transform.getSink().getDynamicDestinations().getSideInputs()) {
+              sideInputs.put(view.getTagInternal().getId(), 
ParDoTranslation.toProto(view));
+            }
+            return sideInputs;
+          }
+
+          @Override
+          public boolean isWindowedWrites() {
+            return transform.isWindowedWrites();
+          }
+
+          @Override
+          public boolean isRunnerDeterminedSharding() {
+            return transform.getNumShards() == null && transform.getSharding() 
== null;
+          }
+        },
+        components);
   }
 
   private static SdkFunctionSpec toProto(FileBasedSink<?, ?, ?> sink) {
@@ -174,8 +197,82 @@ public class WriteFilesTranslation {
             .getPayload());
   }
 
-  static class WriteFilesTranslator
-      extends TransformPayloadTranslator.WithDefaultRehydration<WriteFiles<?, 
?, ?>> {
+  static class RawWriteFiles extends 
PTransformTranslation.RawPTransform<PInput, POutput>
+      implements WriteFilesLike {
+
+    private final RunnerApi.PTransform protoTransform;
+    private final transient RehydratedComponents rehydratedComponents;
+
+    // Parsed from protoTransform and cached
+    private final FunctionSpec spec;
+    private final RunnerApi.WriteFilesPayload payload;
+
+    public RawWriteFiles(
+        RunnerApi.PTransform protoTransform, RehydratedComponents 
rehydratedComponents)
+        throws IOException {
+      this.rehydratedComponents = rehydratedComponents;
+      this.protoTransform = protoTransform;
+      this.spec = protoTransform.getSpec();
+      this.payload = RunnerApi.WriteFilesPayload.parseFrom(spec.getPayload());
+    }
+
+    @Override
+    public FunctionSpec getSpec() {
+      return spec;
+    }
+
+    @Override
+    public FunctionSpec migrate(SdkComponents components) throws IOException {
+      return FunctionSpec.newBuilder()
+          .setUrn(WRITE_FILES_TRANSFORM_URN)
+          .setPayload(payloadForWriteFilesLike(this, 
components).toByteString())
+          .build();
+    }
+
+    @Override
+    public Map<TupleTag<?>, PValue> getAdditionalInputs() {
+      Map<TupleTag<?>, PValue> additionalInputs = new HashMap<>();
+      for (Map.Entry<String, SideInput> sideInputEntry : 
payload.getSideInputsMap().entrySet()) {
+        try {
+          additionalInputs.put(
+              new TupleTag<>(sideInputEntry.getKey()),
+              rehydratedComponents.getPCollection(
+                  protoTransform.getInputsOrThrow(sideInputEntry.getKey())));
+        } catch (IOException exc) {
+          throw new IllegalStateException(
+              String.format(
+                  "Could not find input with name %s for %s transform",
+                  sideInputEntry.getKey(), WriteFiles.class.getSimpleName()));
+        }
+      }
+      return additionalInputs;
+    }
+
+    @Override
+    public SdkFunctionSpec translateSink(SdkComponents newComponents) {
+      // TODO: re-register the environment with the new components
+      return payload.getSink();
+    }
+
+    @Override
+    public Map<String, SideInput> translateSideInputs(SdkComponents 
components) {
+      // TODO: re-register the PCollections and UDF environments
+      return MoreObjects.firstNonNull(
+          payload.getSideInputsMap(), Collections.<String, 
SideInput>emptyMap());
+    }
+
+    @Override
+    public boolean isWindowedWrites() {
+      return payload.getWindowedWrites();
+    }
+
+    @Override
+    public boolean isRunnerDeterminedSharding() {
+      return payload.getRunnerDeterminedSharding();
+    }
+  }
+
+  static class WriteFilesTranslator implements 
TransformPayloadTranslator<WriteFiles<?, ?, ?>> {
     @Override
     public String getUrn(WriteFiles<?, ?, ?> transform) {
       return WRITE_FILES_TRANSFORM_URN;
@@ -183,14 +280,21 @@ public class WriteFilesTranslation {
 
     @Override
     public FunctionSpec translate(
-        AppliedPTransform<?, ?, WriteFiles<?, ?, ?>> transform, SdkComponents 
components) {
+        AppliedPTransform<?, ?, WriteFiles<?, ?, ?>> transform, SdkComponents 
components)
+        throws IOException {
       return FunctionSpec.newBuilder()
           .setUrn(getUrn(transform.getTransform()))
-          .setPayload(toProto(transform.getTransform()).toByteString())
+          .setPayload(payloadForWriteFiles(transform.getTransform(), 
components).toByteString())
           .build();
     }
-  }
 
+    @Override
+    public PTransformTranslation.RawPTransform<?, ?> rehydrate(
+        RunnerApi.PTransform protoTransform, RehydratedComponents 
rehydratedComponents)
+        throws IOException {
+      return new RawWriteFiles(protoTransform, rehydratedComponents);
+    }
+  }
   /** Registers {@link WriteFilesTranslator}. */
   @AutoService(TransformPayloadTranslatorRegistrar.class)
   public static class Registrar implements TransformPayloadTranslatorRegistrar 
{
@@ -202,8 +306,30 @@ public class WriteFilesTranslation {
     }
 
     @Override
-    public Map<String, TransformPayloadTranslator> getTransformRehydrators() {
-      return Collections.emptyMap();
+    public Map<String, ? extends TransformPayloadTranslator> 
getTransformRehydrators() {
+      return Collections.singletonMap(WRITE_FILES_TRANSFORM_URN, new 
WriteFilesTranslator());
     }
   }
+
+  /** These methods drive to-proto translation from Java and from rehydrated 
WriteFiles. */
+  private interface WriteFilesLike {
+    SdkFunctionSpec translateSink(SdkComponents newComponents);
+
+    Map<String, RunnerApi.SideInput> translateSideInputs(SdkComponents 
components);
+
+    boolean isWindowedWrites();
+
+    boolean isRunnerDeterminedSharding();
+  }
+
+  public static WriteFilesPayload payloadForWriteFilesLike(
+      WriteFilesLike writeFiles, SdkComponents components) throws IOException {
+
+    return WriteFilesPayload.newBuilder()
+        .setSink(writeFiles.translateSink(components))
+        .putAllSideInputs(writeFiles.translateSideInputs(components))
+        .setWindowedWrites(writeFiles.isWindowedWrites())
+        .setRunnerDeterminedSharding(writeFiles.isRunnerDeterminedSharding())
+        .build();
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/187beae4/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java
----------------------------------------------------------------------
diff --git 
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java
 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java
index c874828..4bc61d4 100644
--- 
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java
+++ 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java
@@ -76,7 +76,8 @@ public class WriteFilesTranslationTest {
 
     @Test
     public void testEncodedProto() throws Exception {
-      RunnerApi.WriteFilesPayload payload = 
WriteFilesTranslation.toProto(writeFiles);
+      RunnerApi.WriteFilesPayload payload =
+          WriteFilesTranslation.payloadForWriteFiles(writeFiles, 
SdkComponents.create());
 
       assertThat(
           payload.getRunnerDeterminedSharding(),

Reply via email to