This is an automated email from the ASF dual-hosted git repository.
iemejia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new c45a50f [BEAM-6935] Spark portable runner: implement side inputs
new 16b58ad Merge pull request #8220: [BEAM-6935] Spark portable runner:
implement side inputs
c45a50f is described below
commit c45a50fb092171ce4fa5f8b0758a584911d4f50d
Author: Kyle Weaver <[email protected]>
AuthorDate: Thu Mar 28 19:16:51 2019 -0700
[BEAM-6935] Spark portable runner: implement side inputs
---
.../functions/FlinkExecutableStageFunction.java | 4 +-
.../translation/BatchSideInputHandlerFactory.java} | 35 ++++++-------
.../BatchSideInputHandlerFactoryTest.java} | 40 +++++++--------
.../runners/spark/translation/BoundedDataset.java | 9 ++++
.../SparkBatchPortablePipelineTranslator.java | 47 +++++++++++++++--
.../translation/SparkExecutableStageFunction.java | 59 +++++++++++++++++++---
.../runners/spark/SparkPortableExecutionTest.java | 36 +++++++++----
.../SparkExecutableStageFunctionTest.java | 15 +++---
8 files changed, 181 insertions(+), 64 deletions(-)
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
index e7dafa8..c02aa65 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
@@ -54,6 +54,7 @@ import
org.apache.beam.runners.fnexecution.control.StageBundleFactory;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.runners.fnexecution.state.StateRequestHandlers;
+import
org.apache.beam.runners.fnexecution.translation.BatchSideInputHandlerFactory;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.io.FileSystems;
@@ -167,7 +168,8 @@ public class FlinkExecutableStageFunction<InputT> extends
AbstractRichFunction
RuntimeContext runtimeContext) {
final StateRequestHandler sideInputHandler;
StateRequestHandlers.SideInputHandlerFactory sideInputHandlerFactory =
- FlinkBatchSideInputHandlerFactory.forStage(executableStage,
runtimeContext);
+ BatchSideInputHandlerFactory.forStage(
+ executableStage, runtimeContext::getBroadcastVariable);
try {
sideInputHandler =
StateRequestHandlers.forSideInputHandlerFactory(
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactory.java
b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactory.java
similarity index 87%
rename from
runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactory.java
rename to
runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactory.java
index 798c32b..5460898 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactory.java
+++
b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactory.java
@@ -15,7 +15,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.beam.runners.flink.translation.functions;
+package org.apache.beam.runners.fnexecution.translation;
import static
org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;
@@ -43,24 +43,25 @@ import org.apache.beam.sdk.values.KV;
import
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
import
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMultimap;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Multimap;
-import org.apache.flink.api.common.functions.RuntimeContext;
-/**
- * {@link StateRequestHandler} that uses a Flink {@link RuntimeContext} to
access Flink broadcast
- * variable that represent side inputs.
- */
-class FlinkBatchSideInputHandlerFactory implements SideInputHandlerFactory {
+/** {@link StateRequestHandler} that uses a {@link SideInputGetter} to access
side inputs. */
+public class BatchSideInputHandlerFactory implements SideInputHandlerFactory {
// Map from side input id to global PCollection id.
private final Map<SideInputId, PCollectionNode> sideInputToCollection;
- private final RuntimeContext runtimeContext;
+ private final SideInputGetter sideInputGetter;
+
+ /** Returns the value for the side input with the given PCollection id from
the runner. */
+ public interface SideInputGetter {
+ <T> List<T> getSideInput(String pCollectionId);
+ }
/**
* Creates a new state handler for the given stage. Note that this requires
a traversal of the
* stage itself, so this should only be called once per stage rather than
once per bundle.
*/
- static FlinkBatchSideInputHandlerFactory forStage(
- ExecutableStage stage, RuntimeContext runtimeContext) {
+ public static BatchSideInputHandlerFactory forStage(
+ ExecutableStage stage, SideInputGetter sideInputGetter) {
ImmutableMap.Builder<SideInputId, PCollectionNode> sideInputBuilder =
ImmutableMap.builder();
for (SideInputReference sideInput : stage.getSideInputs()) {
sideInputBuilder.put(
@@ -70,13 +71,13 @@ class FlinkBatchSideInputHandlerFactory implements
SideInputHandlerFactory {
.build(),
sideInput.collection());
}
- return new FlinkBatchSideInputHandlerFactory(sideInputBuilder.build(),
runtimeContext);
+ return new BatchSideInputHandlerFactory(sideInputBuilder.build(),
sideInputGetter);
}
- private FlinkBatchSideInputHandlerFactory(
- Map<SideInputId, PCollectionNode> sideInputToCollection, RuntimeContext
runtimeContext) {
+ private BatchSideInputHandlerFactory(
+ Map<SideInputId, PCollectionNode> sideInputToCollection, SideInputGetter
sideInputGetter) {
this.sideInputToCollection = sideInputToCollection;
- this.runtimeContext = runtimeContext;
+ this.sideInputGetter = sideInputGetter;
}
@Override
@@ -96,7 +97,7 @@ class FlinkBatchSideInputHandlerFactory implements
SideInputHandlerFactory {
@SuppressWarnings("unchecked") // T == V
Coder<V> outputCoder = (Coder<V>) elementCoder;
return forIterableSideInput(
- runtimeContext.getBroadcastVariable(collectionNode.getId()),
outputCoder, windowCoder);
+ sideInputGetter.getSideInput(collectionNode.getId()), outputCoder,
windowCoder);
} else if
(PTransformTranslation.MULTIMAP_SIDE_INPUT.equals(accessPattern.getUrn())
||
Materializations.MULTIMAP_MATERIALIZATION_URN.equals(accessPattern.getUrn())) {
// TODO: Remove non standard URN.
@@ -104,7 +105,7 @@ class FlinkBatchSideInputHandlerFactory implements
SideInputHandlerFactory {
@SuppressWarnings("unchecked") // T == KV<?, V>
KvCoder<?, V> kvCoder = (KvCoder<?, V>) elementCoder;
return forMultimapSideInput(
- runtimeContext.getBroadcastVariable(collectionNode.getId()),
+ sideInputGetter.getSideInput(collectionNode.getId()),
kvCoder.getKeyCoder(),
kvCoder.getValueCoder(),
windowCoder);
@@ -202,7 +203,7 @@ class FlinkBatchSideInputHandlerFactory implements
SideInputHandlerFactory {
@AutoValue
abstract static class SideInputKey {
static SideInputKey of(Object key, Object window) {
- return new AutoValue_FlinkBatchSideInputHandlerFactory_SideInputKey(key,
window);
+ return new AutoValue_BatchSideInputHandlerFactory_SideInputKey(key,
window);
}
@Nullable
diff --git
a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactoryTest.java
b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactoryTest.java
similarity index 89%
rename from
runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactoryTest.java
rename to
runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactoryTest.java
index 897289f..f664aa9 100644
---
a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactoryTest.java
+++
b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactoryTest.java
@@ -15,7 +15,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.beam.runners.flink.translation.functions;
+package org.apache.beam.runners.fnexecution.translation;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
@@ -50,7 +50,6 @@ import
org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCod
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
-import org.apache.flink.api.common.functions.RuntimeContext;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.joda.time.Instant;
@@ -63,9 +62,9 @@ import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
-/** Tests for {@link FlinkBatchSideInputHandlerFactory}. */
+/** Tests for {@link BatchSideInputHandlerFactory}. */
@RunWith(JUnit4.class)
-public class FlinkBatchSideInputHandlerFactoryTest {
+public class BatchSideInputHandlerFactoryTest {
private static final String TRANSFORM_ID = "transform-id";
private static final String SIDE_INPUT_NAME = "side-input";
@@ -87,7 +86,7 @@ public class FlinkBatchSideInputHandlerFactoryTest {
@Rule public ExpectedException thrown = ExpectedException.none();
- @Mock private RuntimeContext context;
+ @Mock private BatchSideInputHandlerFactory.SideInputGetter context;
@Before
public void setUpMocks() {
@@ -97,8 +96,7 @@ public class FlinkBatchSideInputHandlerFactoryTest {
@Test
public void invalidSideInputThrowsException() {
ExecutableStage stage = createExecutableStage(Collections.emptyList());
- FlinkBatchSideInputHandlerFactory factory =
- FlinkBatchSideInputHandlerFactory.forStage(stage, context);
+ BatchSideInputHandlerFactory factory =
BatchSideInputHandlerFactory.forStage(stage, context);
thrown.expect(instanceOf(IllegalArgumentException.class));
factory.forSideInput(
"transform-id",
@@ -110,8 +108,8 @@ public class FlinkBatchSideInputHandlerFactoryTest {
@Test
public void emptyResultForEmptyCollection() {
- FlinkBatchSideInputHandlerFactory factory =
- FlinkBatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
+ BatchSideInputHandlerFactory factory =
+ BatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
SideInputHandler<Integer, GlobalWindow> handler =
factory.forSideInput(
TRANSFORM_ID,
@@ -127,12 +125,12 @@ public class FlinkBatchSideInputHandlerFactoryTest {
@Test
public void singleElementForCollection() {
- when(context.getBroadcastVariable(COLLECTION_ID))
+ when(context.getSideInput(COLLECTION_ID))
.thenReturn(
Arrays.asList(WindowedValue.valueInGlobalWindow(KV.<Void,
Integer>of(null, 3))));
- FlinkBatchSideInputHandlerFactory factory =
- FlinkBatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
+ BatchSideInputHandlerFactory factory =
+ BatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
SideInputHandler<Integer, GlobalWindow> handler =
factory.forSideInput(
TRANSFORM_ID,
@@ -146,15 +144,15 @@ public class FlinkBatchSideInputHandlerFactoryTest {
@Test
public void groupsValuesByKey() {
- when(context.getBroadcastVariable(COLLECTION_ID))
+ when(context.getSideInput(COLLECTION_ID))
.thenReturn(
Arrays.asList(
WindowedValue.valueInGlobalWindow(KV.of("foo", 2)),
WindowedValue.valueInGlobalWindow(KV.of("bar", 3)),
WindowedValue.valueInGlobalWindow(KV.of("foo", 5))));
- FlinkBatchSideInputHandlerFactory factory =
- FlinkBatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
+ BatchSideInputHandlerFactory factory =
+ BatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
SideInputHandler<Integer, GlobalWindow> handler =
factory.forSideInput(
TRANSFORM_ID,
@@ -173,7 +171,7 @@ public class FlinkBatchSideInputHandlerFactoryTest {
Instant instantC = new DateTime(2018, 1, 1, 1, 3,
DateTimeZone.UTC).toInstant();
IntervalWindow windowA = new IntervalWindow(instantA, instantB);
IntervalWindow windowB = new IntervalWindow(instantB, instantC);
- when(context.getBroadcastVariable(COLLECTION_ID))
+ when(context.getSideInput(COLLECTION_ID))
.thenReturn(
Arrays.asList(
WindowedValue.of(KV.of("foo", 1), instantA, windowA,
PaneInfo.NO_FIRING),
@@ -183,8 +181,8 @@ public class FlinkBatchSideInputHandlerFactoryTest {
WindowedValue.of(KV.of("bar", 5), instantB, windowB,
PaneInfo.NO_FIRING),
WindowedValue.of(KV.of("foo", 6), instantB, windowB,
PaneInfo.NO_FIRING)));
- FlinkBatchSideInputHandlerFactory factory =
- FlinkBatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
+ BatchSideInputHandlerFactory factory =
+ BatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
SideInputHandler<Integer, IntervalWindow> handler =
factory.forSideInput(
TRANSFORM_ID,
@@ -205,7 +203,7 @@ public class FlinkBatchSideInputHandlerFactoryTest {
Instant instantC = new DateTime(2018, 1, 1, 1, 3,
DateTimeZone.UTC).toInstant();
IntervalWindow windowA = new IntervalWindow(instantA, instantB);
IntervalWindow windowB = new IntervalWindow(instantB, instantC);
- when(context.getBroadcastVariable(COLLECTION_ID))
+ when(context.getSideInput(COLLECTION_ID))
.thenReturn(
Arrays.asList(
WindowedValue.of(1, instantA, windowA, PaneInfo.NO_FIRING),
@@ -213,8 +211,8 @@ public class FlinkBatchSideInputHandlerFactoryTest {
WindowedValue.of(3, instantB, windowB, PaneInfo.NO_FIRING),
WindowedValue.of(4, instantB, windowB, PaneInfo.NO_FIRING)));
- FlinkBatchSideInputHandlerFactory factory =
- FlinkBatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
+ BatchSideInputHandlerFactory factory =
+ BatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
SideInputHandler<Integer, IntervalWindow> handler =
factory.forSideInput(
TRANSFORM_ID,
diff --git
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java
index 1e620e7..c81c5f4 100644
---
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java
+++
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java
@@ -46,6 +46,7 @@ public class BoundedDataset<T> implements Dataset {
private Iterable<WindowedValue<T>> windowedValues;
private Coder<T> coder;
private JavaRDD<WindowedValue<T>> rdd;
+ private List<byte[]> clientBytes;
BoundedDataset(JavaRDD<WindowedValue<T>> rdd) {
this.rdd = rdd;
@@ -69,6 +70,14 @@ public class BoundedDataset<T> implements Dataset {
return rdd;
}
+ List<byte[]> getBytes(WindowedValue.WindowedValueCoder<T> wvCoder) {
+ if (clientBytes == null) {
+ JavaRDDLike<byte[], ?> bytesRDD =
rdd.map(CoderHelpers.toByteFunction(wvCoder));
+ clientBytes = bytesRDD.collect();
+ }
+ return clientBytes;
+ }
+
Iterable<WindowedValue<T>> getValues(PCollection<T> pcollection) {
if (windowedValues == null) {
WindowFn<?, ?> windowFn =
pcollection.getWindowingStrategy().getWindowFn();
diff --git
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
index c65caa4..82557ae 100644
---
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
+++
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
@@ -22,10 +22,12 @@ import static
org.apache.beam.runners.fnexecution.translation.PipelineTranslator
import java.io.IOException;
import java.util.Collections;
+import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.beam.model.pipeline.v1.RunnerApi;
+import
org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.runners.core.SystemReduceFn;
import org.apache.beam.runners.core.construction.PTransformTranslation;
@@ -54,6 +56,8 @@ import
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
import org.apache.spark.HashPartitioner;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.broadcast.Broadcast;
+import scala.Tuple2;
/** Translates a bounded portable pipeline into a Spark job. */
public class SparkBatchPortablePipelineTranslator {
@@ -163,7 +167,7 @@ public class SparkBatchPortablePipelineTranslator {
context.pushDataset(getOutputId(transformNode), new
BoundedDataset<>(groupedByKeyAndWindow));
}
- private static <InputT, OutputT> void translateExecutableStage(
+ private static <InputT, OutputT, SideInputT> void translateExecutableStage(
PTransformNode transformNode, RunnerApi.Pipeline pipeline,
SparkTranslationContext context) {
RunnerApi.ExecutableStagePayload stagePayload;
@@ -180,8 +184,22 @@ public class SparkBatchPortablePipelineTranslator {
Map<String, String> outputs = transformNode.getTransform().getOutputsMap();
BiMap<String, Integer> outputMap = createOutputMap(outputs.values());
- SparkExecutableStageFunction<InputT> function =
- new SparkExecutableStageFunction<>(stagePayload, context.jobInfo,
outputMap);
+ ImmutableMap.Builder<String, Tuple2<Broadcast<List<byte[]>>,
WindowedValueCoder<SideInputT>>>
+ broadcastVariablesBuilder = ImmutableMap.builder();
+ for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
+ RunnerApi.Components components = stagePayload.getComponents();
+ String collectionId =
+ components
+ .getTransformsOrThrow(sideInputId.getTransformId())
+ .getInputsOrThrow(sideInputId.getLocalName());
+ Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 =
+ broadcastSideInput(collectionId, components, context);
+ broadcastVariablesBuilder.put(collectionId, tuple2);
+ }
+
+ SparkExecutableStageFunction<InputT, SideInputT> function =
+ new SparkExecutableStageFunction<>(
+ stagePayload, context.jobInfo, outputMap,
broadcastVariablesBuilder.build());
JavaRDD<RawUnionValue> staged = inputRdd.mapPartitions(function);
for (String outputId : outputs.values()) {
@@ -191,6 +209,29 @@ public class SparkBatchPortablePipelineTranslator {
}
}
+ /**
+ * Collect and serialize the data and then broadcast the result. *This can
be expensive.*
+ *
+ * @return Spark broadcast variable and coder to decode its contents
+ */
+ private static <T> Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<T>>
broadcastSideInput(
+ String collectionId, RunnerApi.Components components,
SparkTranslationContext context) {
+ PCollection collection = components.getPcollectionsOrThrow(collectionId);
+ @SuppressWarnings("unchecked")
+ BoundedDataset<T> dataset = (BoundedDataset<T>)
context.popDataset(collectionId);
+ PCollectionNode collectionNode = PipelineNode.pCollection(collectionId,
collection);
+ WindowedValueCoder<T> coder;
+ try {
+ coder =
+ (WindowedValueCoder<T>)
WireCoders.instantiateRunnerWireCoder(collectionNode, components);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ List<byte[]> bytes = dataset.getBytes(coder);
+ Broadcast<List<byte[]>> broadcast =
context.getSparkContext().broadcast(bytes);
+ return new Tuple2<>(broadcast, coder);
+ }
+
@Nullable
private static Partitioner getPartitioner(SparkTranslationContext context) {
Long bundleSize =
diff --git
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunction.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunction.java
index 93250bc..e9ff511 100644
---
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunction.java
+++
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunction.java
@@ -17,12 +17,15 @@
*/
package org.apache.beam.runners.spark.translation;
+import java.io.IOException;
import java.io.Serializable;
import java.util.EnumMap;
import java.util.Iterator;
+import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.stream.Collectors;
import
org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
@@ -33,17 +36,23 @@ import
org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
import org.apache.beam.runners.fnexecution.control.DefaultJobBundleFactory;
import org.apache.beam.runners.fnexecution.control.JobBundleFactory;
import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory;
+import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors;
import org.apache.beam.runners.fnexecution.control.RemoteBundle;
import org.apache.beam.runners.fnexecution.control.StageBundleFactory;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.runners.fnexecution.state.StateRequestHandlers;
+import
org.apache.beam.runners.fnexecution.translation.BatchSideInputHandlerFactory;
+import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
import org.apache.spark.api.java.function.FlatMapFunction;
+import org.apache.spark.broadcast.Broadcast;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import scala.Tuple2;
/**
* Spark function that passes its input through an SDK-executed {@link
@@ -54,7 +63,7 @@ import org.slf4j.LoggerFactory;
* The resulting data set should be further processed by a {@link
* SparkExecutableStageExtractionFunction}.
*/
-public class SparkExecutableStageFunction<InputT>
+public class SparkExecutableStageFunction<InputT, SideInputT>
implements FlatMapFunction<Iterator<WindowedValue<InputT>>, RawUnionValue>
{
private static final Logger LOG =
LoggerFactory.getLogger(SparkExecutableStageFunction.class);
@@ -62,21 +71,27 @@ public class SparkExecutableStageFunction<InputT>
private final RunnerApi.ExecutableStagePayload stagePayload;
private final Map<String, Integer> outputMap;
private final JobBundleFactoryCreator jobBundleFactoryCreator;
+ // map from pCollection id to tuple of serialized bytes and coder to decode
the bytes
+ private final Map<String, Tuple2<Broadcast<List<byte[]>>,
WindowedValueCoder<SideInputT>>>
+ sideInputs;
SparkExecutableStageFunction(
RunnerApi.ExecutableStagePayload stagePayload,
JobInfo jobInfo,
- Map<String, Integer> outputMap) {
- this(stagePayload, outputMap, () ->
DefaultJobBundleFactory.create(jobInfo));
+ Map<String, Integer> outputMap,
+ Map<String, Tuple2<Broadcast<List<byte[]>>,
WindowedValueCoder<SideInputT>>> sideInputs) {
+ this(stagePayload, outputMap, () ->
DefaultJobBundleFactory.create(jobInfo), sideInputs);
}
SparkExecutableStageFunction(
RunnerApi.ExecutableStagePayload stagePayload,
Map<String, Integer> outputMap,
- JobBundleFactoryCreator jobBundleFactoryCreator) {
+ JobBundleFactoryCreator jobBundleFactoryCreator,
+ Map<String, Tuple2<Broadcast<List<byte[]>>,
WindowedValueCoder<SideInputT>>> sideInputs) {
this.stagePayload = stagePayload;
this.outputMap = outputMap;
this.jobBundleFactoryCreator = jobBundleFactoryCreator;
+ this.sideInputs = sideInputs;
}
@Override
@@ -86,10 +101,8 @@ public class SparkExecutableStageFunction<InputT>
try (StageBundleFactory stageBundleFactory =
jobBundleFactory.forStage(executableStage)) {
ConcurrentLinkedQueue<RawUnionValue> collector = new
ConcurrentLinkedQueue<>();
ReceiverFactory receiverFactory = new ReceiverFactory(collector,
outputMap);
- EnumMap<TypeCase, StateRequestHandler> handlers = new
EnumMap<>(StateKey.TypeCase.class);
- // TODO add state request handlers
StateRequestHandler stateRequestHandler =
- StateRequestHandlers.delegateBasedUponType(handlers);
+ getStateRequestHandler(executableStage,
stageBundleFactory.getProcessBundleDescriptor());
SparkBundleProgressHandler bundleProgressHandler = new
SparkBundleProgressHandler();
try (RemoteBundle bundle =
stageBundleFactory.getBundle(
@@ -109,6 +122,38 @@ public class SparkExecutableStageFunction<InputT>
}
}
+ private StateRequestHandler getStateRequestHandler(
+ ExecutableStage executableStage,
+ ProcessBundleDescriptors.ExecutableProcessBundleDescriptor
processBundleDescriptor) {
+ EnumMap<TypeCase, StateRequestHandler> handlerMap = new
EnumMap<>(StateKey.TypeCase.class);
+ final StateRequestHandler sideInputHandler;
+ StateRequestHandlers.SideInputHandlerFactory sideInputHandlerFactory =
+ BatchSideInputHandlerFactory.forStage(
+ executableStage,
+ new BatchSideInputHandlerFactory.SideInputGetter() {
+ @Override
+ public <T> List<T> getSideInput(String pCollectionId) {
+ Tuple2<Broadcast<List<byte[]>>,
WindowedValueCoder<SideInputT>> tuple2 =
+ sideInputs.get(pCollectionId);
+ Broadcast<List<byte[]>> broadcast = tuple2._1;
+ WindowedValueCoder<SideInputT> coder = tuple2._2;
+ return (List<T>)
+ broadcast.value().stream()
+ .map(bytes -> CoderHelpers.fromByteArray(bytes, coder))
+ .collect(Collectors.toList());
+ }
+ });
+ try {
+ sideInputHandler =
+ StateRequestHandlers.forSideInputHandlerFactory(
+ ProcessBundleDescriptors.getSideInputs(executableStage),
sideInputHandlerFactory);
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to setup state handler", e);
+ }
+ handlerMap.put(StateKey.TypeCase.MULTIMAP_SIDE_INPUT, sideInputHandler);
+ return StateRequestHandlers.delegateBasedUponType(handlerMap);
+ }
+
interface JobBundleFactoryCreator extends Serializable {
JobBundleFactory create();
}
diff --git
a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
index ad97ec0..38bdd1f 100644
---
a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
+++
b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
@@ -34,9 +34,11 @@ import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.Impulse;
import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
import
org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.ListeningExecutorService;
import
org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.MoreExecutors;
import org.junit.AfterClass;
@@ -80,6 +82,20 @@ public class SparkPortableExecutionTest implements
Serializable {
.setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED);
Pipeline p = Pipeline.create(options);
+
+ final PCollectionView<Integer> view =
+ p.apply("impulse23", Impulse.create())
+ .apply(
+ "create23",
+ ParDo.of(
+ new DoFn<byte[], Integer>() {
+ @ProcessElement
+ public void process(ProcessContext context) {
+ context.output(23);
+ }
+ }))
+ .apply(View.asSingleton());
+
PCollection<KV<String, Iterable<Long>>> result =
p.apply("impulse", Impulse.create())
.apply(
@@ -108,15 +124,17 @@ public class SparkPortableExecutionTest implements
Serializable {
.apply(
"print",
ParDo.of(
- new DoFn<KV<String, Iterable<Long>>, KV<String, Long>>() {
- @ProcessElement
- public void process(ProcessContext context) {
- LOG.info("Output element: {}", context.element());
- for (Long i : context.element().getValue()) {
- context.output(KV.of(context.element().getKey(), i));
- }
- }
- }))
+ new DoFn<KV<String, Iterable<Long>>, KV<String,
Long>>() {
+ @ProcessElement
+ public void process(ProcessContext context) {
+ LOG.info("Side input: {}",
context.sideInput(view));
+ LOG.info("Output element: {}", context.element());
+ for (Long i : context.element().getValue()) {
+ context.output(KV.of(context.element().getKey(),
i));
+ }
+ }
+ })
+ .withSideInputs(view))
// Second GBK forces the output to be materialized
.apply("gbk", GroupByKey.create());
diff --git
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java
index 8f1bdca..bba1ea4 100644
---
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java
+++
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java
@@ -89,14 +89,14 @@ public class SparkExecutableStageFunctionTest {
@Test(expected = Exception.class)
public void sdkErrorsSurfaceOnClose() throws Exception {
- SparkExecutableStageFunction<Integer> function =
getFunction(Collections.emptyMap());
+ SparkExecutableStageFunction<Integer, ?> function =
getFunction(Collections.emptyMap());
doThrow(new Exception()).when(remoteBundle).close();
function.call(Collections.emptyIterator());
}
@Test
public void expectedInputsAreSent() throws Exception {
- SparkExecutableStageFunction<Integer> function =
getFunction(Collections.emptyMap());
+ SparkExecutableStageFunction<Integer, ?> function =
getFunction(Collections.emptyMap());
RemoteBundle bundle = Mockito.mock(RemoteBundle.class);
when(stageBundleFactory.getBundle(any(), any(), any())).thenReturn(bundle);
@@ -178,7 +178,7 @@ public class SparkExecutableStageFunctionTest {
};
when(jobBundleFactory.forStage(any())).thenReturn(stageBundleFactory);
- SparkExecutableStageFunction<Integer> function = getFunction(outputTagMap);
+ SparkExecutableStageFunction<Integer, ?> function =
getFunction(outputTagMap);
Iterator<RawUnionValue> iterator =
function.call(Collections.emptyIterator());
Iterable<RawUnionValue> iterable = () -> iterator;
@@ -190,14 +190,17 @@ public class SparkExecutableStageFunctionTest {
@Test
public void testStageBundleClosed() throws Exception {
- SparkExecutableStageFunction<Integer> function =
getFunction(Collections.emptyMap());
+ SparkExecutableStageFunction<Integer, ?> function =
getFunction(Collections.emptyMap());
function.call(Collections.emptyIterator());
verify(stageBundleFactory).getBundle(any(), any(), any());
+ verify(stageBundleFactory).getProcessBundleDescriptor();
verify(stageBundleFactory).close();
verifyNoMoreInteractions(stageBundleFactory);
}
- private <T> SparkExecutableStageFunction<T> getFunction(Map<String, Integer>
outputMap) {
- return new SparkExecutableStageFunction<>(stagePayload, outputMap,
jobBundleFactoryCreator);
+ private <InputT, SideInputT> SparkExecutableStageFunction<InputT,
SideInputT> getFunction(
+ Map<String, Integer> outputMap) {
+ return new SparkExecutableStageFunction<>(
+ stagePayload, outputMap, jobBundleFactoryCreator,
Collections.emptyMap());
}
}