[
https://issues.apache.org/jira/browse/BEAM-4681?focusedWorklogId=167434&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-167434
]
ASF GitHub Bot logged work on BEAM-4681:
----------------------------------------
Author: ASF GitHub Bot
Created on: 19/Nov/18 16:04
Start Date: 19/Nov/18 16:04
Worklog Time Spent: 10m
Work Description: mxm closed pull request #7008: [BEAM-4681] Add support
for portable timers in Flink batch mode
URL: https://github.com/apache/beam/pull/7008
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index ff1d38b859f..0147085ab34 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -243,8 +243,6 @@ class BeamModulePlugin implements Plugin<Project> {
excludeCategories 'org.apache.beam.sdk.testing.UsesMapState'
excludeCategories 'org.apache.beam.sdk.testing.UsesSetState'
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream'
- // TODO Enable test once timer-support for batch is merged
- excludeCategories 'org.apache.beam.sdk.testing.UsesTimersInParDo'
//SplitableDoFnTests
excludeCategories
'org.apache.beam.sdk.testing.UsesBoundedSplittableParDo'
excludeCategories
'org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs'
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
index 2ed31b8ba48..3f26d5e2d99 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
@@ -18,6 +18,8 @@
package org.apache.beam.runners.flink;
import static com.google.common.base.Preconditions.checkArgument;
+import static
org.apache.beam.runners.flink.translation.utils.FlinkPipelineTranslatorUtils.createOutputMap;
+import static
org.apache.beam.runners.flink.translation.utils.FlinkPipelineTranslatorUtils.getWindowingStrategy;
import static
org.apache.beam.runners.flink.translation.utils.FlinkPipelineTranslatorUtils.instantiateCoder;
import com.google.common.collect.BiMap;
@@ -55,7 +57,6 @@
import org.apache.beam.runners.flink.translation.functions.FlinkReduceFunction;
import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
import org.apache.beam.runners.flink.translation.types.KvKeySelector;
-import
org.apache.beam.runners.flink.translation.utils.FlinkPipelineTranslatorUtils;
import org.apache.beam.runners.flink.translation.wrappers.ImpulseInputFormat;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.runners.fnexecution.wire.WireCoders;
@@ -299,16 +300,13 @@ public void translate(BatchTranslationContext context,
RunnerApi.Pipeline pipeli
private static <InputT> void translateExecutableStage(
PTransformNode transform, RunnerApi.Pipeline pipeline,
BatchTranslationContext context) {
- // TODO: Fail on stateful DoFns for now.
- // TODO: Support stateful DoFns by inserting group-by-keys where necessary.
// TODO: Fail on splittable DoFns.
// TODO: Special-case single outputs to avoid multiplexing PCollections.
RunnerApi.Components components = pipeline.getComponents();
Map<String, String> outputs = transform.getTransform().getOutputsMap();
// Mapping from PCollection id to coder tag id.
- BiMap<String, Integer> outputMap =
- FlinkPipelineTranslatorUtils.createOutputMap(outputs.values());
+ BiMap<String, Integer> outputMap = createOutputMap(outputs.values());
// Collect all output Coders and create a UnionCoder for our tagged
outputs.
List<Coder<?>> unionCoders = Lists.newArrayList();
// Enforce tuple tag sorting by union tag index.
@@ -338,21 +336,22 @@ public void translate(BatchTranslationContext context,
RunnerApi.Pipeline pipeli
}
String inputPCollectionId = stagePayload.getInput();
+ Coder<WindowedValue<InputT>> windowedInputCoder =
+ instantiateCoder(inputPCollectionId, components);
+
DataSet<WindowedValue<InputT>> inputDataSet =
context.getDataSetOrThrow(inputPCollectionId);
- final boolean stateful = stagePayload.getUserStatesCount() > 0;
final FlinkExecutableStageFunction<InputT> function =
new FlinkExecutableStageFunction<>(
stagePayload,
context.getJobInfo(),
outputMap,
FlinkExecutableStageContext.factory(context.getPipelineOptions()),
- stateful);
+ getWindowingStrategy(inputPCollectionId,
components).getWindowFn().windowCoder());
final SingleInputUdfOperator taggedDataset;
- if (stateful) {
- Coder<WindowedValue<InputT>> windowedInputCoder =
- instantiateCoder(inputPCollectionId, components);
+ if (stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount()
> 0) {
+
Coder valueCoder =
((WindowedValue.FullWindowedValueCoder)
windowedInputCoder).getValueCoder();
// Stateful stages are only allowed of KV input to be able to group on
the key
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
index 17b7e53aef8..f311e2a7fd7 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
@@ -17,26 +17,32 @@
*/
package org.apache.beam.runners.flink.translation.functions;
-import static com.google.common.base.Preconditions.checkArgument;
-import static com.google.common.base.Preconditions.checkState;
-
-import com.google.common.collect.Iterables;
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
import java.util.EnumMap;
+import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
+import java.util.Locale;
import java.util.Map;
+import java.util.function.BiConsumer;
+import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.InMemoryStateInternals;
+import org.apache.beam.runners.core.InMemoryTimerInternals;
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateNamespaces;
import org.apache.beam.runners.core.StateTag;
import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.core.TimerInternals;
+import org.apache.beam.runners.core.construction.Timer;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
+import org.apache.beam.runners.core.construction.graph.TimerReference;
import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory;
import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors;
@@ -50,15 +56,20 @@
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -85,7 +96,7 @@
// Map from PCollection id to the union tag used to represent this
PCollection in the output.
private final Map<String, Integer> outputMap;
private final FlinkExecutableStageContext.Factory contextFactory;
- private final boolean stateful;
+ private final Coder windowCoder;
// Worker-local fields. These should only be constructed and consumed on
Flink TaskManagers.
private transient RuntimeContext runtimeContext;
@@ -95,18 +106,21 @@
private transient BundleProgressHandler progressHandler;
// Only initialized when the ExecutableStage is stateful
private transient InMemoryBagUserStateFactory bagUserStateHandlerFactory;
+ private transient ExecutableStage executableStage;
+ // In state
+ private transient Object currentTimerKey;
public FlinkExecutableStageFunction(
RunnerApi.ExecutableStagePayload stagePayload,
JobInfo jobInfo,
Map<String, Integer> outputMap,
FlinkExecutableStageContext.Factory contextFactory,
- boolean stateful) {
+ Coder windowCoder) {
this.stagePayload = stagePayload;
this.jobInfo = jobInfo;
this.outputMap = outputMap;
this.contextFactory = contextFactory;
- this.stateful = stateful;
+ this.windowCoder = windowCoder;
}
@Override
@@ -114,7 +128,7 @@ public void open(Configuration parameters) throws Exception
{
// Register standard file systems.
// TODO Use actual pipeline options.
FileSystems.setDefaultPipelineOptions(PipelineOptionsFactory.create());
- ExecutableStage executableStage =
ExecutableStage.fromPayload(stagePayload);
+ executableStage = ExecutableStage.fromPayload(stagePayload);
runtimeContext = getRuntimeContext();
// TODO: Wire this into the distributed cache and make it pluggable.
stageContext = contextFactory.get(jobInfo);
@@ -144,7 +158,7 @@ private StateRequestHandler getStateRequestHandler(
}
final StateRequestHandler userStateHandler;
- if (stateful) {
+ if (executableStage.getUserStates().size() > 0) {
bagUserStateHandlerFactory = new InMemoryBagUserStateFactory();
userStateHandler =
StateRequestHandlers.forBagUserStateHandlerFactory(
@@ -166,40 +180,132 @@ private StateRequestHandler getStateRequestHandler(
public void mapPartition(
Iterable<WindowedValue<InputT>> iterable, Collector<RawUnionValue>
collector)
throws Exception {
- processElements(iterable, collector);
+
+ ReceiverFactory receiverFactory = new ReceiverFactory(collector,
outputMap);
+ try (RemoteBundle bundle =
+ stageBundleFactory.getBundle(receiverFactory, stateRequestHandler,
progressHandler)) {
+ processElements(iterable, bundle);
+ }
}
- /** For stateful processing via a GroupReduceFunction. */
+ /** For stateful and timer processing via a GroupReduceFunction. */
@Override
public void reduce(Iterable<WindowedValue<InputT>> iterable,
Collector<RawUnionValue> collector)
throws Exception {
- bagUserStateHandlerFactory.resetForNewKey();
- processElements(iterable, collector);
+
+ // Need to discard the old key's state
+ if (bagUserStateHandlerFactory != null) {
+ bagUserStateHandlerFactory.resetForNewKey();
+ }
+
+ // Used with Batch, we know that all the data is available for this key.
We can't use the
+ // timer manager from the context because it doesn't exist. So we create
one and advance
+ // time to the end after processing all elements.
+ final InMemoryTimerInternals timerInternals = new InMemoryTimerInternals();
+ timerInternals.advanceProcessingTime(Instant.now());
+ timerInternals.advanceSynchronizedProcessingTime(Instant.now());
+
+ ReceiverFactory receiverFactory =
+ new ReceiverFactory(
+ collector,
+ outputMap,
+ new TimerReceiverFactory(
+ stageBundleFactory,
+ executableStage.getTimers(),
+
stageBundleFactory.getProcessBundleDescriptor().getTimerSpecs(),
+ (WindowedValue timerElement, TimerInternals.TimerData
timerData) -> {
+ currentTimerKey = (((KV) timerElement.getValue()).getKey());
+ timerInternals.setTimer(timerData);
+ },
+ windowCoder));
+
+ // First process all elements and make sure no more elements can arrive
+ try (RemoteBundle bundle =
+ stageBundleFactory.getBundle(receiverFactory, stateRequestHandler,
progressHandler)) {
+ processElements(iterable, bundle);
+ }
+
+ // Finish any pending windows by advancing the input watermark to infinity.
+ timerInternals.advanceInputWatermark(BoundedWindow.TIMESTAMP_MAX_VALUE);
+ // Finally, advance the processing time to infinity to fire any timers.
+ timerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
+
timerInternals.advanceSynchronizedProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
+
+ // Now we fire the timers and process elements generated by timers (which
may be timers itself)
+ try (RemoteBundle bundle =
+ stageBundleFactory.getBundle(receiverFactory, stateRequestHandler,
progressHandler)) {
+
+ fireEligibleTimers(
+ timerInternals,
+ (String timerId, WindowedValue timerValue) -> {
+ FnDataReceiver<WindowedValue<?>> fnTimerReceiver =
+ bundle.getInputReceivers().get(timerId);
+ Preconditions.checkNotNull(fnTimerReceiver, "No FnDataReceiver
found for %s", timerId);
+ try {
+ fnTimerReceiver.accept(timerValue);
+ } catch (Exception e) {
+ throw new RuntimeException(
+ String.format(Locale.ENGLISH, "Failed to process timer: %s",
timerValue));
+ }
+ });
+ }
}
- private void processElements(
- Iterable<WindowedValue<InputT>> iterable, Collector<RawUnionValue>
collector)
+ private void processElements(Iterable<WindowedValue<InputT>> iterable,
RemoteBundle bundle)
throws Exception {
- checkState(
- runtimeContext == getRuntimeContext(),
- "RuntimeContext changed from under us. State handler invalid.");
- checkState(
- stageBundleFactory != null, "%s not yet prepared",
StageBundleFactory.class.getName());
- checkState(
- stateRequestHandler != null, "%s not yet prepared",
StateRequestHandler.class.getName());
+ Preconditions.checkArgument(bundle != null, "RemoteBundle must not be
null");
+
+ String inputPCollectionId = executableStage.getInputPCollection().getId();
+ FnDataReceiver<WindowedValue<?>> mainReceiver =
+ Preconditions.checkNotNull(
+ bundle.getInputReceivers().get(inputPCollectionId),
+ "Main input receiver for %s could not be initialized",
+ inputPCollectionId);
+ for (WindowedValue<InputT> input : iterable) {
+ mainReceiver.accept(input);
+ }
+ }
- try (RemoteBundle bundle =
- stageBundleFactory.getBundle(
- new ReceiverFactory(collector, outputMap), stateRequestHandler,
progressHandler)) {
- // TODO(BEAM-4681): Add support to Flink to support portable timers.
- FnDataReceiver<WindowedValue<?>> receiver =
- Iterables.getOnlyElement(bundle.getInputReceivers().values());
- for (WindowedValue<InputT> input : iterable) {
- receiver.accept(input);
+ /**
+ * Fires all timers which are ready to be fired. This is done in a loop
because timers may itself
+ * schedule timers.
+ */
+ private void fireEligibleTimers(
+ InMemoryTimerInternals timerInternals, BiConsumer<String, WindowedValue>
timerConsumer) {
+
+ boolean hasFired;
+ do {
+ hasFired = false;
+ TimerInternals.TimerData timer;
+
+ while ((timer = timerInternals.removeNextEventTimer()) != null) {
+ hasFired = true;
+ fireTimer(timer, timerConsumer);
}
- }
- // NOTE: RemoteBundle.close() blocks on completion of all data receivers.
This is necessary to
- // safely reference the partition-scoped Collector from receivers.
+ while ((timer = timerInternals.removeNextProcessingTimer()) != null) {
+ hasFired = true;
+ fireTimer(timer, timerConsumer);
+ }
+ while ((timer = timerInternals.removeNextSynchronizedProcessingTimer())
!= null) {
+ hasFired = true;
+ fireTimer(timer, timerConsumer);
+ }
+ } while (hasFired);
+ }
+
+ private void fireTimer(
+ TimerInternals.TimerData timer, BiConsumer<String, WindowedValue>
timerConsumer) {
+ StateNamespace namespace = timer.getNamespace();
+ Preconditions.checkArgument(namespace instanceof
StateNamespaces.WindowNamespace);
+ BoundedWindow window = ((StateNamespaces.WindowNamespace)
namespace).getWindow();
+ Instant timestamp = timer.getTimestamp();
+ WindowedValue<KV<Object, Timer>> timerValue =
+ WindowedValue.of(
+ KV.of(currentTimerKey, Timer.of(timestamp, new byte[0])),
+ timestamp,
+ Collections.singleton(window),
+ PaneInfo.NO_FIRING);
+ timerConsumer.accept(timer.getTimerId(), timerValue);
}
@Override
@@ -220,7 +326,7 @@ public void close() throws Exception {
/**
* Receiver factory that wraps outgoing elements with the corresponding
union tag for a
- * multiplexed PCollection.
+ * multiplexed PCollection and optionally handles timer items.
*/
private static class ReceiverFactory implements OutputReceiverFactory {
@@ -230,20 +336,92 @@ public void close() throws Exception {
private final Collector<RawUnionValue> collector;
private final Map<String, Integer> outputMap;
+ @Nullable private final TimerReceiverFactory timerReceiverFactory;
ReceiverFactory(Collector<RawUnionValue> collector, Map<String, Integer>
outputMap) {
+ this(collector, outputMap, null);
+ }
+
+ ReceiverFactory(
+ Collector<RawUnionValue> collector,
+ Map<String, Integer> outputMap,
+ @Nullable TimerReceiverFactory timerReceiverFactory) {
this.collector = collector;
this.outputMap = outputMap;
+ this.timerReceiverFactory = timerReceiverFactory;
}
@Override
public <OutputT> FnDataReceiver<OutputT> create(String collectionId) {
Integer unionTag = outputMap.get(collectionId);
- checkArgument(unionTag != null, "Unknown PCollection id: %s",
collectionId);
- int tagInt = unionTag;
+ if (unionTag != null) {
+ int tagInt = unionTag;
+ return receivedElement -> {
+ synchronized (collectorLock) {
+ collector.collect(new RawUnionValue(tagInt, receivedElement));
+ }
+ };
+ } else if (timerReceiverFactory != null) {
+ // Delegate to TimerReceiverFactory
+ return timerReceiverFactory.create(collectionId);
+ } else {
+ throw new IllegalStateException(
+ String.format(Locale.ENGLISH, "Unknown PCollectionId %s",
collectionId));
+ }
+ }
+ }
+
+ private static class TimerReceiverFactory implements OutputReceiverFactory {
+
+ private final StageBundleFactory stageBundleFactory;
+ /** Timer PCollection id => TimerReference. */
+ private final HashMap<String, ProcessBundleDescriptors.TimerSpec>
timerOutputIdToSpecMap;
+ /** Timer PCollection id => timer name => TimerSpec. */
+ private final Map<String, Map<String, ProcessBundleDescriptors.TimerSpec>>
timerSpecMap;
+
+ private final BiConsumer<WindowedValue, TimerInternals.TimerData>
timerDataConsumer;
+ private final Coder windowCoder;
+
+ TimerReceiverFactory(
+ StageBundleFactory stageBundleFactory,
+ Collection<TimerReference> timerReferenceCollection,
+ Map<String, Map<String, ProcessBundleDescriptors.TimerSpec>>
timerSpecMap,
+ BiConsumer<WindowedValue, TimerInternals.TimerData> timerDataConsumer,
+ Coder windowCoder) {
+ this.stageBundleFactory = stageBundleFactory;
+ this.timerOutputIdToSpecMap = new HashMap<>();
+ // Gather all timers from all transforms by their output pCollectionId
which is unique
+ for (Map<String, ProcessBundleDescriptors.TimerSpec> transformTimerMap :
+
stageBundleFactory.getProcessBundleDescriptor().getTimerSpecs().values()) {
+ for (ProcessBundleDescriptors.TimerSpec timerSpec :
transformTimerMap.values()) {
+ timerOutputIdToSpecMap.put(timerSpec.outputCollectionId(),
timerSpec);
+ }
+ }
+ this.timerSpecMap = timerSpecMap;
+ this.timerDataConsumer = timerDataConsumer;
+ this.windowCoder = windowCoder;
+ }
+
+ @Override
+ public <OutputT> FnDataReceiver<OutputT> create(String pCollectionId) {
+ final ProcessBundleDescriptors.TimerSpec timerSpec =
+ timerOutputIdToSpecMap.get(pCollectionId);
+
return receivedElement -> {
- synchronized (collectorLock) {
- collector.collect(new RawUnionValue(tagInt, receivedElement));
+ WindowedValue windowedValue = (WindowedValue) receivedElement;
+ Timer timer =
+ Preconditions.checkNotNull(
+ (Timer) ((KV) windowedValue.getValue()).getValue(),
+ "Received null Timer from SDK harness: %s",
+ receivedElement);
+ LOG.debug("Timer received: {} {}", pCollectionId, timer);
+ for (Object window : windowedValue.getWindows()) {
+ StateNamespace namespace = StateNamespaces.window(windowCoder,
(BoundedWindow) window);
+ TimeDomain timeDomain = timerSpec.getTimerSpec().getTimeDomain();
+ String timerId = timerSpec.inputCollectionId();
+ TimerInternals.TimerData timerData =
+ TimerInternals.TimerData.of(timerId, namespace,
timer.getTimestamp(), timeDomain);
+ timerDataConsumer.accept(windowedValue, timerData);
}
};
}
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
index 57650ab0991..1e42d7e6afa 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
@@ -82,6 +82,9 @@
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.runtime.state.StateInitializationContext;
@@ -711,6 +714,7 @@ public void fireTimer(InternalTimer<?, TimerData> timer) {
// This is a user timer, so namespace must be WindowNamespace
checkArgument(namespace instanceof WindowNamespace);
BoundedWindow window = ((WindowNamespace) namespace).getWindow();
+ timerInternals.cleanupPendingTimer(timerData);
pushbackDoFnRunner.onTimer(
timerData.getTimerId(), window, timerData.getTimestamp(),
timerData.getDomain());
}
@@ -918,6 +922,20 @@ public TimerInternals timerInternals() {
class FlinkTimerInternals implements TimerInternals {
+ /**
+ * Pending Timers (=not been fired yet) by context id. The id is generated
from the state
+ * namespace of the timer and the timer's id. Necessary for supporting
removal of existing
+ * timers. In Flink removal of timers can only be done by providing id and
time of the timer.
+ */
+ private final MapState<String, TimerData> pendingTimersById;
+
+ private FlinkTimerInternals() {
+ MapStateDescriptor<String, TimerData> pendingTimersByIdStateDescriptor =
+ new MapStateDescriptor<>(
+ "pending-timers", new StringSerializer(), new
CoderTypeSerializer<>(timerCoder));
+ this.pendingTimersById =
getKeyedStateStore().getMapState(pendingTimersByIdStateDescriptor);
+ }
+
@Override
public void setTimer(
StateNamespace namespace, String timerId, Instant target, TimeDomain
timeDomain) {
@@ -927,22 +945,57 @@ public void setTimer(
/** @deprecated use {@link #setTimer(StateNamespace, String, Instant,
TimeDomain)}. */
@Deprecated
@Override
- public void setTimer(TimerData timerKey) {
- long time = timerKey.getTimestamp().getMillis();
- switch (timerKey.getDomain()) {
+ public void setTimer(TimerData timer) {
+ try {
+ getKeyedStateBackend().setCurrentKey(getCurrentKey());
+ String contextTimerId = getContextTimerId(timer);
+ // Only one timer can exist at a time for a given timer id and context.
+ // If a timer gets set twice in the same context, the second must
+ // override the first. Thus, we must cancel any pending timers
+ // before we set the new one.
+ cancelPendingTimerById(contextTimerId);
+ registerTimer(timer, contextTimerId);
+ } catch (Exception e) {
+ throw new RuntimeException("Failed to set timer", e);
+ }
+ }
+
+ private void registerTimer(TimerData timer, String contextTimerId) throws
Exception {
+ long time = timer.getTimestamp().getMillis();
+ switch (timer.getDomain()) {
case EVENT_TIME:
- timerService.registerEventTimeTimer(timerKey, time);
+ timerService.registerEventTimeTimer(timer, time);
break;
case PROCESSING_TIME:
case SYNCHRONIZED_PROCESSING_TIME:
- timerService.registerProcessingTimeTimer(timerKey, time);
+ timerService.registerProcessingTimeTimer(timer, time);
break;
default:
- throw new UnsupportedOperationException(
- "Unsupported time domain: " + timerKey.getDomain());
+ throw new UnsupportedOperationException("Unsupported time domain: "
+ timer.getDomain());
+ }
+ pendingTimersById.put(contextTimerId, timer);
+ }
+
+ private void cancelPendingTimerById(String contextTimerId) throws
Exception {
+ TimerData oldTimer = pendingTimersById.get(contextTimerId);
+ if (oldTimer != null) {
+ deleteTimer(oldTimer);
}
}
+ void cleanupPendingTimer(TimerData timer) {
+ try {
+ pendingTimersById.remove(getContextTimerId(timer));
+ } catch (Exception e) {
+ throw new RuntimeException("Failed to cleanup state with pending
timers", e);
+ }
+ }
+
+ /** Unique contextual id of a timer. Used to look up any existing timers
in a context. */
+ private String getContextTimerId(TimerData timer) {
+ return timer.getTimerId() + timer.getNamespace().stringKey();
+ }
+
/** @deprecated use {@link #deleteTimer(StateNamespace, String,
TimeDomain)}. */
@Deprecated
@Override
@@ -959,6 +1012,7 @@ public void deleteTimer(StateNamespace namespace, String
timerId, TimeDomain tim
@Deprecated
@Override
public void deleteTimer(TimerData timerKey) {
+ cleanupPendingTimer(timerKey);
long time = timerKey.getTimestamp().getMillis();
switch (timerKey.getDomain()) {
case EVENT_TIME:
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
index 984817fdb66..20717023ef0 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
@@ -20,7 +20,6 @@
import static org.apache.flink.util.Preconditions.checkNotNull;
import com.google.common.base.Preconditions;
-import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
@@ -57,6 +56,7 @@
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.runners.fnexecution.state.StateRequestHandlers;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.state.BagState;
@@ -65,6 +65,7 @@
import org.apache.beam.sdk.transforms.join.RawUnionValue;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionView;
@@ -280,8 +281,9 @@ public void setKeyContextElement1(StreamRecord record)
throws Exception {
@Override
public void setCurrentKey(Object key) {
- // We don't need to set anything, the key is set manually on the state
backend
- // This will be called by HeapInternalTimerService before a timer is fired
+ // We don't need to set anything, the key is set manually on the state
backend in
+ // the case of state access. For timers, the key will be extracted from
the timer
+ // element, i.e. in HeapInternalTimerService
if (!usesTimers) {
throw new UnsupportedOperationException(
"Current key for state backend can only be set by state requests
from SDK workers or when processing timers.");
@@ -291,7 +293,7 @@ public void setCurrentKey(Object key) {
@Override
public Object getCurrentKey() {
// This is the key retrieved by HeapInternalTimerService when setting a
Flink timer
- return sdkHarnessRunner.getTimerKeyForRegistration();
+ return sdkHarnessRunner.getCurrentTimerKey();
}
@Override
@@ -300,17 +302,16 @@ public void fireTimer(InternalTimer<?,
TimerInternals.TimerData> timer) {
final ByteBuffer encodedKey = (ByteBuffer) timer.getKey();
@SuppressWarnings("ByteBufferBackingArray")
byte[] bytes = encodedKey.array();
- ByteArrayInputStream byteStream = new ByteArrayInputStream(bytes);
final Object decodedKey;
try {
- decodedKey = keyCoder.decode(byteStream);
- } catch (IOException e) {
+ decodedKey = CoderUtils.decodeFromByteArray(keyCoder, bytes);
+ } catch (CoderException e) {
throw new RuntimeException(
- String.format(
- Locale.ENGLISH, "Failed to decode encoded key: %s",
Arrays.toString(bytes)));
+ String.format(Locale.ENGLISH, "Failed to decode encoded key: %s",
Arrays.toString(bytes)),
+ e);
}
// Prepare the SdkHarnessRunner with the key for the timer
- sdkHarnessRunner.setTimerKeyForFire(decodedKey);
+ sdkHarnessRunner.setCurrentTimerKey(decodedKey);
super.fireTimer(timer);
}
@@ -428,11 +429,11 @@ public void processWatermark(Watermark mark) throws
Exception {
private RemoteBundle remoteBundle;
private FnDataReceiver<WindowedValue<?>> mainInputReceiver;
private Runnable bundleFinishedCallback;
- // Timer key set before calling Flink's internal timer service. Used to
- // avoid synchronizing on the state backend.
- private Object keyForTimerToBeSet;
- // Set before calling onTimer
- private Object keyForTimerToBeFired;
+ // Timer key set before calling Flink's internal timer service to register
+ // a timer. The timer service will retrieve this with a call to {@code
getCurrentKey}.
+ // Before firing a timer, this will be initialized with the current key
+ // from the timer element.
+ private Object currentTimerKey;
public SdkHarnessDoFnRunner(
String mainInput,
@@ -505,7 +506,7 @@ public void processElement(WindowedValue<InputT> element) {
public void onTimer(
String timerId, BoundedWindow window, Instant timestamp, TimeDomain
timeDomain) {
Preconditions.checkNotNull(
- keyForTimerToBeFired, "Key for timer needs to be set before calling
onTimer");
+ currentTimerKey, "Key for timer needs to be set before calling
onTimer");
LOG.debug("timer callback: {} {} {} {}", timerId, window, timestamp,
timeDomain);
FnDataReceiver<WindowedValue<?>> timerReceiver =
Preconditions.checkNotNull(
@@ -514,7 +515,7 @@ public void onTimer(
timerId);
WindowedValue<KV<Object, Timer>> timerValue =
WindowedValue.of(
- KV.of(keyForTimerToBeFired, Timer.of(timestamp, new byte[0])),
+ KV.of(currentTimerKey, Timer.of(timestamp, new byte[0])),
timestamp,
Collections.singleton(window),
PaneInfo.NO_FIRING);
@@ -524,7 +525,7 @@ public void onTimer(
throw new RuntimeException(
String.format(Locale.ENGLISH, "Failed to process timer %s",
timerReceiver), e);
} finally {
- keyForTimerToBeFired = null;
+ currentTimerKey = null;
}
}
@@ -547,13 +548,13 @@ public void finishBundle() {
}
/** Key for timer which has not been registered yet. */
- Object getTimerKeyForRegistration() {
- return keyForTimerToBeSet;
+ Object getCurrentTimerKey() {
+ return currentTimerKey;
}
/** Key for timer which is about to be fired. */
- void setTimerKeyForFire(Object key) {
- this.keyForTimerToBeFired = key;
+ void setCurrentTimerKey(Object key) {
+ this.currentTimerKey = key;
}
boolean isBundleInProgress() {
@@ -605,12 +606,12 @@ private void emitResults() {
private void setTimer(WindowedValue timerElement, TimerInternals.TimerData
timerData) {
try {
- keyForTimerToBeSet = keySelector.getKey(timerElement);
+ currentTimerKey = keySelector.getKey(timerElement);
timerInternals.setTimer(timerData);
} catch (Exception e) {
throw new RuntimeException("Couldn't set timer", e);
} finally {
- keyForTimerToBeSet = null;
+ currentTimerKey = null;
}
}
diff --git
a/runners/flink/src/test/java/org/apache/beam/runners/flink/PortableTimersExecutionTest.java
b/runners/flink/src/test/java/org/apache/beam/runners/flink/PortableTimersExecutionTest.java
index 1d5226f7e1d..f3f04b0b340 100644
---
a/runners/flink/src/test/java/org/apache/beam/runners/flink/PortableTimersExecutionTest.java
+++
b/runners/flink/src/test/java/org/apache/beam/runners/flink/PortableTimersExecutionTest.java
@@ -66,9 +66,8 @@
public class PortableTimersExecutionTest implements Serializable {
@Parameters
- // TODO(mxm) enable tor batch
public static Object[] testModes() {
- return new Object[] {true};
+ return new Object[] {true, false};
}
@Parameter public boolean isStreaming;
diff --git
a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
index 33832ce0594..8c4d930d432 100644
---
a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
+++
b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
@@ -503,7 +503,8 @@ public void onTimer(OnTimerContext context,
@StateId(stateId) ValueState<String>
WindowedValue.of(
KV.of("key2", 7 + offset), new Instant(3), window1,
PaneInfo.NO_FIRING)));
- assertEquals(2, testHarness.numKeyedStateEntries());
+ // 2 entries for the elements and 2 for the pending timers
+ assertEquals(4, testHarness.numKeyedStateEntries());
testHarness.getOutput().clear();
@@ -527,7 +528,7 @@ public void onTimer(OnTimerContext context,
@StateId(stateId) ValueState<String>
WindowedValue.of(
KV.of("key2", timerOutput), new Instant(9), window1,
PaneInfo.NO_FIRING)));
- // ensure the state was garbage collected
+ // ensure the state was garbage collected and the pending timers have been
removed
assertEquals(0, testHarness.numKeyedStateEntries());
testHarness.close();
@@ -567,7 +568,7 @@ void testSideInputs(boolean keyed) throws Exception {
outputTag,
Collections.emptyList(),
new DoFnOperator.MultiOutputOutputManagerFactory<>(outputTag,
coder),
- WindowingStrategy.globalDefault(),
+ WindowingStrategy.of(FixedWindows.of(Duration.millis(100))),
sideInputMapping, /* side-input mapping */
ImmutableList.of(view1, view2), /* side inputs */
PipelineOptionsFactory.as(FlinkPipelineOptions.class),
@@ -774,7 +775,7 @@ public void keyedParDoSideInputCheckpointing() throws
Exception {
outputTag,
Collections.emptyList(),
new
DoFnOperator.MultiOutputOutputManagerFactory<>(outputTag, coder),
- WindowingStrategy.globalDefault(),
+ WindowingStrategy.of(FixedWindows.of(Duration.millis(100))),
sideInputMapping, /* side-input mapping */
ImmutableList.of(view1, view2), /* side inputs */
PipelineOptionsFactory.as(FlinkPipelineOptions.class),
@@ -870,7 +871,7 @@ public void nonKeyedParDoPushbackDataCheckpointing() throws
Exception {
outputTag,
Collections.emptyList(),
new
DoFnOperator.MultiOutputOutputManagerFactory<>(outputTag, coder),
- WindowingStrategy.globalDefault(),
+ WindowingStrategy.of(FixedWindows.of(Duration.millis(100))),
sideInputMapping, /* side-input mapping */
ImmutableList.of(view1, view2), /* side inputs */
PipelineOptionsFactory.as(FlinkPipelineOptions.class),
@@ -908,7 +909,7 @@ public void keyedParDoPushbackDataCheckpointing() throws
Exception {
outputTag,
Collections.emptyList(),
new
DoFnOperator.MultiOutputOutputManagerFactory<>(outputTag, coder),
- WindowingStrategy.globalDefault(),
+ WindowingStrategy.of(FixedWindows.of(Duration.millis(100))),
sideInputMapping, /* side-input mapping */
ImmutableList.of(view1, view2), /* side inputs */
PipelineOptionsFactory.as(FlinkPipelineOptions.class),
diff --git
a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
index 81ebb898e7e..e50dde4ea9c 100644
---
a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
+++
b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.runners.flink.translation.functions;
+import static
org.apache.beam.runners.core.construction.PTransformTranslation.PAR_DO_TRANSFORM_URN;
import static org.hamcrest.Matchers.is;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doThrow;
@@ -28,6 +29,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
@@ -46,7 +48,6 @@
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;
-import org.hamcrest.Matchers;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@@ -77,6 +78,7 @@
@Mock private FlinkExecutableStageContext stageContext;
@Mock private StageBundleFactory stageBundleFactory;
@Mock private StateRequestHandler stateRequestHandler;
+ @Mock private ProcessBundleDescriptors.ExecutableProcessBundleDescriptor
processBundleDescriptor;
// NOTE: ExecutableStage.fromPayload expects exactly one input, so we
provide one here. These unit
// tests in general ignore the executable stage itself and mock around it.
@@ -85,17 +87,31 @@
.setInput("input")
.setComponents(
Components.newBuilder()
+ .putTransforms(
+ "transform",
+ RunnerApi.PTransform.newBuilder()
+ .putInputs("bla", "input")
+
.setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PAR_DO_TRANSFORM_URN))
+ .build())
.putPcollections("input", PCollection.getDefaultInstance())
.build())
+ .addUserStates(
+
ExecutableStagePayload.UserStateId.newBuilder().setTransformId("transform").build())
.build();
private final JobInfo jobInfo =
JobInfo.create("job-id", "job-name", "retrieval-token",
Struct.getDefaultInstance());
@Before
- public void setUpMocks() {
+ public void setUpMocks() throws Exception {
MockitoAnnotations.initMocks(this);
when(runtimeContext.getDistributedCache()).thenReturn(distributedCache);
when(stageContext.getStageBundleFactory(any())).thenReturn(stageBundleFactory);
+ RemoteBundle remoteBundle = Mockito.mock(RemoteBundle.class);
+ when(stageBundleFactory.getBundle(any(), any(),
any())).thenReturn(remoteBundle);
+ ImmutableMap input =
+ ImmutableMap.builder().put("input",
Mockito.mock(FnDataReceiver.class)).build();
+ when(remoteBundle.getInputReceivers()).thenReturn(input);
+
when(processBundleDescriptor.getTimerSpecs()).thenReturn(Collections.emptyMap());
}
@Test
@@ -109,7 +125,7 @@ public void sdkErrorsSurfaceOnClose() throws Exception {
@SuppressWarnings("unchecked")
FnDataReceiver<WindowedValue<?>> receiver =
Mockito.mock(FnDataReceiver.class);
-
when(bundle.getInputReceivers()).thenReturn(ImmutableMap.of("pCollectionId",
receiver));
+ when(bundle.getInputReceivers()).thenReturn(ImmutableMap.of("input",
receiver));
Exception expected = new Exception();
doThrow(expected).when(bundle).close();
@@ -117,16 +133,6 @@ public void sdkErrorsSurfaceOnClose() throws Exception {
function.mapPartition(Collections.emptyList(), collector);
}
- @Test
- public void checksForRuntimeContextChanges() throws Exception {
- FlinkExecutableStageFunction<Integer> function =
getFunction(Collections.emptyMap());
- function.open(new Configuration());
- // Change runtime context.
- function.setRuntimeContext(Mockito.mock(RuntimeContext.class));
- thrown.expect(Matchers.instanceOf(IllegalStateException.class));
- function.mapPartition(Collections.emptyList(), collector);
- }
-
@Test
public void expectedInputsAreSent() throws Exception {
FlinkExecutableStageFunction<Integer> function =
getFunction(Collections.emptyMap());
@@ -138,7 +144,7 @@ public void expectedInputsAreSent() throws Exception {
@SuppressWarnings("unchecked")
FnDataReceiver<WindowedValue<?>> receiver =
Mockito.mock(FnDataReceiver.class);
-
when(bundle.getInputReceivers()).thenReturn(ImmutableMap.of("pCollectionId",
receiver));
+ when(bundle.getInputReceivers()).thenReturn(ImmutableMap.of("input",
receiver));
WindowedValue<Integer> one = WindowedValue.valueInGlobalWindow(1);
WindowedValue<Integer> two = WindowedValue.valueInGlobalWindow(2);
@@ -165,6 +171,9 @@ public void outputsAreTaggedCorrectly() throws Exception {
// We use a real StageBundleFactory here in order to exercise the output
receiver factory.
StageBundleFactory stageBundleFactory =
new StageBundleFactory() {
+
+ private boolean once;
+
@Override
public RemoteBundle getBundle(
OutputReceiverFactory receiverFactory,
@@ -179,7 +188,7 @@ public String getId() {
@Override
public Map<String, FnDataReceiver<WindowedValue<?>>>
getInputReceivers() {
return ImmutableMap.of(
- "pCollectionId",
+ "input",
input -> {
/* Ignore input*/
});
@@ -187,10 +196,14 @@ public String getId() {
@Override
public void close() throws Exception {
+ if (once) {
+ return;
+ }
// Emit all values to the runner when the bundle is closed.
receiverFactory.create("one").accept(three);
receiverFactory.create("two").accept(four);
receiverFactory.create("three").accept(five);
+ once = true;
}
};
}
@@ -198,7 +211,7 @@ public void close() throws Exception {
@Override
public ProcessBundleDescriptors.ExecutableProcessBundleDescriptor
getProcessBundleDescriptor() {
- return null;
+ return processBundleDescriptor;
}
@Override
@@ -243,8 +256,7 @@ public void testStageBundleClosed() throws Exception {
Mockito.mock(FlinkExecutableStageContext.Factory.class);
when(contextFactory.get(any())).thenReturn(stageContext);
FlinkExecutableStageFunction<Integer> function =
- new FlinkExecutableStageFunction<>(
- stagePayload, jobInfo, outputMap, contextFactory, isStateful);
+ new FlinkExecutableStageFunction<>(stagePayload, jobInfo, outputMap,
contextFactory, null);
function.setRuntimeContext(runtimeContext);
Whitebox.setInternalState(function, "stateRequestHandler",
stateRequestHandler);
return function;
diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index 942faed0554..4ccf89de8d3 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -107,10 +107,6 @@ def test_read(self):
def test_no_subtransform_composite(self):
raise unittest.SkipTest("BEAM-4781")
- def test_pardo_timers(self):
- # TODO Enable once BEAM-5999 is fixed.
- raise unittest.SkipTest("BEAM-4681 - User timers not yet supported.")
-
def test_assert_that(self):
# We still want to make sure asserts fail, even if the message
# isn't right (BEAM-6019).
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
Issue Time Tracking
-------------------
Worklog Id: (was: 167434)
Time Spent: 25h 10m (was: 25h)
> Integrate support for timers using the portability APIs into Flink
> ------------------------------------------------------------------
>
> Key: BEAM-4681
> URL: https://issues.apache.org/jira/browse/BEAM-4681
> Project: Beam
> Issue Type: Sub-task
> Components: runner-flink
> Reporter: Luke Cwik
> Assignee: Maximilian Michels
> Priority: Major
> Labels: portability, portability-flink
> Fix For: 2.9.0
>
> Time Spent: 25h 10m
> Remaining Estimate: 0h
>
> Consider using the code produced in BEAM-4658 to support timers.
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)