[
https://issues.apache.org/jira/browse/BEAM-4076?focusedWorklogId=126962&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-126962
]
ASF GitHub Bot logged work on BEAM-4076:
----------------------------------------
Author: ASF GitHub Bot
Created on: 24/Jul/18 23:45
Start Date: 24/Jul/18 23:45
Worklog Time Spent: 10m
Work Description: reuvenlax closed pull request #5955: [BEAM-4076] Enable
schemas for more runners
URL: https://github.com/apache/beam/pull/5955
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/runners/apex/build.gradle b/runners/apex/build.gradle
index dbe19e8efa4..a2bfdec355a 100644
--- a/runners/apex/build.gradle
+++ b/runners/apex/build.gradle
@@ -93,7 +93,6 @@ task validatesRunnerBatch(type: Test) {
excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics'
excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse'
excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle'
- excludeCategories 'org.apache.beam.sdk.testing.UsesSchema'
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream'
excludeCategories 'org.apache.beam.sdk.testing.UsesTimersInParDo'
excludeCategories 'org.apache.beam.sdk.testing.UsesMetricsPusher'
diff --git
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java
index 32113a97630..d44d18c849c 100644
---
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java
+++
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java
@@ -28,9 +28,11 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
+import java.util.stream.Collectors;
import org.apache.beam.runners.apex.ApexRunner;
import org.apache.beam.runners.apex.translation.operators.ApexParDoOperator;
import
org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems.ProcessElements;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
@@ -76,6 +78,13 @@ public void translate(ParDo.MultiOutput<InputT, OutputT>
transform, TranslationC
PCollection<InputT> input = context.getInput();
List<PCollectionView<?>> sideInputs = transform.getSideInputs();
+ Map<TupleTag<?>, Coder<?>> outputCoders =
+ outputs
+ .entrySet()
+ .stream()
+ .filter(e -> e.getValue() instanceof PCollection)
+ .collect(
+ Collectors.toMap(e -> e.getKey(), e -> ((PCollection)
e.getValue()).getCoder()));
ApexParDoOperator<InputT, OutputT> operator =
new ApexParDoOperator<>(
context.getPipelineOptions(),
@@ -85,6 +94,7 @@ public void translate(ParDo.MultiOutput<InputT, OutputT>
transform, TranslationC
input.getWindowingStrategy(),
sideInputs,
input.getCoder(),
+ outputCoders,
context.getStateBackend());
Map<PCollection<?>, OutputPort<?>> ports =
Maps.newHashMapWithExpectedSize(outputs.size());
@@ -130,6 +140,14 @@ public void translate(
PCollection<InputT> input = context.getInput();
List<PCollectionView<?>> sideInputs = transform.getSideInputs();
+ Map<TupleTag<?>, Coder<?>> outputCoders =
+ outputs
+ .entrySet()
+ .stream()
+ .filter(e -> e.getValue() instanceof PCollection)
+ .collect(
+ Collectors.toMap(e -> e.getKey(), e -> ((PCollection)
e.getValue()).getCoder()));
+
@SuppressWarnings({"rawtypes", "unchecked"})
DoFn<InputT, OutputT> doFn = (DoFn)
transform.newProcessFn(transform.getFn());
ApexParDoOperator<InputT, OutputT> operator =
@@ -140,7 +158,8 @@ public void translate(
transform.getAdditionalOutputTags().getAll(),
input.getWindowingStrategy(),
sideInputs,
- null,
+ input.getCoder(),
+ outputCoders,
context.getStateBackend());
Map<PCollection<?>, OutputPort<?>> ports =
Maps.newHashMapWithExpectedSize(outputs.size());
diff --git
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
index f9d20520e2b..577835238e4 100644
---
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
+++
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
@@ -117,7 +117,13 @@
private final List<PCollectionView<?>> sideInputs;
@Bind(JavaSerializer.class)
- private final Coder<WindowedValue<InputT>> inputCoder;
+ private final Coder<WindowedValue<InputT>> windowedInputCoder;
+
+ @Bind(JavaSerializer.class)
+ private final Coder<InputT> inputCoder;
+
+ @Bind(JavaSerializer.class)
+ private final Map<TupleTag<?>, Coder<?>> outputCoders;
private StateInternalsProxy<?> currentKeyStateInternals;
private final ApexTimerInternals<Object> currentKeyTimerInternals;
@@ -142,6 +148,7 @@ public ApexParDoOperator(
WindowingStrategy<?, ?> windowingStrategy,
List<PCollectionView<?>> sideInputs,
Coder<InputT> inputCoder,
+ Map<TupleTag<?>, Coder<?>> outputCoders,
ApexStateBackend stateBackend) {
this.pipelineOptions = new SerializablePipelineOptions(pipelineOptions);
this.doFn = doFn;
@@ -164,7 +171,9 @@ public ApexParDoOperator(
FullWindowedValueCoder.of(inputCoder,
this.windowingStrategy.getWindowFn().windowCoder());
Coder<List<WindowedValue<InputT>>> listCoder = ListCoder.of(wvCoder);
this.pushedBack = new ValueAndCoderKryoSerializable<>(new ArrayList<>(),
listCoder);
- this.inputCoder = wvCoder;
+ this.windowedInputCoder = wvCoder;
+ this.inputCoder = inputCoder;
+ this.outputCoders = outputCoders;
TimerInternals.TimerDataCoder timerCoder =
TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder());
@@ -197,7 +206,9 @@ private ApexParDoOperator() {
this.sideInputs = null;
this.pushedBack = null;
this.sideInputStateInternals = null;
+ this.windowedInputCoder = null;
this.inputCoder = null;
+ this.outputCoders = Collections.emptyMap();
this.currentKeyTimerInternals = null;
}
@@ -310,7 +321,7 @@ public void
process(ApexStreamTuple<WindowedValue<Iterable<?>>> t) {
final Object key;
final Coder<Object> keyCoder;
@SuppressWarnings({"rawtypes", "unchecked"})
- WindowedValueCoder<InputT> wvCoder = (WindowedValueCoder) inputCoder;
+ WindowedValueCoder<InputT> wvCoder = (WindowedValueCoder)
windowedInputCoder;
if (value instanceof KeyedWorkItem) {
key = ((KeyedWorkItem) value).key();
@SuppressWarnings({"rawtypes", "unchecked"})
@@ -453,8 +464,8 @@ public TimerInternals timerInternals() {
mainOutputTag,
additionalOutputTags,
stepContext,
- null,
- Collections.emptyMap(),
+ inputCoder,
+ outputCoders,
windowingStrategy);
doFnInvoker = DoFnInvokers.invokerFor(doFn);
diff --git
a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java
b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java
index 3bb0bc9dd1b..a2642b4a58f 100644
---
a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java
+++
b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java
@@ -212,6 +212,7 @@ public void testSerialization() throws Exception {
WindowingStrategy.globalDefault(),
Collections.singletonList(singletonView),
VarIntCoder.of(),
+ Collections.emptyMap(),
new ApexStateInternals.ApexStateBackend());
operator.setup(null);
operator.beginWindow(0);
diff --git
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
index efb0c96e440..26e5656a545 100644
---
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
+++
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
@@ -34,6 +34,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
@@ -300,6 +301,15 @@ public static TupleTagList
getAdditionalOutputTags(AppliedPTransform<?, ?, ?> ap
return TupleTagList.of(additionalOutputTags);
}
+ public static Map<TupleTag<?>, Coder<?>>
getOutputCoders(AppliedPTransform<?, ?, ?> application) {
+ return application
+ .getOutputs()
+ .entrySet()
+ .stream()
+ .filter(e -> e.getValue() instanceof PCollection)
+ .collect(Collectors.toMap(e -> e.getKey(), e -> ((PCollection)
e.getValue()).getCoder()));
+ }
+
public static List<PCollectionView<?>> getSideInputs(AppliedPTransform<?, ?,
?> application)
throws IOException {
PTransform<?, ?> transform = application.getTransform();
diff --git a/runners/flink/build.gradle b/runners/flink/build.gradle
index 0a80ad4e04d..e19529c2a16 100644
--- a/runners/flink/build.gradle
+++ b/runners/flink/build.gradle
@@ -113,7 +113,6 @@ def createValidatesRunnerTask(Map m) {
} else {
excludeCategories
'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
}
- excludeCategories 'org.apache.beam.sdk.testing.UsesSchema'
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream'
}
}
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
index 78e0f39efb4..08e520c97b2 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
@@ -571,12 +571,11 @@ public void translateNode(
throw new RuntimeException(e);
}
+ Map<TupleTag<?>, Coder<?>> outputCoderMap = context.getOutputCoders();
+
String fullName = getCurrentTransformName(context);
if (usesStateOrTimers) {
- // Based on the fact that the signature is stateful, DoFnSignatures
ensures
- // that it is also keyed
- KvCoder<?, InputT> inputCoder = (KvCoder<?, InputT>)
context.getInput(transform).getCoder();
-
+ KvCoder<?, ?> inputCoder = (KvCoder<?, ?>)
context.getInput(transform).getCoder();
FlinkStatefulDoFnFunction<?, ?, OutputT> doFnWrapper =
new FlinkStatefulDoFnFunction<>(
(DoFn) doFn,
@@ -585,8 +584,12 @@ public void translateNode(
sideInputStrategies,
context.getPipelineOptions(),
outputMap,
- (TupleTag<OutputT>) mainOutputTag);
+ (TupleTag<OutputT>) mainOutputTag,
+ inputCoder,
+ outputCoderMap);
+ // Based on the fact that the signature is stateful, DoFnSignatures
ensures
+ // that it is also keyed.
Grouping<WindowedValue<InputT>> grouping =
inputDataSet.groupBy(new KvKeySelector(inputCoder.getKeyCoder()));
@@ -601,7 +604,9 @@ public void translateNode(
sideInputStrategies,
context.getPipelineOptions(),
outputMap,
- mainOutputTag);
+ mainOutputTag,
+ context.getInput(transform).getCoder(),
+ outputCoderMap);
outputDataSet =
new MapPartitionOperator<>(inputDataSet, typeInformation,
doFnWrapper, fullName);
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java
index 5a2823f5d3d..5020db98e89 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java
@@ -20,6 +20,7 @@
import com.google.common.collect.Iterables;
import java.util.HashMap;
import java.util.Map;
+import java.util.stream.Collectors;
import org.apache.beam.runners.core.construction.TransformInputs;
import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
import org.apache.beam.sdk.coders.Coder;
@@ -108,6 +109,15 @@ public void setCurrentTransform(AppliedPTransform<?, ?, ?>
currentTransform) {
return currentTransform;
}
+ public Map<TupleTag<?>, Coder<?>> getOutputCoders() {
+ return currentTransform
+ .getOutputs()
+ .entrySet()
+ .stream()
+ .filter(e -> e.getValue() instanceof PCollection)
+ .collect(Collectors.toMap(e -> e.getKey(), e -> ((PCollection)
e.getValue()).getCoder()));
+ }
+
@SuppressWarnings("unchecked")
public <T> DataSet<T> getSideInputDataSet(PCollectionView<?> value) {
return (DataSet<T>) broadcastDataSets.get(value);
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
index 25bf9ab7139..d376e58ec4b 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
@@ -296,9 +296,9 @@ public void flatMap(T t, Collector<T> collector) {
e);
}
- WindowedValueCoder<KV<K, V>> inputCoder =
+ WindowedValueCoder<KV<K, V>> windowedInputCoder =
(WindowedValueCoder) instantiateCoder(inputPCollectionId,
pipeline.getComponents());
- KvCoder<K, V> inputElementCoder = (KvCoder<K, V>)
inputCoder.getValueCoder();
+ KvCoder<K, V> inputElementCoder = (KvCoder<K, V>)
windowedInputCoder.getValueCoder();
SingletonKeyedWorkItemCoder<K, V> workItemCoder =
SingletonKeyedWorkItemCoder.of(
@@ -470,7 +470,7 @@ private void translateImpulse(
context.getDataStreamOrThrow(inputPCollectionId);
// TODO: coder for side input push back
- final Coder<WindowedValue<InputT>> inputCoder = null;
+ final Coder<WindowedValue<InputT>> windowedInputCoder = null;
CoderTypeInformation<WindowedValue<OutputT>> outputTypeInformation =
(!outputs.isEmpty())
? new CoderTypeInformation(outputCoders.get(mainOutputTag.getId()))
@@ -491,7 +491,9 @@ private void translateImpulse(
DoFnOperator<InputT, OutputT> doFnOperator =
new ExecutableStageDoFnOperator<>(
transform.getUniqueName(),
- inputCoder,
+ windowedInputCoder,
+ null,
+ Collections.emptyMap(),
mainOutputTag,
additionalOutputTags,
outputManagerFactory,
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java
index a4634679582..7564f291537 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java
@@ -410,7 +410,9 @@ public RawUnionValue map(T o) throws Exception {
Map<TupleTag<?>, OutputTag<WindowedValue<?>>> tagsToOutputTags,
Map<TupleTag<?>, Coder<WindowedValue<?>>> tagsToCoders,
Map<TupleTag<?>, Integer> tagsToIds,
- Coder<WindowedValue<InputT>> inputCoder,
+ Coder<WindowedValue<InputT>> windowedInputCoder,
+ Coder<InputT> inputCoder,
+ Map<TupleTag<?>, Coder<?>> outputCoders,
Coder keyCoder,
KeySelector<WindowedValue<InputT>, ?> keySelector,
Map<Integer, PCollectionView<?>> transformedSideInputs);
@@ -448,14 +450,17 @@ public RawUnionValue map(T o) throws Exception {
entry.getKey().getId(),
(TypeInformation) context.getTypeInfo((PCollection<?>)
entry.getValue())));
tagsToCoders.put(
- entry.getKey(), (Coder) context.getCoder((PCollection<OutputT>)
entry.getValue()));
+ entry.getKey(),
+ (Coder) context.getWindowedInputCoder((PCollection<OutputT>)
entry.getValue()));
tagsToIds.put(entry.getKey(), idCount++);
}
}
SingleOutputStreamOperator<WindowedValue<OutputT>> outputStream;
- Coder<WindowedValue<InputT>> inputCoder = context.getCoder(input);
+ Coder<WindowedValue<InputT>> windowedInputCoder =
context.getWindowedInputCoder(input);
+ Coder<InputT> inputCoder = context.getInputCoder(input);
+ Map<TupleTag<?>, Coder<?>> outputCoders = context.getOutputCoders();
DataStream<WindowedValue<InputT>> inputDataStream =
context.getInputDataStream(input);
@@ -478,7 +483,7 @@ public RawUnionValue map(T o) throws Exception {
CoderTypeInformation<WindowedValue<OutputT>> outputTypeInformation =
new CoderTypeInformation<>(
- context.getCoder((PCollection<OutputT>)
outputs.get(mainOutputTag)));
+ context.getWindowedInputCoder((PCollection<OutputT>)
outputs.get(mainOutputTag)));
if (sideInputs.isEmpty()) {
DoFnOperator<InputT, OutputT> doFnOperator =
@@ -493,7 +498,9 @@ public RawUnionValue map(T o) throws Exception {
tagsToOutputTags,
tagsToCoders,
tagsToIds,
+ windowedInputCoder,
inputCoder,
+ outputCoders,
keyCoder,
keySelector,
new HashMap<>() /* side-input mapping */);
@@ -517,7 +524,9 @@ public RawUnionValue map(T o) throws Exception {
tagsToOutputTags,
tagsToCoders,
tagsToIds,
+ windowedInputCoder,
inputCoder,
+ outputCoders,
keyCoder,
keySelector,
transformedSideInputs.f0);
@@ -625,14 +634,18 @@ public void translateNode(
tagsToOutputTags,
tagsToCoders,
tagsToIds,
+ windowedInputCoder,
inputCoder,
+ outputCoders1,
keyCoder,
keySelector,
transformedSideInputs) ->
new DoFnOperator<>(
doFn1,
stepName,
+ windowedInputCoder,
inputCoder,
+ outputCoders1,
mainOutputTag1,
additionalOutputTags1,
new DoFnOperator.MultiOutputOutputManagerFactory<>(
@@ -677,14 +690,18 @@ public void translateNode(
tagsToOutputTags,
tagsToCoders,
tagsToIds,
+ windowedInputCoder,
inputCoder,
+ outputCoders1,
keyCoder,
keySelector,
transformedSideInputs) ->
new SplittableDoFnOperator<>(
doFn,
stepName,
+ windowedInputCoder,
inputCoder,
+ outputCoders1,
mainOutputTag,
additionalOutputTags,
new DoFnOperator.MultiOutputOutputManagerFactory<>(
@@ -817,7 +834,7 @@ public void translateNode(
SystemReduceFn.buffering(inputKvCoder.getValueCoder());
Coder<WindowedValue<KV<K, Iterable<InputT>>>> outputCoder =
- context.getCoder(context.getOutput(transform));
+ context.getWindowedInputCoder(context.getOutput(transform));
TypeInformation<WindowedValue<KV<K, Iterable<InputT>>>> outputTypeInfo =
context.getTypeInfo(context.getOutput(transform));
@@ -919,7 +936,7 @@ public void translateNode(
combineFn, input.getPipeline().getCoderRegistry(),
inputKvCoder));
Coder<WindowedValue<KV<K, OutputT>>> outputCoder =
- context.getCoder(context.getOutput(transform));
+ context.getWindowedInputCoder(context.getOutput(transform));
TypeInformation<WindowedValue<KV<K, OutputT>>> outputTypeInfo =
context.getTypeInfo(context.getOutput(transform));
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java
index 8b926d9ed42..fef667acef2 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java
@@ -22,6 +22,7 @@
import com.google.common.collect.Iterables;
import java.util.HashMap;
import java.util.Map;
+import java.util.stream.Collectors;
import org.apache.beam.runners.core.construction.TransformInputs;
import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
import org.apache.beam.sdk.coders.Coder;
@@ -90,13 +91,26 @@ public void setCurrentTransform(AppliedPTransform<?, ?, ?>
currentTransform) {
this.currentTransform = currentTransform;
}
- public <T> Coder<WindowedValue<T>> getCoder(PCollection<T> collection) {
+ public <T> Coder<WindowedValue<T>> getWindowedInputCoder(PCollection<T>
collection) {
Coder<T> valueCoder = collection.getCoder();
return WindowedValue.getFullCoder(
valueCoder,
collection.getWindowingStrategy().getWindowFn().windowCoder());
}
+ public <T> Coder<T> getInputCoder(PCollection<T> collection) {
+ return collection.getCoder();
+ }
+
+ public Map<TupleTag<?>, Coder<?>> getOutputCoders() {
+ return currentTransform
+ .getOutputs()
+ .entrySet()
+ .stream()
+ .filter(e -> e.getValue() instanceof PCollection)
+ .collect(Collectors.toMap(e -> e.getKey(), e -> ((PCollection)
e.getValue()).getCoder()));
+ }
+
@SuppressWarnings("unchecked")
public <T> TypeInformation<WindowedValue<T>> getTypeInfo(PCollection<T>
collection) {
Coder<T> valueCoder = collection.getCoder();
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
index 21d6a1769f3..a0d3f7c550f 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
@@ -18,7 +18,6 @@
package org.apache.beam.runners.flink.translation.functions;
import com.google.common.collect.Lists;
-import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.beam.runners.core.DoFnRunner;
@@ -26,6 +25,7 @@
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.flink.FlinkPipelineOptions;
import org.apache.beam.runners.flink.metrics.DoFnRunnerWithMetricsUpdate;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
@@ -61,6 +61,8 @@
private final Map<TupleTag<?>, Integer> outputMap;
private final TupleTag<OutputT> mainOutputTag;
+ private final Coder<InputT> inputCoder;
+ private final Map<TupleTag<?>, Coder<?>> outputCoderMap;
private transient DoFnInvoker<InputT, OutputT> doFnInvoker;
@@ -71,7 +73,9 @@ public FlinkDoFnFunction(
Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs,
PipelineOptions options,
Map<TupleTag<?>, Integer> outputMap,
- TupleTag<OutputT> mainOutputTag) {
+ TupleTag<OutputT> mainOutputTag,
+ Coder<InputT> inputCoder,
+ Map<TupleTag<?>, Coder<?>> outputCoderMap) {
this.doFn = doFn;
this.stepName = stepName;
@@ -80,6 +84,8 @@ public FlinkDoFnFunction(
this.windowingStrategy = windowingStrategy;
this.outputMap = outputMap;
this.mainOutputTag = mainOutputTag;
+ this.inputCoder = inputCoder;
+ this.outputCoderMap = outputCoderMap;
}
@Override
@@ -108,8 +114,8 @@ public void mapPartition(
mainOutputTag,
additionalOutputTags,
new FlinkNoOpStepContext(),
- null,
- Collections.emptyMap(),
+ inputCoder,
+ outputCoderMap,
windowingStrategy);
if
((serializedOptions.get().as(FlinkPipelineOptions.class)).getEnableMetrics()) {
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java
index 4dec3134238..9be2ec3bb54 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java
@@ -20,7 +20,6 @@
import static org.apache.flink.util.Preconditions.checkArgument;
import com.google.common.collect.Lists;
-import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
@@ -35,6 +34,7 @@
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.flink.FlinkPipelineOptions;
import org.apache.beam.runners.flink.metrics.DoFnRunnerWithMetricsUpdate;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
@@ -63,6 +63,8 @@
private final SerializablePipelineOptions serializedOptions;
private final Map<TupleTag<?>, Integer> outputMap;
private final TupleTag<OutputT> mainOutputTag;
+ private final Coder<KV<K, V>> inputCoder;
+ private final Map<TupleTag<?>, Coder<?>> outputCoderMap;
private transient DoFnInvoker doFnInvoker;
public FlinkStatefulDoFnFunction(
@@ -72,7 +74,9 @@ public FlinkStatefulDoFnFunction(
Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs,
PipelineOptions pipelineOptions,
Map<TupleTag<?>, Integer> outputMap,
- TupleTag<OutputT> mainOutputTag) {
+ TupleTag<OutputT> mainOutputTag,
+ Coder<KV<K, V>> inputCoder,
+ Map<TupleTag<?>, Coder<?>> outputCoderMap) {
this.dofn = dofn;
this.stepName = stepName;
@@ -81,6 +85,8 @@ public FlinkStatefulDoFnFunction(
this.serializedOptions = new SerializablePipelineOptions(pipelineOptions);
this.outputMap = outputMap;
this.mainOutputTag = mainOutputTag;
+ this.inputCoder = inputCoder;
+ this.outputCoderMap = outputCoderMap;
}
@Override
@@ -134,8 +140,8 @@ public TimerInternals timerInternals() {
return timerInternals;
}
},
- null,
- Collections.emptyMap(),
+ inputCoder,
+ outputCoderMap,
windowingStrategy);
if
((serializedOptions.get().as(FlinkPipelineOptions.class)).getEnableMetrics()) {
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
index ffc997ea6a4..1bfe830a5ef 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
@@ -28,7 +28,6 @@
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
-import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
@@ -148,7 +147,11 @@
private final String stepName;
- private final Coder<WindowedValue<InputT>> inputCoder;
+ private final Coder<WindowedValue<InputT>> windowedInputCoder;
+
+ private final Coder<InputT> inputCoder;
+
+ private final Map<TupleTag<?>, Coder<?>> outputCoders;
private final Coder<?> keyCoder;
@@ -179,7 +182,9 @@
public DoFnOperator(
DoFn<InputT, OutputT> doFn,
String stepName,
- Coder<WindowedValue<InputT>> inputCoder,
+ Coder<WindowedValue<InputT>> inputWindowedCoder,
+ Coder<InputT> inputCoder,
+ Map<TupleTag<?>, Coder<?>> outputCoders,
TupleTag<OutputT> mainOutputTag,
List<TupleTag<?>> additionalOutputTags,
OutputManagerFactory<OutputT> outputManagerFactory,
@@ -191,7 +196,9 @@ public DoFnOperator(
KeySelector<WindowedValue<InputT>, ?> keySelector) {
this.doFn = doFn;
this.stepName = stepName;
+ this.windowedInputCoder = inputWindowedCoder;
this.inputCoder = inputCoder;
+ this.outputCoders = outputCoders;
this.mainOutputTag = mainOutputTag;
this.additionalOutputTags = additionalOutputTags;
this.sideInputTagMapping = sideInputTagMapping;
@@ -265,7 +272,8 @@ public void initializeState(StateInitializationContext
context) throws Exception
super.initializeState(context);
ListStateDescriptor<WindowedValue<InputT>> pushedBackStateDescriptor =
- new ListStateDescriptor<>("pushed-back-elements", new
CoderTypeSerializer<>(inputCoder));
+ new ListStateDescriptor<>(
+ "pushed-back-elements", new
CoderTypeSerializer<>(windowedInputCoder));
if (keySelector != null) {
pushedBackElementsHandler =
@@ -344,9 +352,8 @@ public void initializeState(StateInitializationContext
context) throws Exception
mainOutputTag,
additionalOutputTags,
stepContext,
- // TODO: fix
- null,
- Collections.emptyMap(),
+ inputCoder,
+ outputCoders,
windowingStrategy);
doFnRunner = createWrappingDoFnRunner(doFnRunner);
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
index cc0ed8849a5..7b9c7590857 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
@@ -75,7 +75,9 @@
public ExecutableStageDoFnOperator(
String stepName,
- Coder<WindowedValue<InputT>> inputCoder,
+ Coder<WindowedValue<InputT>> windowedInputCoder,
+ Coder<InputT> inputCoder,
+ Map<TupleTag<?>, Coder<?>> outputCoders,
TupleTag<OutputT> mainOutputTag,
List<TupleTag<?>> additionalOutputTags,
OutputManagerFactory<OutputT> outputManagerFactory,
@@ -89,7 +91,9 @@ public ExecutableStageDoFnOperator(
super(
new NoOpDoFn(),
stepName,
+ windowedInputCoder,
inputCoder,
+ outputCoders,
mainOutputTag,
additionalOutputTags,
outputManagerFactory,
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java
index 0834c795449..58d80e28b1b 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java
@@ -67,7 +67,9 @@
public SplittableDoFnOperator(
DoFn<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> doFn,
String stepName,
- Coder<WindowedValue<KeyedWorkItem<String, KV<InputT, RestrictionT>>>>
inputCoder,
+ Coder<WindowedValue<KeyedWorkItem<String, KV<InputT, RestrictionT>>>>
windowedInputCoder,
+ Coder<KeyedWorkItem<String, KV<InputT, RestrictionT>>> inputCoder,
+ Map<TupleTag<?>, Coder<?>> outputCoders,
TupleTag<OutputT> mainOutputTag,
List<TupleTag<?>> additionalOutputTags,
OutputManagerFactory<OutputT> outputManagerFactory,
@@ -80,7 +82,9 @@ public SplittableDoFnOperator(
super(
doFn,
stepName,
+ windowedInputCoder,
inputCoder,
+ outputCoders,
mainOutputTag,
additionalOutputTags,
outputManagerFactory,
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java
index 35d062bdd24..a9b57ab3a92 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java
@@ -53,7 +53,7 @@
public WindowDoFnOperator(
SystemReduceFn<K, InputT, ?, OutputT, BoundedWindow> systemReduceFn,
String stepName,
- Coder<WindowedValue<KeyedWorkItem<K, InputT>>> inputCoder,
+ Coder<WindowedValue<KeyedWorkItem<K, InputT>>> windowedInputCoder,
TupleTag<KV<K, OutputT>> mainOutputTag,
List<TupleTag<?>> additionalOutputTags,
OutputManagerFactory<KV<K, OutputT>> outputManagerFactory,
@@ -66,7 +66,9 @@ public WindowDoFnOperator(
super(
null,
stepName,
- inputCoder,
+ windowedInputCoder,
+ null,
+ Collections.emptyMap(),
mainOutputTag,
additionalOutputTags,
outputManagerFactory,
diff --git
a/runners/flink/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java
b/runners/flink/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java
index 61fce04624e..2edf3c73d1e 100644
---
a/runners/flink/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java
+++
b/runners/flink/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java
@@ -67,6 +67,8 @@ public void parDoBaseClassPipelineOptionsNullTest() {
new TestDoFn(),
"stepName",
coder,
+ null,
+ Collections.emptyMap(),
mainTag,
Collections.emptyList(),
new DoFnOperator.MultiOutputOutputManagerFactory<>(mainTag, coder),
@@ -90,6 +92,8 @@ public void parDoBaseClassPipelineOptionsSerializationTest()
throws Exception {
new TestDoFn(),
"stepName",
coder,
+ null,
+ Collections.emptyMap(),
mainTag,
Collections.emptyList(),
new DoFnOperator.MultiOutputOutputManagerFactory<>(mainTag, coder),
diff --git
a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
index 4dd9850e391..33832ce0594 100644
---
a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
+++
b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
@@ -118,6 +118,8 @@ public void testSingleOutput() throws Exception {
new IdentityDoFn<>(),
"stepName",
coder,
+ null,
+ Collections.emptyMap(),
outputTag,
Collections.emptyList(),
new DoFnOperator.MultiOutputOutputManagerFactory<>(outputTag,
coder),
@@ -175,6 +177,8 @@ public void testMultiOutputOutput() throws Exception {
new MultiOutputDoFn(additionalOutput1, additionalOutput2),
"stepName",
coder,
+ null,
+ Collections.emptyMap(),
mainOutput,
ImmutableList.of(additionalOutput1, additionalOutput2),
new DoFnOperator.MultiOutputOutputManagerFactory(
@@ -267,6 +271,8 @@ public void onEventTime(OnTimerContext context) {
fn,
"stepName",
inputCoder,
+ null,
+ Collections.emptyMap(),
outputTag,
Collections.emptyList(),
new DoFnOperator.MultiOutputOutputManagerFactory<>(outputTag,
outputCoder),
@@ -346,6 +352,8 @@ public void processElement(ProcessContext context) {
fn,
"stepName",
inputCoder,
+ null,
+ Collections.emptyMap(),
outputTag,
Collections.emptyList(),
new DoFnOperator.MultiOutputOutputManagerFactory<>(outputTag,
outputCoder),
@@ -451,6 +459,8 @@ public void onTimer(OnTimerContext context,
@StateId(stateId) ValueState<String>
fn,
"stepName",
coder,
+ null,
+ Collections.emptyMap(),
outputTag,
Collections.emptyList(),
new DoFnOperator.MultiOutputOutputManagerFactory<>(outputTag,
coder),
@@ -552,6 +562,8 @@ void testSideInputs(boolean keyed) throws Exception {
new IdentityDoFn<>(),
"stepName",
coder,
+ null,
+ Collections.emptyMap(),
outputTag,
Collections.emptyList(),
new DoFnOperator.MultiOutputOutputManagerFactory<>(outputTag,
coder),
@@ -720,6 +732,8 @@ public void nonKeyedParDoSideInputCheckpointing() throws
Exception {
new IdentityDoFn<>(),
"stepName",
coder,
+ null,
+ Collections.emptyMap(),
outputTag,
Collections.emptyList(),
new
DoFnOperator.MultiOutputOutputManagerFactory<>(outputTag, coder),
@@ -755,6 +769,8 @@ public void keyedParDoSideInputCheckpointing() throws
Exception {
new IdentityDoFn<>(),
"stepName",
coder,
+ null,
+ Collections.emptyMap(),
outputTag,
Collections.emptyList(),
new
DoFnOperator.MultiOutputOutputManagerFactory<>(outputTag, coder),
@@ -849,6 +865,8 @@ public void nonKeyedParDoPushbackDataCheckpointing() throws
Exception {
new IdentityDoFn<>(),
"stepName",
coder,
+ null,
+ Collections.emptyMap(),
outputTag,
Collections.emptyList(),
new
DoFnOperator.MultiOutputOutputManagerFactory<>(outputTag, coder),
@@ -885,6 +903,8 @@ public void keyedParDoPushbackDataCheckpointing() throws
Exception {
new IdentityDoFn<>(),
"stepName",
coder,
+ null,
+ Collections.emptyMap(),
outputTag,
Collections.emptyList(),
new
DoFnOperator.MultiOutputOutputManagerFactory<>(outputTag, coder),
@@ -1057,6 +1077,8 @@ public void onEventTime(OnTimerContext context) {
fn,
"stepName",
inputCoder,
+ null,
+ Collections.emptyMap(),
outputTag,
Collections.emptyList(),
new DoFnOperator.MultiOutputOutputManagerFactory<>(outputTag,
outputCoder),
@@ -1101,6 +1123,8 @@ public void finishBundle(FinishBundleContext context) {
doFn,
"stepName",
windowedValueCoder,
+ null,
+ Collections.emptyMap(),
outputTag,
Collections.emptyList(),
outputManagerFactory,
@@ -1140,6 +1164,8 @@ public void finishBundle(FinishBundleContext context) {
doFn,
"stepName",
windowedValueCoder,
+ null,
+ Collections.emptyMap(),
outputTag,
Collections.emptyList(),
outputManagerFactory,
diff --git
a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/ExecutableStageDoFnOperatorTest.java
b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/ExecutableStageDoFnOperatorTest.java
index 234684a10a9..8860dd2eba1 100644
---
a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/ExecutableStageDoFnOperatorTest.java
+++
b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/ExecutableStageDoFnOperatorTest.java
@@ -319,6 +319,8 @@ public void testSerialization() {
new ExecutableStageDoFnOperator<>(
"transform",
null,
+ null,
+ Collections.emptyMap(),
mainOutput,
ImmutableList.of(additionalOutput),
outputManagerFactory,
@@ -353,6 +355,8 @@ public void testSerialization() {
new ExecutableStageDoFnOperator<>(
"transform",
null,
+ null,
+ Collections.emptyMap(),
mainOutput,
additionalOutputs,
outputManagerFactory,
diff --git a/runners/gearpump/build.gradle b/runners/gearpump/build.gradle
index 10866e08025..3a0ceb15e7c 100644
--- a/runners/gearpump/build.gradle
+++ b/runners/gearpump/build.gradle
@@ -66,6 +66,7 @@ task validatesRunnerStreaming(type: Test) {
"--streaming=true",
])
+
classpath = configurations.validatesRunner
testClassesDirs =
files(project(":beam-sdks-java-core").sourceSets.test.output.classesDirs)
useJUnit {
diff --git a/runners/samza/build.gradle b/runners/samza/build.gradle
index 41dfd8bd478..3f9147afbea 100644
--- a/runners/samza/build.gradle
+++ b/runners/samza/build.gradle
@@ -82,7 +82,6 @@ task validatesRunner(type: Test) {
excludeCategories 'org.apache.beam.sdk.testing.UsesAttemptedMetrics'
excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics'
excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse'
- excludeCategories 'org.apache.beam.sdk.testing.UsesSchema'
excludeCategories
'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream'
excludeCategories 'org.apache.beam.sdk.testing.UsesTimersInParDo'
diff --git
a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java
b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java
index 058efd10797..3c704f9853b 100644
---
a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java
+++
b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java
@@ -92,11 +92,15 @@
private transient DoFnSignature signature;
private transient TaskContext context;
private transient SamzaPipelineOptions pipelineOptions;
+ private Coder<InT> inputCoder;
+ private Map<TupleTag<?>, Coder<?>> outputCoders;
public DoFnOp(
TupleTag<FnOutT> mainOutputTag,
DoFn<InT, FnOutT> doFn,
Coder<?> keyCoder,
+ Coder<InT> inputCoder,
+ Map<TupleTag<?>, Coder<?>> outputCoders,
Collection<PCollectionView<?>> sideInputs,
List<TupleTag<?>> sideOutputTags,
WindowingStrategy windowingStrategy,
@@ -107,6 +111,8 @@ public DoFnOp(
this.doFn = doFn;
this.sideInputs = sideInputs;
this.sideOutputTags = sideOutputTags;
+ this.inputCoder = inputCoder;
+ this.outputCoders = outputCoders;
this.windowingStrategy = windowingStrategy;
this.idToViewMap = new HashMap<>(idToViewMap);
this.outputManagerFactory = outputManagerFactory;
@@ -150,6 +156,8 @@ public void open(
outputManagerFactory.create(emitter),
mainOutputTag,
sideOutputTags,
+ inputCoder,
+ outputCoders,
stateInternalsFactory,
timerInternalsFactory,
windowingStrategy,
diff --git
a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnRunnerWithKeyedInternals.java
b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnRunnerWithKeyedInternals.java
index 9563578a822..2bd7792ec2e 100644
---
a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnRunnerWithKeyedInternals.java
+++
b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnRunnerWithKeyedInternals.java
@@ -20,8 +20,8 @@
import static com.google.common.base.Preconditions.checkState;
-import java.util.Collections;
import java.util.List;
+import java.util.Map;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.runners.core.DoFnRunners;
import org.apache.beam.runners.core.KeyedWorkItem;
@@ -58,6 +58,8 @@
DoFnRunners.OutputManager outputManager,
TupleTag<OutputT> mainOutputTag,
List<TupleTag<?>> additionalOutputTags,
+ Coder<InputT> inputCoder,
+ Map<TupleTag<?>, Coder<?>> outputCoders,
SamzaStoreStateInternals.Factory<?> stateInternalsFactory,
SamzaTimerInternalsFactory<?> timerInternalsFactory,
WindowingStrategy<?, ?> windowingStrategy,
@@ -88,9 +90,8 @@
mainOutputTag,
additionalOutputTags,
createStepContext(stateInternals, timerInternals),
- // TODO: fix.
- null,
- Collections.emptyMap(),
+ inputCoder,
+ outputCoders,
windowingStrategy);
final DoFnRunner<InputT, OutputT> doFnRunnerWithMetrics =
diff --git
a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java
b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java
index 83edd5a45df..b6b03331a8a 100644
---
a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java
+++
b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java
@@ -58,6 +58,14 @@ public void translate(
TransformHierarchy.Node node,
TranslationContext ctx) {
final PCollection<? extends InT> input = ctx.getInput(transform);
+ final Map<TupleTag<?>, Coder<?>> outputCoders =
+ ctx.getCurrentTransform()
+ .getOutputs()
+ .entrySet()
+ .stream()
+ .filter(e -> e.getValue() instanceof PCollection)
+ .collect(
+ Collectors.toMap(e -> e.getKey(), e -> ((PCollection<?>)
e.getValue()).getCoder()));
final DoFnSignature signature =
DoFnSignatures.getSignature(transform.getFn().getClass());
final Coder<?> keyCoder =
@@ -105,6 +113,8 @@ public void translate(
transform.getMainOutputTag(),
transform.getFn(),
keyCoder,
+ (Coder<InT>) input.getCoder(),
+ outputCoders,
transform.getSideInputs(),
transform.getAdditionalOutputTags().getAll(),
input.getWindowingStrategy(),
diff --git a/runners/spark/build.gradle b/runners/spark/build.gradle
index 6ea29be2d15..5a750403249 100644
--- a/runners/spark/build.gradle
+++ b/runners/spark/build.gradle
@@ -124,7 +124,6 @@ task validatesRunnerBatch(type: Test) {
includeCategories 'org.apache.beam.runners.spark.UsesCheckpointRecovery'
excludeCategories 'org.apache.beam.sdk.testing.UsesAttemptedMetrics'
excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics'
- excludeCategories 'org.apache.beam.sdk.testing.UsesSchema'
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream'
excludeCategories 'org.apache.beam.sdk.testing.UsesCustomWindowMerging'
excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse'
@@ -148,7 +147,6 @@ task validatesRunnerStreaming(type: Test) {
maxParallelForks 4
useJUnit {
includeCategories 'org.apache.beam.runners.spark.StreamingTest'
- excludeCategories 'org.apache.beam.sdk.testing.UsesSchema'
}
}
diff --git
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
index 0a2e386fb10..6d1cdea295c 100644
---
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
+++
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
@@ -26,6 +26,7 @@
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
+import java.util.stream.Collectors;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.construction.TransformInputs;
import org.apache.beam.runners.spark.SparkPipelineOptions;
@@ -131,6 +132,15 @@ public void setCurrentTransform(AppliedPTransform<?, ?, ?>
transform) {
return currentTransform.getOutputs();
}
+ public Map<TupleTag<?>, Coder<?>> getOutputCoders() {
+ return currentTransform
+ .getOutputs()
+ .entrySet()
+ .stream()
+ .filter(e -> e.getValue() instanceof PCollection)
+ .collect(Collectors.toMap(e -> e.getKey(), e -> ((PCollection)
e.getValue()).getCoder()));
+ }
+
private boolean shouldCache(PValue pvalue) {
if ((pvalue instanceof PCollection)
&& cacheCandidates.containsKey(pvalue)
diff --git
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
index f585c54901a..d864e624d4a 100644
---
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
+++
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
@@ -38,6 +38,7 @@
import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
import org.apache.beam.runners.spark.util.SideInputBroadcast;
import org.apache.beam.runners.spark.util.SparkSideInputReader;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -67,6 +68,8 @@
private final SerializablePipelineOptions options;
private final TupleTag<OutputT> mainOutputTag;
private final List<TupleTag<?>> additionalOutputTags;
+ private final Coder<InputT> inputCoder;
+ private final Map<TupleTag<?>, Coder<?>> outputCoders;
private final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>,
SideInputBroadcast<?>>> sideInputs;
private final WindowingStrategy<?, ?> windowingStrategy;
private final boolean stateful;
@@ -77,6 +80,8 @@
* @param options The {@link SerializablePipelineOptions}.
* @param mainOutputTag The main output {@link TupleTag}.
* @param additionalOutputTags Additional {@link TupleTag output tags}.
+ * @param inputCoder The coder for the input.
+ * @param outputCoders A map of all output coders.
* @param sideInputs Side inputs used in this {@link DoFn}.
* @param windowingStrategy Input {@link WindowingStrategy}.
* @param stateful Stateful {@link DoFn}.
@@ -88,6 +93,8 @@ public MultiDoFnFunction(
SerializablePipelineOptions options,
TupleTag<OutputT> mainOutputTag,
List<TupleTag<?>> additionalOutputTags,
+ Coder<InputT> inputCoder,
+ Map<TupleTag<?>, Coder<?>> outputCoders,
Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>>
sideInputs,
WindowingStrategy<?, ?> windowingStrategy,
boolean stateful) {
@@ -97,6 +104,8 @@ public MultiDoFnFunction(
this.options = options;
this.mainOutputTag = mainOutputTag;
this.additionalOutputTags = additionalOutputTags;
+ this.inputCoder = inputCoder;
+ this.outputCoders = outputCoders;
this.sideInputs = sideInputs;
this.windowingStrategy = windowingStrategy;
this.stateful = stateful;
@@ -150,8 +159,8 @@ public TimerInternals timerInternals() {
mainOutputTag,
additionalOutputTags,
context,
- null,
- Collections.emptyMap(),
+ inputCoder,
+ outputCoders,
windowingStrategy);
DoFnRunnerWithMetrics<InputT, OutputT> doFnRunnerWithMetrics =
diff --git
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index 5de56d8cae4..3f508b1edd6 100644
---
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -344,7 +344,8 @@ public void evaluate(
WindowingStrategy<?, ?> windowingStrategy =
context.getInput(transform).getWindowingStrategy();
Accumulator<MetricsContainerStepMap> metricsAccum =
MetricsAccumulator.getInstance();
-
+ Coder<InputT> inputCoder = (Coder<InputT>)
context.getInput(transform).getCoder();
+ Map<TupleTag<?>, Coder<?>> outputCoders = context.getOutputCoders();
JavaPairRDD<TupleTag<?>, WindowedValue<?>> all;
DoFnSignature signature =
DoFnSignatures.getSignature(transform.getFn().getClass());
@@ -359,6 +360,8 @@ public void evaluate(
context.getSerializableOptions(),
transform.getMainOutputTag(),
transform.getAdditionalOutputTags().getAll(),
+ inputCoder,
+ outputCoders,
TranslationUtils.getSideInputs(transform.getSideInputs(),
context),
windowingStrategy,
stateful);
diff --git
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index b654dfaf992..a307cc96827 100644
---
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -387,6 +387,8 @@ public void evaluate(
final SparkPCollectionView pviews = context.getPViews();
final WindowingStrategy<?, ?> windowingStrategy =
context.getInput(transform).getWindowingStrategy();
+ Coder<InputT> inputCoder = (Coder<InputT>)
context.getInput(transform).getCoder();
+ Map<TupleTag<?>, Coder<?>> outputCoders = context.getOutputCoders();
@SuppressWarnings("unchecked")
UnboundedDataset<InputT> unboundedDataset =
@@ -414,6 +416,8 @@ public void evaluate(
options,
transform.getMainOutputTag(),
transform.getAdditionalOutputTags().getAll(),
+ inputCoder,
+ outputCoders,
sideInputs,
windowingStrategy,
false));
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
Issue Time Tracking
-------------------
Worklog Id: (was: 126962)
Time Spent: 15h 50m (was: 15h 40m)
> Schema followups
> ----------------
>
> Key: BEAM-4076
> URL: https://issues.apache.org/jira/browse/BEAM-4076
> Project: Beam
> Issue Type: Improvement
> Components: beam-model, dsl-sql, sdk-java-core
> Reporter: Kenneth Knowles
> Priority: Major
> Time Spent: 15h 50m
> Remaining Estimate: 0h
>
> This umbrella bug contains subtasks with followups for Beam schemas, which
> were moved from SQL to the core Java SDK and made to be type-name-based
> rather than coder based.
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)