BEAM-261 Support multiple side inputs.

Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/1ec7cd91
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/1ec7cd91
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/1ec7cd91

Branch: refs/heads/apex-runner
Commit: 1ec7cd9129fc31ece7554e2ea18535ce15e46bcf
Parents: fd7f46c
Author: Thomas Weise <t...@apache.org>
Authored: Thu Oct 13 14:38:06 2016 -0700
Committer: Thomas Weise <t...@apache.org>
Committed: Mon Oct 17 09:22:49 2016 -0700

----------------------------------------------------------------------
 .../runners/apex/ApexPipelineTranslator.java    | 19 ++++++-
 .../apache/beam/runners/apex/ApexRunner.java    |  7 ++-
 .../beam/runners/apex/ApexRunnerResult.java     |  7 +++
 .../FlattenPCollectionTranslator.java           | 38 +++++++++++---
 .../translators/ParDoBoundMultiTranslator.java  | 55 +++++++++++++++++---
 .../apex/translators/ParDoBoundTranslator.java  | 14 +----
 .../functions/ApexFlattenOperator.java          | 11 ++++
 .../functions/ApexParDoOperator.java            | 13 +++--
 .../apex/translators/utils/ApexStreamTuple.java | 22 ++++++--
 9 files changed, 148 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1ec7cd91/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexPipelineTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexPipelineTranslator.java
 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexPipelineTranslator.java
index 40edfb1..a16f551 100644
--- 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexPipelineTranslator.java
+++ 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexPipelineTranslator.java
@@ -37,6 +37,7 @@ import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.View.CreatePCollectionView;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.PValue;
 import org.slf4j.Logger;
@@ -74,7 +75,8 @@ public class ApexPipelineTranslator implements 
Pipeline.PipelineVisitor {
     registerTransformTranslator(Flatten.FlattenPCollectionList.class,
         new FlattenPCollectionTranslator());
     registerTransformTranslator(Create.Values.class, new 
CreateValuesTranslator());
-    registerTransformTranslator(CreateApexPCollectionView.class, new 
CreatePCollectionViewTranslator());
+    registerTransformTranslator(CreateApexPCollectionView.class, new 
CreateApexPCollectionViewTranslator());
+    registerTransformTranslator(CreatePCollectionView.class, new 
CreatePCollectionViewTranslator());
   }
 
   public ApexPipelineTranslator(TranslationContext translationContext) {
@@ -151,7 +153,7 @@ public class ApexPipelineTranslator implements 
Pipeline.PipelineVisitor {
 
   }
 
-  private static class CreatePCollectionViewTranslator<ElemT, ViewT> 
implements TransformTranslator<CreateApexPCollectionView<ElemT, ViewT>>
+  private static class CreateApexPCollectionViewTranslator<ElemT, ViewT> 
implements TransformTranslator<CreateApexPCollectionView<ElemT, ViewT>>
   {
     private static final long serialVersionUID = 1L;
 
@@ -164,4 +166,17 @@ public class ApexPipelineTranslator implements 
Pipeline.PipelineVisitor {
     }
   }
 
+  private static class CreatePCollectionViewTranslator<ElemT, ViewT> 
implements TransformTranslator<CreatePCollectionView<ElemT, ViewT>>
+  {
+    private static final long serialVersionUID = 1L;
+
+    @Override
+    public void translate(CreatePCollectionView<ElemT, ViewT> transform, 
TranslationContext context)
+    {
+      PCollectionView<ViewT> view = transform.getView();
+      context.addView(view);
+      LOG.debug("view {}", view.getName());
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1ec7cd91/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java
----------------------------------------------------------------------
diff --git 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java
index ad49f08..667f1c8 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java
@@ -74,7 +74,7 @@ public class ApexRunner extends 
PipelineRunner<ApexRunnerResult> {
   /**
    * TODO: this isn't thread safe and may cause issues when tests run in 
parallel
    * Holds any most resent assertion error that was raised while processing 
elements.
-   * Used in the unit test driver in embedded to propagate the exception.
+   * Used in the unit test driver in embedded mode to propagate the exception.
    */
   public static volatile AssertionError assertionError;
 
@@ -100,6 +100,8 @@ public class ApexRunner extends 
PipelineRunner<ApexRunnerResult> {
               WindowingStrategy.globalDefault(),
               PCollection.IsBounded.BOUNDED);
 // TODO: replace this with a mapping
+////
+
     } else if 
(Combine.GloballyAsSingletonView.class.equals(transform.getClass())) {
       PTransform<InputT, OutputT> customTransform = (PTransform)new 
StreamingCombineGloballyAsSingletonView<InputT, OutputT>(this,
           (Combine.GloballyAsSingletonView)transform);
@@ -109,6 +111,7 @@ public class ApexRunner extends 
PipelineRunner<ApexRunnerResult> {
       PTransform<InputT, OutputT> customTransform = (PTransform)new 
StreamingViewAsSingleton<InputT>(this,
           (View.AsSingleton)transform);
       return Pipeline.applyTransform(input, customTransform);
+/*
     } else if (View.AsIterable.class.equals(transform.getClass())) {
       PTransform<InputT, OutputT> customTransform = (PTransform)new 
StreamingViewAsIterable<InputT>(this,
           (View.AsIterable)transform);
@@ -125,6 +128,8 @@ public class ApexRunner extends 
PipelineRunner<ApexRunnerResult> {
       PTransform<InputT, OutputT> customTransform = new 
StreamingViewAsMultimap(this,
           (View.AsMultimap)transform);
       return Pipeline.applyTransform(input, customTransform);
+*/
+////
     } else {
       return super.apply(transform, input);
     }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1ec7cd91/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunnerResult.java
----------------------------------------------------------------------
diff --git 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunnerResult.java 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunnerResult.java
index f28c8dc..6817684 100644
--- 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunnerResult.java
+++ 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunnerResult.java
@@ -19,6 +19,7 @@ package org.apache.beam.runners.apex;
 
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.PipelineResult;
+import org.apache.beam.sdk.metrics.MetricResults;
 
 import java.io.IOException;
 
@@ -74,6 +75,12 @@ public class ApexRunnerResult implements PipelineResult {
     throw new UnsupportedOperationException();
   }
 
+  @Override
+  public MetricResults metrics()
+  {
+    throw new UnsupportedOperationException();
+  }
+
   /**
    * Return the DAG executed by the pipeline.
    * @return

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1ec7cd91/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/FlattenPCollectionTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/FlattenPCollectionTranslator.java
 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/FlattenPCollectionTranslator.java
index 90ab81f..6737767 100644
--- 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/FlattenPCollectionTranslator.java
+++ 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/FlattenPCollectionTranslator.java
@@ -20,6 +20,7 @@ package org.apache.beam.runners.apex.translators;
 
 import java.util.Collections;
 import java.util.List;
+import java.util.Map;
 
 import org.apache.beam.runners.apex.translators.functions.ApexFlattenOperator;
 import 
org.apache.beam.runners.apex.translators.io.ApexReadUnboundedInputOperator;
@@ -53,9 +54,25 @@ public class FlattenPCollectionTranslator<T> implements
       ApexReadUnboundedInputOperator<T, ?> operator = new 
ApexReadUnboundedInputOperator<>(
           unboundedSource, context.getPipelineOptions());
       context.addOperator(operator, operator.output);
-      return;
+    } else {
+      PCollection<T> output = context.getOutput();
+      Map<PCollection<?>, Integer> unionTags = Collections.emptyMap();
+      flattenCollections(collections, unionTags, output, context);
     }
+  }
 
+  /**
+   * Flatten the given collections into the given result collection. Translates
+   * into a cascading merge with 2 input ports per operator. The optional union
+   * tags can be used to identify the source in the result stream, used to
+   * channel multiple side inputs to a single Apex operator port.
+   *
+   * @param collections
+   * @param unionTags
+   * @param finalCollection
+   * @param context
+   */
+  static <T> void flattenCollections(List<PCollection<T>> collections, 
Map<PCollection<?>, Integer> unionTags, PCollection<T> finalCollection, 
TranslationContext context) {
     List<PCollection<T>> remainingCollections = Lists.newArrayList();
     PCollection<T> firstCollection = null;
     while (!collections.isEmpty()) {
@@ -65,14 +82,23 @@ public class FlattenPCollectionTranslator<T> implements
         } else {
           ApexFlattenOperator<T> operator = new ApexFlattenOperator<>();
           context.addStream(firstCollection, operator.data1);
+          Integer unionTag = unionTags.get(firstCollection);
+          operator.data1Tag = (unionTag != null) ? unionTag : 0;
           context.addStream(collection, operator.data2);
+          unionTag = unionTags.get(collection);
+          operator.data2Tag = (unionTag != null) ? unionTag : 0;
+
+          if (!collection.getCoder().equals(firstCollection.getCoder())) {
+              throw new UnsupportedOperationException("coders don't match");
+          }
+
           if (collections.size() > 2) {
-            PCollection<T> resultCollection = 
intermediateCollection(collection, collection.getCoder());
-            context.addOperator(operator, operator.out, resultCollection);
-            remainingCollections.add(resultCollection);
+            PCollection<T> intermediateCollection = 
intermediateCollection(collection, collection.getCoder());
+            context.addOperator(operator, operator.out, 
intermediateCollection);
+            remainingCollections.add(intermediateCollection);
           } else {
             // final stream merge
-            context.addOperator(operator, operator.out);
+            context.addOperator(operator, operator.out, finalCollection);
           }
           firstCollection = null;
         }
@@ -91,7 +117,7 @@ public class FlattenPCollectionTranslator<T> implements
     }
   }
 
-  public static <T> PCollection<T> intermediateCollection(PCollection<T> 
input, Coder<T> outputCoder) {
+  static <T> PCollection<T> intermediateCollection(PCollection<T> input, 
Coder<T> outputCoder) {
     PCollection<T> output = 
PCollection.createPrimitiveOutputInternal(input.getPipeline(), 
input.getWindowingStrategy(), input.isBounded());
     output.setCoder(outputCoder);
     return output;

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1ec7cd91/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java
 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java
index 9c5f2b5..a229a81 100644
--- 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java
+++ 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java
@@ -18,6 +18,10 @@
 
 package org.apache.beam.runners.apex.translators;
 
+import static com.google.common.base.Preconditions.checkArgument;
+
+import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
@@ -63,20 +67,55 @@ public class ParDoBoundMultiTranslator<InputT, OutputT> 
implements TransformTran
       ports.put(outputEntry.getValue(), operator.sideOutputPorts[i++]);
     }
     context.addOperator(operator, ports);
-
     context.addStream(context.getInput(), operator.input);
     if (!sideInputs.isEmpty()) {
-      Operator.InputPort<?>[] sideInputPorts = {operator.sideInput1};
-      for (i=0; i<sideInputs.size(); i++) {
+      addSideInputs(operator, sideInputs, context);
+    }
+  }
+
+  static void addSideInputs(ApexParDoOperator<?, ?> operator, 
List<PCollectionView<?>> sideInputs, TranslationContext context) {
+    Operator.InputPort<?>[] sideInputPorts = {operator.sideInput1};
+    if (sideInputs.size() > sideInputPorts.length) {
+      //  String msg = String.format("Too many side inputs in %s (currently 
only supporting %s).",
+      //      transform.toString(), sideInputPorts.length);
+      //  throw new UnsupportedOperationException(msg);
+      PCollection<?> unionCollection = unionSideInputs(sideInputs, context);
+      context.addStream(unionCollection, sideInputPorts[0]);
+    } else {
+      for (int i=0; i<sideInputs.size(); i++) {
         // the number of input ports for side inputs are fixed and each port 
can only take one input.
         // more (optional) ports can be added to give reasonable capacity or 
an explicit union operation introduced.
-        if (i == sideInputPorts.length) {
-          String msg = String.format("Too many side inputs in %s (currently 
only supporting %s).",
-              transform.toString(), sideInputPorts.length);
-          throw new UnsupportedOperationException(msg);
-        }
         context.addStream(context.getViewInput(sideInputs.get(i)), 
sideInputPorts[i]);
       }
     }
   }
+
+  private static PCollection<?> unionSideInputs(List<PCollectionView<?>> 
sideInputs, TranslationContext context) {
+    checkArgument(sideInputs.size() > 1, "requires multiple side inputs");
+    // flatten and assign union tag
+    List<PCollection<Object>> sourceCollections = new ArrayList<>();
+    Map<PCollection<?>, Integer> unionTags = new HashMap<>();
+    PCollection<Object> firstSideInput = 
context.getViewInput(sideInputs.get(0));
+    for (int i=0; i < sideInputs.size(); i++) {
+      PCollectionView<?> sideInput = sideInputs.get(i);
+      PCollection<?> sideInputCollection = context.getViewInput(sideInput);
+      if 
(!sideInputCollection.getWindowingStrategy().equals(firstSideInput.getWindowingStrategy()))
 {
+        // TODO: check how to handle this in stream codec
+        //String msg = "Multiple side inputs with different window 
strategies.";
+        //throw new UnsupportedOperationException(msg);
+      }
+      if (!sideInputCollection.getCoder().equals(firstSideInput.getCoder())) {
+        String msg = "Multiple side inputs with different coders.";
+        throw new UnsupportedOperationException(msg);
+      }
+      
sourceCollections.add(context.<PCollection<Object>>getViewInput(sideInput));
+      unionTags.put(sideInputCollection, i);
+    }
+
+    PCollection<Object> resultCollection = 
FlattenPCollectionTranslator.intermediateCollection(firstSideInput, 
firstSideInput.getCoder());
+    FlattenPCollectionTranslator.flattenCollections(sourceCollections, 
unionTags, resultCollection, context);
+    return resultCollection;
+
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1ec7cd91/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java
 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java
index 8a7dd4b..7749a06 100644
--- 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java
+++ 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java
@@ -32,8 +32,6 @@ import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TupleTagList;
 
-import com.datatorrent.api.Operator;
-
 /**
  * {@link ParDo.Bound} is translated to Apex operator that wraps the {@link 
DoFn}
  */
@@ -57,17 +55,7 @@ public class ParDoBoundTranslator<InputT, OutputT> implements
     context.addOperator(operator, operator.output);
     context.addStream(context.getInput(), operator.input);
     if (!sideInputs.isEmpty()) {
-      Operator.InputPort<?>[] sideInputPorts = {operator.sideInput1};
-      for (int i=0; i<sideInputs.size(); i++) {
-        // the number of input ports for side inputs are fixed and each port 
can only take one input.
-        // more (optional) ports can be added to give reasonable capacity or 
an explicit union operation introduced.
-        if (i == sideInputPorts.length) {
-          String msg = String.format("Too many side inputs in %s (currently 
only supporting %s).",
-              transform.toString(), sideInputPorts.length);
-          throw new UnsupportedOperationException(msg);
-        }
-        context.addStream(context.getViewInput(sideInputs.get(i)), 
sideInputPorts[i]);
-      }
+       ParDoBoundMultiTranslator.addSideInputs(operator, sideInputs, context);
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1ec7cd91/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexFlattenOperator.java
----------------------------------------------------------------------
diff --git 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexFlattenOperator.java
 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexFlattenOperator.java
index 4675244..202f2d3 100644
--- 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexFlattenOperator.java
+++ 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexFlattenOperator.java
@@ -41,6 +41,9 @@ public class ApexFlattenOperator<InputT> extends BaseOperator
   private long inputWM2;
   private long outputWM;
 
+  public int data1Tag;
+  public int data2Tag;
+
   /**
    * Data input port 1.
    */
@@ -70,6 +73,10 @@ public class ApexFlattenOperator<InputT> extends BaseOperator
       if (traceTuples) {
         LOG.debug("\nemitting {}\n", tuple);
       }
+
+      if (data1Tag > 0 && tuple instanceof ApexStreamTuple.DataTuple) {
+        ((ApexStreamTuple.DataTuple<?>)tuple).setUnionTag(data1Tag);
+      }
       out.emit(tuple);
     }
   };
@@ -103,6 +110,10 @@ public class ApexFlattenOperator<InputT> extends 
BaseOperator
       if (traceTuples) {
         LOG.debug("\nemitting {}\n", tuple);
       }
+
+      if (data2Tag > 0 && tuple instanceof ApexStreamTuple.DataTuple) {
+        ((ApexStreamTuple.DataTuple<?>)tuple).setUnionTag(data2Tag);
+      }
       out.emit(tuple);
     }
   };

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1ec7cd91/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java
----------------------------------------------------------------------
diff --git 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java
 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java
index a951ca7..96be11d 100644
--- 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java
+++ 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java
@@ -34,7 +34,6 @@ import org.apache.beam.runners.core.SideInputHandler;
 import org.apache.beam.runners.core.DoFnRunners.OutputManager;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.ListCoder;
-import org.apache.beam.sdk.repackaged.com.google.common.base.Throwables;
 import org.apache.beam.sdk.transforms.Aggregator;
 import org.apache.beam.sdk.transforms.Aggregator.AggregatorFactory;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
@@ -60,6 +59,7 @@ import 
com.datatorrent.api.annotation.InputPortFieldAnnotation;
 import com.datatorrent.api.annotation.OutputPortFieldAnnotation;
 import com.datatorrent.common.util.BaseOperator;
 import com.esotericsoftware.kryo.serializers.FieldSerializer.Bind;
+import com.google.common.base.Throwables;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Maps;
 import com.esotericsoftware.kryo.serializers.JavaSerializer;
@@ -158,8 +158,6 @@ private transient StateInternals<Void> 
sideInputStateInternals = InMemoryStateIn
   @InputPortFieldAnnotation(optional=true)
   public final transient 
DefaultInputPort<ApexStreamTuple<WindowedValue<Iterable<?>>>> sideInput1 = new 
DefaultInputPort<ApexStreamTuple<WindowedValue<Iterable<?>>>>()
   {
-    private final int sideInputIndex = 0;
-
     @Override
     public void process(ApexStreamTuple<WindowedValue<Iterable<?>>> t)
     {
@@ -167,9 +165,16 @@ private transient StateInternals<Void> 
sideInputStateInternals = InMemoryStateIn
         // ignore side input watermarks
         return;
       }
+
+      int sideInputIndex = 0;
+      if (t instanceof ApexStreamTuple.DataTuple) {
+        sideInputIndex = ((ApexStreamTuple.DataTuple<?>)t).getUnionTag();
+      }
+
       if (traceTuples) {
-        LOG.debug("\nsideInput {}\n", t.getValue());
+        LOG.debug("\nsideInput {} {}\n", sideInputIndex, t.getValue());
       }
+
       PCollectionView<?> sideInput = sideInputs.get(sideInputIndex);
       sideInputHandler.addSideInputValue(sideInput, t.getValue());
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1ec7cd91/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/utils/ApexStreamTuple.java
----------------------------------------------------------------------
diff --git 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/utils/ApexStreamTuple.java
 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/utils/ApexStreamTuple.java
index 06940aa..c9bf6dc 100644
--- 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/utils/ApexStreamTuple.java
+++ 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/utils/ApexStreamTuple.java
@@ -50,15 +50,17 @@ public interface ApexStreamTuple<T>
    */
   class DataTuple<T> implements ApexStreamTuple<T>
   {
+    private int unionTag;
     private T value;
 
     public static <T> DataTuple<T> of(T value) {
-      return new DataTuple<>(value);
+      return new DataTuple<>(value, 0);
     }
 
-    private DataTuple(T value)
+    private DataTuple(T value, int unionTag)
     {
       this.value = value;
+      this.unionTag = unionTag;
     }
 
     @Override
@@ -72,6 +74,16 @@ public interface ApexStreamTuple<T>
       this.value = value;
     }
 
+    public int getUnionTag()
+    {
+      return unionTag;
+    }
+
+    public void setUnionTag(int unionTag)
+    {
+      this.unionTag = unionTag;
+    }
+
     @Override
     public String toString()
     {
@@ -91,7 +103,7 @@ public interface ApexStreamTuple<T>
 
     public TimestampedTuple(long timestamp, T value)
     {
-      super(value);
+      super(value, 0);
       this.timestamp = timestamp;
     }
 
@@ -152,6 +164,7 @@ public interface ApexStreamTuple<T>
         new 
DataOutputStream(outStream).writeLong(((WatermarkTuple<?>)value).getTimestamp());
       } else {
         outStream.write(0);
+        outStream.write(((DataTuple<?>)value).unionTag);
         valueCoder.encode(value.getValue(), outStream, context);
       }
     }
@@ -164,7 +177,8 @@ public interface ApexStreamTuple<T>
       if (b == 1) {
         return new WatermarkTuple<T>(new DataInputStream(inStream).readLong());
       } else {
-        return new DataTuple<T>(valueCoder.decode(inStream, context));
+        int unionTag = inStream.read();
+        return new DataTuple<T>(valueCoder.decode(inStream, context), 
unionTag);
       }
     }
 

Reply via email to