[ 
https://issues.apache.org/jira/browse/BEAM-3515?focusedWorklogId=101409&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-101409
 ]

ASF GitHub Bot logged work on BEAM-3515:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 12/May/18 07:21
            Start Date: 12/May/18 07:21
    Worklog Time Spent: 10m 
      Work Description: jkff closed pull request #5277: [BEAM-3515] Portable 
translation of SplittableProcessKeyed
URL: https://github.com/apache/beam/pull/5277
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto 
b/model/pipeline/src/main/proto/beam_runner_api.proto
index 33448316921..11646239b93 100644
--- a/model/pipeline/src/main/proto/beam_runner_api.proto
+++ b/model/pipeline/src/main/proto/beam_runner_api.proto
@@ -240,15 +240,18 @@ message StandardPTransforms {
     // Less well-known. Payload: WriteFilesPayload.
     WRITE_FILES = 4 [(beam_urn) = "beam:transform:write_files:v1"];
   }
+  // Payload for all of these: CombinePayload
   enum CombineComponents {
     COMBINE_PGBKCV = 0 [(beam_urn) = "beam:transform:combine_pgbkcv:v1"];
     COMBINE_MERGE_ACCUMULATORS = 1 [(beam_urn) = 
"beam:transform:combine_merge_accumulators:v1"];
     COMBINE_EXTRACT_OUTPUTS = 2 [(beam_urn) = 
"beam:transform:combine_extract_outputs:v1"];
   }
-
-  // This field is needed only as a work-around for a proto compiler bug.
-  // See https://github.com/google/protobuf/issues/4514
-  int32 ignored = 1;
+  // Payload for all of these: ParDoPayload containing the user's SDF
+  enum SplittableParDoComponents {
+    PAIR_WITH_RESTRICTION = 0 [(beam_urn) = 
"beam:transform:sdf_pair_with_restriction:v1"];
+    SPLIT_RESTRICTION = 1 [(beam_urn) = 
"beam:transform:sdf_split_restriction:v1"];
+    PROCESS_KEYED_ELEMENTS = 2 [(beam_urn) = 
"beam:transform:sdf_process_keyed_elements:v1"];
+  }
 }
 
 message StandardSideInputTypes {
@@ -350,6 +353,9 @@ message ParDoPayload {
 
   // Whether the DoFn is splittable
   bool splittable = 6;
+
+  // (Required if splittable == true) Id of the restriction coder.
+  string restriction_coder_id = 7;
 }
 
 // Parameters that a UDF might require.
diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
index 2eb27873105..37886605296 100644
--- 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
@@ -37,6 +37,7 @@
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
 import org.apache.beam.model.pipeline.v1.RunnerApi.StandardPTransforms;
+import 
org.apache.beam.model.pipeline.v1.RunnerApi.StandardPTransforms.SplittableParDoComponents;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -82,6 +83,8 @@
       StandardPTransforms.Composites.RESHUFFLE);
   public static final String WRITE_FILES_TRANSFORM_URN =
       getUrn(StandardPTransforms.Composites.WRITE_FILES);
+  public static final String SPLITTABLE_PROCESS_KEYED_URN =
+      getUrn(SplittableParDoComponents.PROCESS_KEYED_ELEMENTS);
 
   private static final Map<Class<? extends PTransform>, 
TransformPayloadTranslator>
       KNOWN_PAYLOAD_TRANSLATORS = loadTransformPayloadTranslators();
diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
index 6365d77dccc..d5478498798 100644
--- 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
@@ -48,6 +48,7 @@
 import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput;
 import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput.Builder;
 import 
org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator;
+import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.state.StateSpec;
@@ -60,6 +61,7 @@
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
 import org.apache.beam.sdk.transforms.ViewFn;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.Cases;
@@ -105,7 +107,8 @@ public String getUrn(ParDo.MultiOutput<?, ?> transform) {
     public FunctionSpec translate(
         AppliedPTransform<?, ?, MultiOutput<?, ?>> transform, SdkComponents 
components)
         throws IOException {
-      ParDoPayload payload = translateParDo(transform.getTransform(), 
components);
+      ParDoPayload payload =
+          translateParDo(transform.getTransform(), transform.getPipeline(), 
components);
       return RunnerApi.FunctionSpec.newBuilder()
           .setUrn(PAR_DO_TRANSFORM_URN)
           .setPayload(payload.toByteString())
@@ -136,10 +139,19 @@ public FunctionSpec translate(
   }
 
   public static ParDoPayload translateParDo(
-      final ParDo.MultiOutput<?, ?> parDo, SdkComponents components) throws 
IOException {
+      final ParDo.MultiOutput<?, ?> parDo, Pipeline pipeline, SdkComponents 
components)
+      throws IOException {
 
     final DoFn<?, ?> doFn = parDo.getFn();
     final DoFnSignature signature = 
DoFnSignatures.getSignature(doFn.getClass());
+    final String restrictionCoderId;
+    if (signature.processElement().isSplittable()) {
+      final Coder<?> restrictionCoder =
+          
DoFnInvokers.invokerFor(doFn).invokeGetRestrictionCoder(pipeline.getCoderRegistry());
+      restrictionCoderId = components.registerCoder(restrictionCoder);
+    } else {
+      restrictionCoderId = "";
+    }
 
     return payloadForParDoLike(
         new ParDoLike() {
@@ -151,14 +163,8 @@ public SdkFunctionSpec translateDoFn(SdkComponents 
newComponents) {
 
           @Override
           public List<RunnerApi.Parameter> translateParameters() {
-            List<RunnerApi.Parameter> parameters = new ArrayList<>();
-            for (Parameter parameter : 
signature.processElement().extraParameters()) {
-              RunnerApi.Parameter protoParameter = 
translateParameter(parameter);
-              if (protoParameter != null) {
-                parameters.add(protoParameter);
-              }
-            }
-            return parameters;
+            return ParDoTranslation.translateParameters(
+                signature.processElement().extraParameters());
           }
 
           @Override
@@ -200,10 +206,26 @@ public SdkFunctionSpec translateDoFn(SdkComponents 
newComponents) {
           public boolean isSplittable() {
             return signature.processElement().isSplittable();
           }
+
+          @Override
+          public String translateRestrictionCoderId(SdkComponents 
newComponents) {
+            return restrictionCoderId;
+          }
         },
         components);
   }
 
+  public static List<RunnerApi.Parameter> translateParameters(List<Parameter> 
params) {
+    List<RunnerApi.Parameter> parameters = new ArrayList<>();
+    for (Parameter parameter : params) {
+      RunnerApi.Parameter protoParameter = translateParameter(parameter);
+      if (protoParameter != null) {
+        parameters.add(protoParameter);
+      }
+    }
+    return parameters;
+  }
+
   public static DoFn<?, ?> getDoFn(ParDoPayload payload) throws 
InvalidProtocolBufferException {
     return doFnAndMainOutputTagFromProto(payload.getDoFn()).getDoFn();
   }
@@ -442,8 +464,7 @@ public static SdkFunctionSpec translateDoFn(
         .build();
   }
 
-  private static DoFnAndMainOutput 
doFnAndMainOutputTagFromProto(SdkFunctionSpec fnSpec)
-      throws InvalidProtocolBufferException {
+  public static DoFnAndMainOutput 
doFnAndMainOutputTagFromProto(SdkFunctionSpec fnSpec) {
     checkArgument(
         fnSpec.getSpec().getUrn().equals(CUSTOM_JAVA_DO_FN_URN),
         "Expected %s to be %s with URN %s, but URN was %s",
@@ -489,6 +510,17 @@ private static DoFnAndMainOutput 
doFnAndMainOutputTagFromProto(SdkFunctionSpec f
         });
   }
 
+  public static Map<String, SideInput> translateSideInputs(
+      List<PCollectionView<?>> views, SdkComponents components) {
+    Map<String, SideInput> sideInputs = new HashMap<>();
+    for (PCollectionView<?> sideInput : views) {
+      sideInputs.put(
+          sideInput.getTagInternal().getId(),
+          ParDoTranslation.translateView(sideInput, components));
+    }
+    return sideInputs;
+  }
+
   public static SideInput translateView(PCollectionView<?> view, SdkComponents 
components) {
     Builder builder = SideInput.newBuilder();
     builder.setAccessPattern(
@@ -631,6 +663,11 @@ public SdkFunctionSpec translateDoFn(SdkComponents 
newComponents) {
     public boolean isSplittable() {
       return payload.getSplittable();
     }
+
+    @Override
+    public String translateRestrictionCoderId(SdkComponents newComponents) {
+      return payload.getRestrictionCoderId();
+    }
   }
 
   /** These methods drive to-proto translation from Java and from rehydrated 
ParDos. */
@@ -647,6 +684,8 @@ public boolean isSplittable() {
     Map<String, RunnerApi.TimerSpec> translateTimerSpecs(SdkComponents 
newComponents);
 
     boolean isSplittable();
+
+    String translateRestrictionCoderId(SdkComponents newComponents);
   }
 
   public static ParDoPayload payloadForParDoLike(ParDoLike parDo, 
SdkComponents components)
@@ -659,6 +698,7 @@ public static ParDoPayload payloadForParDoLike(ParDoLike 
parDo, SdkComponents co
         .putAllTimerSpecs(parDo.translateTimerSpecs(components))
         .putAllSideInputs(parDo.translateSideInputs(components))
         .setSplittable(parDo.isSplittable())
+        .setRestrictionCoderId(parDo.translateRestrictionCoderId(components))
         .build();
   }
 }
diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
index 8254620f265..8bd39602363 100644
--- 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
@@ -19,14 +19,27 @@
 
 import static com.google.common.base.Preconditions.checkArgument;
 
+import com.google.auto.service.AutoService;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Maps;
 import java.io.IOException;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.UUID;
 import javax.annotation.Nullable;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
-import 
org.apache.beam.runners.core.construction.PTransformTranslation.RawPTransform;
+import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
+import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
+import org.apache.beam.model.pipeline.v1.RunnerApi.Parameter;
+import org.apache.beam.model.pipeline.v1.RunnerApi.SdkFunctionSpec;
+import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput;
+import org.apache.beam.model.pipeline.v1.RunnerApi.StateSpec;
+import org.apache.beam.model.pipeline.v1.RunnerApi.TimerSpec;
+import 
org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator;
+import org.apache.beam.runners.core.construction.ParDoTranslation.ParDoLike;
+import 
org.apache.beam.runners.core.construction.ReadTranslation.BoundedReadPayloadTranslator;
+import 
org.apache.beam.runners.core.construction.ReadTranslation.UnboundedReadPayloadTranslator;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
@@ -82,9 +95,6 @@
   public static final String SPLITTABLE_PROCESS_URN =
       "urn:beam:runners_core:transforms:splittable_process:v1";
 
-  public static final String SPLITTABLE_PROCESS_KEYED_ELEMENTS_URN =
-      "urn:beam:runners_core:transforms:splittable_process_keyed_elements:v1";
-
   public static final String SPLITTABLE_GBKIKWI_URN =
       "urn:beam:runners_core:transforms:splittable_gbkikwi:v1";
 
@@ -188,7 +198,7 @@ public void process(ProcessContext c, BoundedWindow window) 
{
    * {@link KV KVs} keyed with arbitrary but globally unique keys.
    */
   public static class ProcessKeyedElements<InputT, OutputT, RestrictionT>
-      extends RawPTransform<PCollection<KV<String, KV<InputT, RestrictionT>>>, 
PCollectionTuple> {
+      extends PTransform<PCollection<KV<String, KV<InputT, RestrictionT>>>, 
PCollectionTuple> {
     private final DoFn<InputT, OutputT> fn;
     private final Coder<InputT> elementCoder;
     private final Coder<RestrictionT> restrictionCoder;
@@ -290,16 +300,93 @@ public PCollectionTuple expand(PCollection<KV<String, 
KV<InputT, RestrictionT>>>
     public Map<TupleTag<?>, PValue> getAdditionalInputs() {
       return PCollectionViews.toAdditionalInputs(sideInputs);
     }
+  }
+
+  /** Registers {@link UnboundedReadPayloadTranslator} and {@link 
BoundedReadPayloadTranslator}. */
+  @AutoService(TransformPayloadTranslatorRegistrar.class)
+  public static class Registrar implements TransformPayloadTranslatorRegistrar 
{
+    @Override
+    public Map<? extends Class<? extends PTransform>, ? extends 
TransformPayloadTranslator>
+    getTransformPayloadTranslators() {
+      return ImmutableMap.<Class<? extends PTransform>, 
TransformPayloadTranslator>builder()
+          .put(ProcessKeyedElements.class, new 
ProcessKeyedElementsTranslator())
+          .build();
+    }
+
+    @Override
+    public Map<String, TransformPayloadTranslator> getTransformRehydrators() {
+      return Collections.emptyMap();
+    }
+  }
+
+  /** A translator for {@link ProcessKeyedElements}. */
+  public static class ProcessKeyedElementsTranslator extends
+      PTransformTranslation.TransformPayloadTranslator.WithDefaultRehydration<
+          ProcessKeyedElements<?, ?, ?>> {
+
+    public static TransformPayloadTranslator create() {
+      return new ProcessKeyedElementsTranslator();
+    }
+
+    private ProcessKeyedElementsTranslator() {
+    }
 
     @Override
-    public String getUrn() {
-      return SPLITTABLE_PROCESS_KEYED_ELEMENTS_URN;
+    public String getUrn(ProcessKeyedElements<?, ?, ?> transform) {
+      return PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN;
     }
 
-    @Nullable
     @Override
-    public RunnerApi.FunctionSpec getSpec() {
-      return null;
+    public FunctionSpec translate(
+        AppliedPTransform<?, ?, ProcessKeyedElements<?, ?, ?>> transform,
+        SdkComponents components) throws IOException {
+      ProcessKeyedElements<?, ?, ?> pke = transform.getTransform();
+      final DoFn<?, ?> fn = pke.getFn();
+      final DoFnSignature signature = 
DoFnSignatures.getSignature(fn.getClass());
+      final String restrictionCoderId = 
components.registerCoder(pke.getRestrictionCoder());
+
+      ParDoPayload payload = ParDoTranslation.payloadForParDoLike(new 
ParDoLike() {
+        @Override
+        public SdkFunctionSpec translateDoFn(SdkComponents newComponents) {
+          return ParDoTranslation.translateDoFn(fn, pke.getMainOutputTag(), 
newComponents);
+        }
+
+        @Override
+        public List<Parameter> translateParameters() {
+          return 
ParDoTranslation.translateParameters(signature.processElement().extraParameters());
+        }
+
+        @Override
+        public Map<String, SideInput> translateSideInputs(SdkComponents 
components) {
+          return ParDoTranslation.translateSideInputs(pke.getSideInputs(), 
components);
+        }
+
+        @Override
+        public Map<String, StateSpec> translateStateSpecs(SdkComponents 
components) {
+          // SDFs don't have state.
+          return ImmutableMap.of();
+        }
+
+        @Override
+        public Map<String, TimerSpec> translateTimerSpecs(SdkComponents 
components) {
+          // SDFs don't have timers.
+          return ImmutableMap.of();
+        }
+
+        @Override
+        public boolean isSplittable() {
+          return true;
+        }
+
+        @Override
+        public String translateRestrictionCoderId(SdkComponents newComponents) 
{
+          return restrictionCoderId;
+        }
+      }, components);
+      return RunnerApi.FunctionSpec.newBuilder()
+          .setUrn(getUrn(pke))
+          .setPayload(payload.toByteString())
+          .build();
     }
   }
 
diff --git 
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EnvironmentsTest.java
 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EnvironmentsTest.java
index 34751935987..bb596920992 100644
--- 
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EnvironmentsTest.java
+++ 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EnvironmentsTest.java
@@ -32,6 +32,7 @@
 import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
 import org.apache.beam.model.pipeline.v1.RunnerApi.ReadPayload;
 import org.apache.beam.model.pipeline.v1.RunnerApi.WindowIntoPayload;
+import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.io.CountingSource;
 import org.apache.beam.sdk.io.Read;
@@ -80,6 +81,7 @@ public void getEnvironmentParDo() throws IOException {
                       public void process(ProcessContext ctxt) {}
                     })
                 .withOutputTags(new TupleTag<>(), TupleTagList.empty()),
+            Pipeline.create(),
             components);
     RehydratedComponents rehydratedComponents =
         RehydratedComponents.forComponents(components.toComponents());
diff --git 
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java
 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java
index e45e4db8b50..aa5a1d5e116 100644
--- 
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java
+++ 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java
@@ -118,7 +118,7 @@
     @Test
     public void testToAndFromProto() throws Exception {
       SdkComponents components = SdkComponents.create();
-      ParDoPayload payload = ParDoTranslation.translateParDo(parDo, 
components);
+      ParDoPayload payload = ParDoTranslation.translateParDo(parDo, p, 
components);
 
       assertThat(ParDoTranslation.getDoFn(payload), 
Matchers.equalTo(parDo.getFn()));
       assertThat(
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
index 6329c55d813..b1f4d804420 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
@@ -267,7 +267,7 @@ public DirectPipelineResult run(Pipeline originalPipeline) {
             .add(
                 PTransformOverride.of(
                     PTransformMatchers.urnEqualTo(
-                        SplittableParDo.SPLITTABLE_PROCESS_KEYED_ELEMENTS_URN),
+                        PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN),
                     new SplittableParDoViaKeyedWorkItems.OverrideFactory()))
             .add(
                 PTransformOverride.of(
diff --git 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkTransformOverrides.java
 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkTransformOverrides.java
index 9baef8f492a..09f3f96d8f2 100644
--- 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkTransformOverrides.java
+++ 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkTransformOverrides.java
@@ -22,7 +22,6 @@
 import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems;
 import org.apache.beam.runners.core.construction.PTransformMatchers;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
-import org.apache.beam.runners.core.construction.SplittableParDo;
 import org.apache.beam.sdk.runners.PTransformOverride;
 import org.apache.beam.sdk.transforms.PTransform;
 
@@ -40,7 +39,7 @@
           .add(
               PTransformOverride.of(
                   PTransformMatchers.urnEqualTo(
-                      SplittableParDo.SPLITTABLE_PROCESS_KEYED_ELEMENTS_URN),
+                      PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN),
                   new SplittableParDoViaKeyedWorkItems.OverrideFactory()))
           .add(
               PTransformOverride.of(
diff --git 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
index 4b5b5cdcfae..6613173233e 100644
--- 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
+++ 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
@@ -97,7 +97,7 @@ public void sdkErrorsSurfaceOnClose() throws Exception {
 
     @SuppressWarnings("unchecked")
     RemoteBundle<Integer> bundle = Mockito.mock(RemoteBundle.class);
-    when(stageBundleFactory.<Integer>getBundle(any(), 
any())).thenReturn(bundle);
+    when(stageBundleFactory.getBundle(any(), any())).thenReturn(bundle);
 
     @SuppressWarnings("unchecked")
     FnDataReceiver<WindowedValue<Integer>> receiver = 
Mockito.mock(FnDataReceiver.class);
@@ -126,7 +126,7 @@ public void expectedInputsAreSent() throws Exception {
 
     @SuppressWarnings("unchecked")
     RemoteBundle<Integer> bundle = Mockito.mock(RemoteBundle.class);
-    when(stageBundleFactory.<Integer>getBundle(any(), 
any())).thenReturn(bundle);
+    when(stageBundleFactory.getBundle(any(), any())).thenReturn(bundle);
 
     @SuppressWarnings("unchecked")
     FnDataReceiver<WindowedValue<Integer>> receiver = 
Mockito.mock(FnDataReceiver.class);
@@ -155,26 +155,20 @@ public void outputsAreTaggedCorrectly() throws Exception {
             "three", 3);
 
     // We use a real StageBundleFactory here in order to exercise the output 
receiver factory.
-    StageBundleFactory stageBundleFactory =
-        new StageBundleFactory() {
+    StageBundleFactory<Void> stageBundleFactory =
+        new StageBundleFactory<Void>() {
           @Override
-          public <InputT> RemoteBundle<InputT> getBundle(
-              OutputReceiverFactory receiverFactory, StateRequestHandler 
stateRequestHandler)
-              throws Exception {
-            return new RemoteBundle<InputT>() {
+          public RemoteBundle<Void> getBundle(
+              OutputReceiverFactory receiverFactory, StateRequestHandler 
stateRequestHandler) {
+            return new RemoteBundle<Void>() {
               @Override
               public String getId() {
                 return "bundle-id";
               }
 
               @Override
-              public FnDataReceiver<WindowedValue<InputT>> getInputReceiver() {
-                return new FnDataReceiver<WindowedValue<InputT>>() {
-                  @Override
-                  public void accept(WindowedValue<InputT> input) throws 
Exception {
-                    // Ignore input
-                  }
-                };
+              public FnDataReceiver<WindowedValue<Void>> getInputReceiver() {
+                return input -> {/* Ignore input*/};
               }
 
               @Override
diff --git 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
index 9d7068f6058..fd6c6411788 100644
--- 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
+++ 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
@@ -967,17 +967,19 @@ public void translate(
 
             translateInputs(
                 stepContext, context.getInput(transform), 
transform.getSideInputs(), context);
-                translateOutputs(context.getOutputs(transform), stepContext);
-            stepContext.addInput(
-                PropertyNames.SERIALIZED_FN,
-                byteArrayToJsonString(
-                    serializeToByteArray(
-                        DoFnInfo.forFn(
-                            transform.getFn(),
-                            transform.getInputWindowingStrategy(),
-                            transform.getSideInputs(),
-                            transform.getElementCoder(),
-                            transform.getMainOutputTag()))));
+            translateOutputs(context.getOutputs(transform), stepContext);
+            String ptransformId =
+                
context.getSdkComponents().getPTransformIdOrThrow(context.getCurrentTransform());
+            translateFn(
+                stepContext,
+                ptransformId,
+                transform.getFn(),
+                transform.getInputWindowingStrategy(),
+                transform.getSideInputs(),
+                transform.getElementCoder(),
+                context,
+                transform.getMainOutputTag());
+
             stepContext.addInput(
                 PropertyNames.RESTRICTION_CODER,
                 CloudObjects.asCloudObject(transform.getRestrictionCoder()));
@@ -1021,13 +1023,6 @@ private static void translateFn(
       TupleTag<?> mainOutput) {
 
     DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass());
-    if (signature.processElement().isSplittable()) {
-      throw new UnsupportedOperationException(
-          String.format(
-              "%s does not currently support splittable DoFn: %s",
-              DataflowRunner.class.getSimpleName(),
-              fn));
-    }
 
     if (signature.usesState() || signature.usesTimers()) {
       DataflowRunner.verifyStateSupported(fn);
@@ -1036,12 +1031,9 @@ private static void translateFn(
 
     stepContext.addInput(PropertyNames.USER_FN, fn.getClass().getName());
 
-    List<String> experiments = context.getPipelineOptions().getExperiments();
-    boolean isFnApi = experiments != null && 
experiments.contains("beam_fn_api");
-
     // Fn API does not need the additional metadata in the wrapper, and it is 
Java-only serializable
     // hence not suitable for portable execution
-    if (isFnApi) {
+    if (context.isFnApi()) {
       stepContext.addInput(PropertyNames.SERIALIZED_FN, ptransformId);
     } else {
       stepContext.addInput(
diff --git 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
index e7f6e6d4c12..89beafe79c2 100644
--- 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
+++ 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
@@ -18,6 +18,7 @@
 
 package org.apache.beam.runners.dataflow;
 
+import static com.google.common.base.Preconditions.checkArgument;
 import static 
org.apache.beam.runners.core.construction.PTransformTranslation.PAR_DO_TRANSFORM_URN;
 import static 
org.apache.beam.runners.core.construction.ParDoTranslation.translateTimerSpec;
 import static 
org.apache.beam.sdk.transforms.reflect.DoFnSignatures.getStateSpecOrThrow;
@@ -26,7 +27,6 @@
 import com.google.auto.service.AutoService;
 import com.google.common.collect.Iterables;
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -160,10 +160,15 @@ public String getUrn(ParDoSingle<?, ?> transform) {
     }
 
     private static RunnerApi.ParDoPayload payloadForParDoSingle(
-        final ParDoSingle<?, ?> parDo, SdkComponents components) throws 
IOException {
-
+        final ParDoSingle<?, ?> parDo, SdkComponents components)
+        throws IOException {
       final DoFn<?, ?> doFn = parDo.getFn();
       final DoFnSignature signature = 
DoFnSignatures.getSignature(doFn.getClass());
+      checkArgument(
+          !signature.processElement().isSplittable(),
+          String.format(
+              "Not expecting a splittable %s: should have been overridden",
+              ParDoSingle.class.getSimpleName()));
 
       return ParDoTranslation.payloadForParDoLike(
           new ParDoTranslation.ParDoLike() {
@@ -175,26 +180,13 @@ public String getUrn(ParDoSingle<?, ?> transform) {
 
             @Override
             public List<RunnerApi.Parameter> translateParameters() {
-              List<RunnerApi.Parameter> parameters = new ArrayList<>();
-              for (DoFnSignature.Parameter parameter :
-                  signature.processElement().extraParameters()) {
-                RunnerApi.Parameter protoParameter = 
ParDoTranslation.translateParameter(parameter);
-                if (protoParameter != null) {
-                  parameters.add(protoParameter);
-                }
-              }
-              return parameters;
+              return ParDoTranslation.translateParameters(
+                  signature.processElement().extraParameters());
             }
 
             @Override
             public Map<String, RunnerApi.SideInput> 
translateSideInputs(SdkComponents components) {
-              Map<String, RunnerApi.SideInput> sideInputs = new HashMap<>();
-              for (PCollectionView<?> sideInput : parDo.getSideInputs()) {
-                sideInputs.put(
-                    sideInput.getTagInternal().getId(),
-                    ParDoTranslation.translateView(sideInput, components));
-              }
-              return sideInputs;
+              return 
ParDoTranslation.translateSideInputs(parDo.getSideInputs(), components);
             }
 
             @Override
@@ -226,7 +218,12 @@ public String getUrn(ParDoSingle<?, ?> transform) {
 
             @Override
             public boolean isSplittable() {
-              return signature.processElement().isSplittable();
+              return false;
+            }
+
+            @Override
+            public String translateRestrictionCoderId(SdkComponents 
newComponents) {
+              return "";
             }
           },
           components);
diff --git 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
index 3ea97b535f5..3c6789e9b26 100644
--- 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
+++ 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
@@ -47,6 +47,11 @@
    * including reading and writing the values of {@link PCollection}s and side 
inputs.
    */
   interface TranslationContext {
+    default boolean isFnApi() {
+      List<String> experiments = getPipelineOptions().getExperiments();
+      return (experiments != null && experiments.contains("beam_fn_api"));
+    }
+
     /** Returns the configured pipeline options. */
     DataflowPipelineOptions getPipelineOptions();
 
diff --git 
a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
 
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
index ed6eb7c6118..19f81954bcf 100644
--- 
a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
+++ 
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
@@ -45,6 +45,7 @@
 import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
@@ -52,17 +53,23 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
+import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.ParDoTranslation;
+import org.apache.beam.runners.core.construction.RehydratedComponents;
 import 
org.apache.beam.runners.dataflow.DataflowPipelineTranslator.JobSpecification;
 import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions;
 import 
org.apache.beam.runners.dataflow.options.DataflowPipelineWorkerPoolOptions;
 import org.apache.beam.runners.dataflow.util.CloudObject;
 import org.apache.beam.runners.dataflow.util.CloudObjects;
-import org.apache.beam.runners.dataflow.util.OutputReference;
 import org.apache.beam.runners.dataflow.util.PropertyNames;
 import org.apache.beam.runners.dataflow.util.Structs;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.SerializableCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.extensions.gcp.auth.TestCredential;
@@ -137,9 +144,6 @@ private Pipeline buildPipeline(DataflowPipelineOptions 
options) {
     options.setRunner(DataflowRunner.class);
     Pipeline p = Pipeline.create(options);
 
-    // Enable the FileSystems API to know about gs:// URIs in this test.
-    FileSystems.setDefaultPipelineOptions(options);
-
     p.apply("ReadMyFile", TextIO.read().from("gs://bucket/object"))
         .apply("WriteMyFile", TextIO.write().to("gs://bucket/object"));
     DataflowRunner runner = DataflowRunner.fromOptions(options);
@@ -183,6 +187,10 @@ private static DataflowPipelineOptions 
buildPipelineOptions() throws IOException
     options.setFilesToStage(new LinkedList<>());
     options.setDataflowClient(buildMockDataflow(new IsValidCreateRequest()));
     options.setGcsUtil(mockGcsUtil);
+
+    // Enable the FileSystems API to know about gs:// URIs in this test.
+    FileSystems.setDefaultPipelineOptions(options);
+
     return options;
   }
 
@@ -411,47 +419,6 @@ public void testDiskSizeGbConfig() throws IOException {
         job.getEnvironment().getWorkerPools().get(0).getDiskSizeGb());
   }
 
-  /**
-   * Construct a OutputReference for the output of the step.
-   */
-  private static OutputReference getOutputPortReference(Step step) throws 
Exception {
-    // TODO: This should be done via a Structs accessor.
-    @SuppressWarnings("unchecked")
-    List<Map<String, Object>> output =
-        (List<Map<String, Object>>) 
step.getProperties().get(PropertyNames.OUTPUT_INFO);
-    String outputTagId = getString(Iterables.getOnlyElement(output), 
PropertyNames.OUTPUT_NAME);
-    return new OutputReference(step.getName(), outputTagId);
-  }
-
-  /**
-   * Returns a Step for a {@link DoFn} by creating and translating a pipeline.
-   */
-  private static Step createPredefinedStep() throws Exception {
-    DataflowPipelineOptions options = buildPipelineOptions();
-    DataflowPipelineTranslator translator = 
DataflowPipelineTranslator.fromOptions(options);
-    Pipeline pipeline = Pipeline.create(options);
-    String stepName = "DoFn1";
-    pipeline.apply("ReadMyFile", TextIO.read().from("gs://bucket/in"))
-        .apply(stepName, ParDo.of(new NoOpFn()))
-        .apply("WriteMyFile", TextIO.write().to("gs://bucket/out"));
-    DataflowRunner runner = DataflowRunner.fromOptions(options);
-    runner.replaceTransforms(pipeline);
-    Job job = translator.translate(pipeline, runner, 
Collections.emptyList()).getJob();
-
-    assertEquals(8, job.getSteps().size());
-    Step step = job.getSteps().get(1);
-    assertEquals(stepName, getString(step.getProperties(), 
PropertyNames.USER_NAME));
-    assertAllStepOutputsHaveUniqueIds(job);
-    return step;
-  }
-
-  private static class NoOpFn extends DoFn<String, String> {
-    @ProcessElement
-    public void processElement(ProcessContext c) throws Exception {
-      c.output(c.element());
-    }
-  }
-
   /**
    * A composite transform that returns an output that is unrelated to
    * the input.
@@ -799,6 +766,7 @@ public void testStreamingSplittableParDoTranslation() 
throws Exception {
     assertThat(
         fnInfo.getWindowingStrategy().getWindowFn(),
         
Matchers.<WindowFn>equalTo(FixedWindows.of(Duration.standardMinutes(1))));
+    assertThat(fnInfo.getInputCoder(), instanceOf(StringUtf8Coder.class));
     Coder<?> restrictionCoder =
         CloudObjects.coderFromCloudObject(
             (CloudObject)
@@ -808,6 +776,72 @@ public void testStreamingSplittableParDoTranslation() 
throws Exception {
     assertEquals(SerializableCoder.of(OffsetRange.class), restrictionCoder);
   }
 
+  /**
+   * Smoke test to fail fast if translation of a splittable ParDo
+   * in streaming breaks.
+   */
+  @Test
+  public void testStreamingSplittableParDoTranslationFnApi() throws Exception {
+    DataflowPipelineOptions options = buildPipelineOptions();
+    DataflowRunner runner = DataflowRunner.fromOptions(options);
+    options.setStreaming(true);
+    options.setExperiments(Arrays.asList("beam_fn_api"));
+    DataflowPipelineTranslator translator = 
DataflowPipelineTranslator.fromOptions(options);
+
+    Pipeline pipeline = Pipeline.create(options);
+
+    PCollection<String> windowedInput =
+        pipeline
+            .apply(Create.of("a"))
+            .apply(Window.into(FixedWindows.of(Duration.standardMinutes(1))));
+    windowedInput.apply(ParDo.of(new TestSplittableFn()));
+
+    runner.replaceTransforms(pipeline);
+
+    JobSpecification result = translator.translate(pipeline, runner, 
Collections.emptyList());
+
+    Job job = result.getJob();
+
+    // The job should contain a SplittableParDo.ProcessKeyedElements step, 
translated as
+    // "SplittableProcessKeyed".
+
+    List<Step> steps = job.getSteps();
+    Step processKeyedStep = null;
+    for (Step step : steps) {
+      if ("SplittableProcessKeyed".equals(step.getKind())) {
+        assertNull(processKeyedStep);
+        processKeyedStep = step;
+      }
+    }
+    assertNotNull(processKeyedStep);
+
+    String fn = Structs.getString(processKeyedStep.getProperties(), 
PropertyNames.SERIALIZED_FN);
+
+    Components componentsProto = result.getPipelineProto().getComponents();
+    RehydratedComponents components = RehydratedComponents
+        .forComponents(componentsProto);
+    RunnerApi.PTransform spkTransform = componentsProto
+        .getTransformsOrThrow(fn);
+    assertEquals(PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN,
+        spkTransform.getSpec().getUrn());
+    ParDoPayload payload = 
ParDoPayload.parseFrom(spkTransform.getSpec().getPayload());
+    assertThat(
+        
ParDoTranslation.doFnAndMainOutputTagFromProto(payload.getDoFn()).getDoFn(),
+        instanceOf(TestSplittableFn.class));
+    assertThat(
+        components.getCoder(payload.getRestrictionCoderId()), 
instanceOf(SerializableCoder.class));
+
+    // In the Fn API case, we still translate the restriction coder into the 
RESTRICTION_CODER
+    // property as a CloudObject, and it gets passed through the Dataflow 
backend, but in the end
+    // the Dataflow worker will end up fetching it from the SPK transform 
payload instead.
+    Coder<?> restrictionCoder =
+        CloudObjects.coderFromCloudObject(
+            (CloudObject)
+                Structs.getObject(
+                    processKeyedStep.getProperties(), 
PropertyNames.RESTRICTION_CODER));
+    assertEquals(SerializableCoder.of(OffsetRange.class), restrictionCoder);
+  }
+
   @Test
   public void testToSingletonTranslationWithIsmSideInput() throws Exception {
     // A "change detector" test that makes sure the translation
diff --git 
a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java
 
b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java
index d07250ca8e4..56ba5dcf641 100644
--- 
a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java
+++ 
b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java
@@ -50,6 +50,7 @@
 import org.apache.beam.sdk.fn.data.LogicalEndpoint;
 import org.apache.beam.sdk.fn.test.TestStreams;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -62,6 +63,7 @@
   private static final Coder<WindowedValue<String>> CODER =
       
LengthPrefixCoder.of(WindowedValue.getValueOnlyCoder(StringUtf8Coder.of()));
 
+  @Ignore("https://issues.apache.org/jira/browse/BEAM-4281";)
   @Test
   public void testMessageReceivedBySingleClientWhenThereAreMultipleClients()
       throws Exception {


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


Issue Time Tracking
-------------------

    Worklog Id:     (was: 101409)
    Time Spent: 5h 10m  (was: 5h)

> Use portable ParDoPayload for SDF in DataflowRunner
> ---------------------------------------------------
>
>                 Key: BEAM-3515
>                 URL: https://issues.apache.org/jira/browse/BEAM-3515
>             Project: Beam
>          Issue Type: Sub-task
>          Components: runner-dataflow
>            Reporter: Kenneth Knowles
>            Assignee: Eugene Kirpichov
>            Priority: Major
>              Labels: portability
>          Time Spent: 5h 10m
>  Remaining Estimate: 0h
>
> The Java-specific blobs transmitted to Dataflow need more context, in the 
> form of portability framework protos.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to