Repository: beam
Updated Branches:
  refs/heads/master 8479425b6 -> ea7940d88


BEAM-1037 Support for new State API in ApexRunner


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/575e36e5
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/575e36e5
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/575e36e5

Branch: refs/heads/master
Commit: 575e36e5cac603ec8a02250e3e4e2dc58af21379
Parents: aa2604a
Author: Thomas Weise <[email protected]>
Authored: Fri Sep 22 19:11:56 2017 -0700
Committer: Thomas Weise <[email protected]>
Committed: Sat Sep 30 08:29:23 2017 -0700

----------------------------------------------------------------------
 runners/apex/pom.xml                            |  1 -
 .../apex/translation/ParDoTranslator.java       | 24 ++------------------
 .../operators/ApexParDoOperator.java            | 22 ++++++++++++++----
 .../translation/utils/ApexStateInternals.java   | 13 ++++-------
 .../FlattenPCollectionTranslatorTest.java       |  1 -
 .../apex/translation/ParDoTranslatorTest.java   |  4 +---
 6 files changed, 25 insertions(+), 40 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/575e36e5/runners/apex/pom.xml
----------------------------------------------------------------------
diff --git a/runners/apex/pom.xml b/runners/apex/pom.xml
index 11d2f5c..0011788 100644
--- a/runners/apex/pom.xml
+++ b/runners/apex/pom.xml
@@ -218,7 +218,6 @@
               <groups>org.apache.beam.sdk.testing.ValidatesRunner</groups>
               <excludedGroups>
                 org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders,
-                org.apache.beam.sdk.testing.UsesStatefulParDo,
                 org.apache.beam.sdk.testing.UsesTimersInParDo,
                 org.apache.beam.sdk.testing.UsesSplittableParDo,
                 org.apache.beam.sdk.testing.UsesAttemptedMetrics,

http://git-wip-us.apache.org/repos/asf/beam/blob/575e36e5/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java
 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java
index be11b02..dd4bd67 100644
--- 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java
+++ 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java
@@ -31,14 +31,11 @@ import java.util.Map.Entry;
 import org.apache.beam.runners.apex.ApexRunner;
 import org.apache.beam.runners.apex.translation.operators.ApexParDoOperator;
 import 
org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems.ProcessElements;
-import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
-import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
-import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.PValue;
@@ -64,15 +61,6 @@ class ParDoTranslator<InputT, OutputT>
           String.format(
               "%s does not support splittable DoFn: %s", 
ApexRunner.class.getSimpleName(), doFn));
     }
-    if (signature.stateDeclarations().size() > 0) {
-      throw new UnsupportedOperationException(
-          String.format(
-              "Found %s annotations on %s, but %s cannot yet be used with 
state in the %s.",
-              DoFn.StateId.class.getSimpleName(),
-              doFn.getClass().getName(),
-              DoFn.class.getSimpleName(),
-              ApexRunner.class.getSimpleName()));
-    }
 
     if (signature.timerDeclarations().size() > 0) {
       throw new UnsupportedOperationException(
@@ -87,10 +75,6 @@ class ParDoTranslator<InputT, OutputT>
     Map<TupleTag<?>, PValue> outputs = context.getOutputs();
     PCollection<InputT> input = context.getInput();
     List<PCollectionView<?>> sideInputs = transform.getSideInputs();
-    Coder<InputT> inputCoder = input.getCoder();
-    WindowedValueCoder<InputT> wvInputCoder =
-        FullWindowedValueCoder.of(
-            inputCoder, 
input.getWindowingStrategy().getWindowFn().windowCoder());
 
     ApexParDoOperator<InputT, OutputT> operator = new ApexParDoOperator<>(
             context.getPipelineOptions(),
@@ -99,7 +83,7 @@ class ParDoTranslator<InputT, OutputT>
             transform.getAdditionalOutputTags().getAll(),
             input.getWindowingStrategy(),
             sideInputs,
-            wvInputCoder,
+            input.getCoder(),
             context.getStateBackend());
 
     Map<PCollection<?>, OutputPort<?>> ports = 
Maps.newHashMapWithExpectedSize(outputs.size());
@@ -144,10 +128,6 @@ class ParDoTranslator<InputT, OutputT>
       Map<TupleTag<?>, PValue> outputs = context.getOutputs();
       PCollection<InputT> input = context.getInput();
       List<PCollectionView<?>> sideInputs = transform.getSideInputs();
-      Coder<InputT> inputCoder = input.getCoder();
-      WindowedValueCoder<InputT> wvInputCoder =
-          FullWindowedValueCoder.of(
-              inputCoder, 
input.getWindowingStrategy().getWindowFn().windowCoder());
 
       @SuppressWarnings({ "rawtypes", "unchecked" })
       DoFn<InputT, OutputT> doFn = (DoFn) 
transform.newProcessFn(transform.getFn());
@@ -158,7 +138,7 @@ class ParDoTranslator<InputT, OutputT>
               transform.getAdditionalOutputTags().getAll(),
               input.getWindowingStrategy(),
               sideInputs,
-              wvInputCoder,
+              input.getCoder(),
               context.getStateBackend());
 
       Map<PCollection<?>, OutputPort<?>> ports = 
Maps.newHashMapWithExpectedSize(outputs.size());

http://git-wip-us.apache.org/repos/asf/beam/blob/575e36e5/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
----------------------------------------------------------------------
diff --git 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
index 4dc807d..a66bb5b 100644
--- 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
+++ 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
@@ -73,11 +73,14 @@ import org.apache.beam.sdk.state.TimeDomain;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.util.UserCodeException;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
 import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollectionView;
@@ -133,7 +136,7 @@ public class ApexParDoOperator<InputT, OutputT> extends 
BaseOperator implements
       List<TupleTag<?>> additionalOutputTags,
       WindowingStrategy<?, ?> windowingStrategy,
       List<PCollectionView<?>> sideInputs,
-      Coder<WindowedValue<InputT>> inputCoder,
+      Coder<InputT> linputCoder,
       ApexStateBackend stateBackend
       ) {
     this.pipelineOptions = new SerializablePipelineOptions(pipelineOptions);
@@ -151,10 +154,13 @@ public class ApexParDoOperator<InputT, OutputT> extends 
BaseOperator implements
       throw new UnsupportedOperationException(msg);
     }
 
-    Coder<List<WindowedValue<InputT>>> listCoder = ListCoder.of(inputCoder);
+    WindowedValueCoder<InputT> wvCoder =
+        FullWindowedValueCoder.of(
+            linputCoder, this.windowingStrategy.getWindowFn().windowCoder());
+    Coder<List<WindowedValue<InputT>>> listCoder = ListCoder.of(wvCoder);
     this.pushedBack = new ValueAndCoderKryoSerializable<>(new 
ArrayList<WindowedValue<InputT>>(),
         listCoder);
-    this.inputCoder = inputCoder;
+    this.inputCoder = wvCoder;
 
     TimerInternals.TimerDataCoder timerCoder =
         
TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder());
@@ -165,8 +171,16 @@ public class ApexParDoOperator<InputT, OutputT> extends 
BaseOperator implements
       Coder<?> keyCoder = StringUtf8Coder.of();
       this.currentKeyStateInternals = new StateInternalsProxy<>(
           stateBackend.newStateInternalsFactory(keyCoder));
+    } else {
+      DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
+      if (signature.usesState()) {
+        checkArgument(linputCoder instanceof KvCoder, "keyed input required 
for stateful DoFn");
+        @SuppressWarnings("rawtypes")
+        Coder<?> keyCoder = ((KvCoder) linputCoder).getKeyCoder();
+        this.currentKeyStateInternals = new StateInternalsProxy<>(
+            stateBackend.newStateInternalsFactory(keyCoder));
+      }
     }
-
   }
 
   @SuppressWarnings("unused") // for Kryo

http://git-wip-us.apache.org/repos/asf/beam/blob/575e36e5/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java
----------------------------------------------------------------------
diff --git 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java
 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java
index e23601d..978a793 100644
--- 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java
+++ 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java
@@ -37,7 +37,6 @@ import org.apache.beam.runners.core.StateNamespace;
 import org.apache.beam.runners.core.StateTag;
 import org.apache.beam.runners.core.StateTag.StateBinder;
 import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.Coder.Context;
 import org.apache.beam.sdk.coders.CoderException;
 import org.apache.beam.sdk.coders.InstantCoder;
 import org.apache.beam.sdk.coders.ListCoder;
@@ -141,7 +140,6 @@ public class ApexStateInternals<K> implements 
StateInternals {
           namespace,
           address,
           accumCoder,
-          key,
           combineFn
           );
     }
@@ -184,7 +182,7 @@ public class ApexStateInternals<K> implements 
StateInternals {
         // TODO: reuse input
         Input input = new Input(buf);
         try {
-          return coder.decode(input, Context.OUTER);
+          return coder.decode(input);
         } catch (IOException e) {
           throw new RuntimeException(e);
         }
@@ -195,7 +193,7 @@ public class ApexStateInternals<K> implements 
StateInternals {
     public void writeValue(T input) {
       ByteArrayOutputStream output = new ByteArrayOutputStream();
       try {
-        coder.encode(input, output, Context.OUTER);
+        coder.encode(input, output);
         stateTable.put(namespace.stringKey(), address.getId(), 
output.toByteArray());
       } catch (IOException e) {
         throw new RuntimeException(e);
@@ -306,15 +304,13 @@ public class ApexStateInternals<K> implements 
StateInternals {
   private final class ApexCombiningState<K, InputT, AccumT, OutputT>
       extends AbstractState<AccumT>
       implements CombiningState<InputT, AccumT, OutputT> {
-    private final K key;
     private final CombineFn<InputT, AccumT, OutputT> combineFn;
 
     private ApexCombiningState(StateNamespace namespace,
         StateTag<CombiningState<InputT, AccumT, OutputT>> address,
         Coder<AccumT> coder,
-        K key, CombineFn<InputT, AccumT, OutputT> combineFn) {
+        CombineFn<InputT, AccumT, OutputT> combineFn) {
       super(namespace, address, coder);
-      this.key = key;
       this.combineFn = combineFn;
     }
 
@@ -330,8 +326,7 @@ public class ApexStateInternals<K> implements 
StateInternals {
 
     @Override
     public void add(InputT input) {
-      AccumT accum = getAccum();
-      combineFn.addInput(accum, input);
+      AccumT accum = combineFn.addInput(getAccum(), input);
       writeValue(accum);
     }
 

http://git-wip-us.apache.org/repos/asf/beam/blob/575e36e5/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslatorTest.java
----------------------------------------------------------------------
diff --git 
a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslatorTest.java
 
b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslatorTest.java
index 929778a..1ad9622 100644
--- 
a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslatorTest.java
+++ 
b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslatorTest.java
@@ -53,7 +53,6 @@ public class FlattenPCollectionTranslatorTest {
   @Test
   public void test() throws Exception {
     ApexPipelineOptions options = 
PipelineOptionsFactory.as(ApexPipelineOptions.class);
-    options.setApplicationName("FlattenPCollection");
     options.setRunner(ApexRunner.class);
     Pipeline p = Pipeline.create(options);
 

http://git-wip-us.apache.org/repos/asf/beam/blob/575e36e5/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java
----------------------------------------------------------------------
diff --git 
a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java
 
b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java
index 736b0e7..73382e3 100644
--- 
a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java
+++ 
b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java
@@ -42,7 +42,6 @@ import 
org.apache.beam.runners.apex.translation.operators.ApexReadUnboundedInput
 import org.apache.beam.runners.apex.translation.utils.ApexStateInternals;
 import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple;
 import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.coders.VoidCoder;
@@ -202,7 +201,6 @@ public class ParDoTranslatorTest {
         .as(ApexPipelineOptions.class);
     options.setRunner(TestApexRunner.class);
     Pipeline pipeline = Pipeline.create(options);
-    Coder<WindowedValue<Integer>> coder = 
WindowedValue.getValueOnlyCoder(VarIntCoder.of());
 
     PCollectionView<Integer> singletonView = pipeline.apply(Create.of(1))
             .apply(Sum.integersGlobally().asSingletonView());
@@ -215,7 +213,7 @@ public class ParDoTranslatorTest {
             TupleTagList.empty().getAll(),
             WindowingStrategy.globalDefault(),
             Collections.<PCollectionView<?>>singletonList(singletonView),
-            coder,
+            VarIntCoder.of(),
             new ApexStateInternals.ApexStateBackend());
     operator.setup(null);
     operator.beginWindow(0);

Reply via email to