http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java new file mode 100644 index 0000000..48c783d --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java @@ -0,0 +1,594 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation; + +import org.apache.beam.runners.flink.io.ConsoleIO; +import org.apache.beam.runners.flink.translation.functions.FlinkCoGroupKeyedListAggregator; +import org.apache.beam.runners.flink.translation.functions.FlinkCreateFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkDoFnFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkKeyedListAggregationFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkMultiOutputDoFnFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkMultiOutputPruningFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkPartialReduceFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkReduceFunction; +import org.apache.beam.runners.flink.translation.functions.UnionCoder; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.types.KvCoderTypeInformation; +import org.apache.beam.runners.flink.translation.wrappers.SinkOutputFormat; +import org.apache.beam.runners.flink.translation.wrappers.SourceInputFormat; +import com.google.api.client.util.Maps; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.io.AvroIO; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.Write; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResultSchema; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; +import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.Lists; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.operators.Keys; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.io.AvroInputFormat; +import org.apache.flink.api.java.io.AvroOutputFormat; +import org.apache.flink.api.java.io.TextInputFormat; +import org.apache.flink.api.java.operators.CoGroupOperator; +import org.apache.flink.api.java.operators.DataSink; +import org.apache.flink.api.java.operators.DataSource; +import org.apache.flink.api.java.operators.FlatMapOperator; +import org.apache.flink.api.java.operators.GroupCombineOperator; +import org.apache.flink.api.java.operators.GroupReduceOperator; +import org.apache.flink.api.java.operators.Grouping; +import org.apache.flink.api.java.operators.MapPartitionOperator; +import org.apache.flink.api.java.operators.UnsortedGrouping; +import org.apache.flink.core.fs.Path; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.lang.reflect.Field; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Translators for transforming + * Dataflow {@link com.google.cloud.dataflow.sdk.transforms.PTransform}s to + * Flink {@link org.apache.flink.api.java.DataSet}s + */ +public class FlinkBatchTransformTranslators { + + // -------------------------------------------------------------------------------------------- + // Transform Translator Registry + // -------------------------------------------------------------------------------------------- + + @SuppressWarnings("rawtypes") + private static final Map<Class<? extends PTransform>, FlinkBatchPipelineTranslator.BatchTransformTranslator> TRANSLATORS = new HashMap<>(); + + // register the known translators + static { + TRANSLATORS.put(View.CreatePCollectionView.class, new CreatePCollectionViewTranslatorBatch()); + + TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslatorBatch()); + // we don't need this because we translate the Combine.PerKey directly + //TRANSLATORS.put(Combine.GroupedValues.class, new CombineGroupedValuesTranslator()); + + TRANSLATORS.put(Create.Values.class, new CreateTranslatorBatch()); + + TRANSLATORS.put(Flatten.FlattenPCollectionList.class, new FlattenPCollectionTranslatorBatch()); + + TRANSLATORS.put(GroupByKey.GroupByKeyOnly.class, new GroupByKeyOnlyTranslatorBatch()); + // TODO we're currently ignoring windows here but that has to change in the future + TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch()); + + TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoBoundMultiTranslatorBatch()); + TRANSLATORS.put(ParDo.Bound.class, new ParDoBoundTranslatorBatch()); + + TRANSLATORS.put(CoGroupByKey.class, new CoGroupByKeyTranslatorBatch()); + + TRANSLATORS.put(AvroIO.Read.Bound.class, new AvroIOReadTranslatorBatch()); + TRANSLATORS.put(AvroIO.Write.Bound.class, new AvroIOWriteTranslatorBatch()); + + TRANSLATORS.put(Read.Bounded.class, new ReadSourceTranslatorBatch()); + TRANSLATORS.put(Write.Bound.class, new WriteSinkTranslatorBatch()); + + TRANSLATORS.put(TextIO.Read.Bound.class, new TextIOReadTranslatorBatch()); + TRANSLATORS.put(TextIO.Write.Bound.class, new TextIOWriteTranslatorBatch()); + + // Flink-specific + TRANSLATORS.put(ConsoleIO.Write.Bound.class, new ConsoleIOWriteTranslatorBatch()); + + } + + + public static FlinkBatchPipelineTranslator.BatchTransformTranslator<?> getTranslator(PTransform<?, ?> transform) { + return TRANSLATORS.get(transform.getClass()); + } + + private static class ReadSourceTranslatorBatch<T> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<Read.Bounded<T>> { + + @Override + public void translateNode(Read.Bounded<T> transform, FlinkBatchTranslationContext context) { + String name = transform.getName(); + BoundedSource<T> source = transform.getSource(); + PCollection<T> output = context.getOutput(transform); + Coder<T> coder = output.getCoder(); + + TypeInformation<T> typeInformation = context.getTypeInfo(output); + + DataSource<T> dataSource = new DataSource<>(context.getExecutionEnvironment(), + new SourceInputFormat<>(source, context.getPipelineOptions()), typeInformation, name); + + context.setOutputDataSet(output, dataSource); + } + } + + private static class AvroIOReadTranslatorBatch<T> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<AvroIO.Read.Bound<T>> { + private static final Logger LOG = LoggerFactory.getLogger(AvroIOReadTranslatorBatch.class); + + @Override + public void translateNode(AvroIO.Read.Bound<T> transform, FlinkBatchTranslationContext context) { + String path = transform.getFilepattern(); + String name = transform.getName(); +// Schema schema = transform.getSchema(); + PValue output = context.getOutput(transform); + + TypeInformation<T> typeInformation = context.getTypeInfo(output); + + // This is super hacky, but unfortunately we cannot get the type otherwise + Class<T> extractedAvroType; + try { + Field typeField = transform.getClass().getDeclaredField("type"); + typeField.setAccessible(true); + @SuppressWarnings("unchecked") + Class<T> avroType = (Class<T>) typeField.get(transform); + extractedAvroType = avroType; + } catch (NoSuchFieldException | IllegalAccessException e) { + // we know that the field is there and it is accessible + throw new RuntimeException("Could not access type from AvroIO.Bound", e); + } + + DataSource<T> source = new DataSource<>(context.getExecutionEnvironment(), + new AvroInputFormat<>(new Path(path), extractedAvroType), + typeInformation, name); + + context.setOutputDataSet(output, source); + } + } + + private static class AvroIOWriteTranslatorBatch<T> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<AvroIO.Write.Bound<T>> { + private static final Logger LOG = LoggerFactory.getLogger(AvroIOWriteTranslatorBatch.class); + + @Override + public void translateNode(AvroIO.Write.Bound<T> transform, FlinkBatchTranslationContext context) { + DataSet<T> inputDataSet = context.getInputDataSet(context.getInput(transform)); + String filenamePrefix = transform.getFilenamePrefix(); + String filenameSuffix = transform.getFilenameSuffix(); + int numShards = transform.getNumShards(); + String shardNameTemplate = transform.getShardNameTemplate(); + + // TODO: Implement these. We need Flink support for this. + LOG.warn("Translation of TextIO.Write.filenameSuffix not yet supported. Is: {}.", + filenameSuffix); + LOG.warn("Translation of TextIO.Write.shardNameTemplate not yet supported. Is: {}.", shardNameTemplate); + + // This is super hacky, but unfortunately we cannot get the type otherwise + Class<T> extractedAvroType; + try { + Field typeField = transform.getClass().getDeclaredField("type"); + typeField.setAccessible(true); + @SuppressWarnings("unchecked") + Class<T> avroType = (Class<T>) typeField.get(transform); + extractedAvroType = avroType; + } catch (NoSuchFieldException | IllegalAccessException e) { + // we know that the field is there and it is accessible + throw new RuntimeException("Could not access type from AvroIO.Bound", e); + } + + DataSink<T> dataSink = inputDataSet.output(new AvroOutputFormat<>(new Path + (filenamePrefix), extractedAvroType)); + + if (numShards > 0) { + dataSink.setParallelism(numShards); + } + } + } + + private static class TextIOReadTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator<TextIO.Read.Bound<String>> { + private static final Logger LOG = LoggerFactory.getLogger(TextIOReadTranslatorBatch.class); + + @Override + public void translateNode(TextIO.Read.Bound<String> transform, FlinkBatchTranslationContext context) { + String path = transform.getFilepattern(); + String name = transform.getName(); + + TextIO.CompressionType compressionType = transform.getCompressionType(); + boolean needsValidation = transform.needsValidation(); + + // TODO: Implement these. We need Flink support for this. + LOG.warn("Translation of TextIO.CompressionType not yet supported. Is: {}.", compressionType); + LOG.warn("Translation of TextIO.Read.needsValidation not yet supported. Is: {}.", needsValidation); + + PValue output = context.getOutput(transform); + + TypeInformation<String> typeInformation = context.getTypeInfo(output); + DataSource<String> source = new DataSource<>(context.getExecutionEnvironment(), new TextInputFormat(new Path(path)), typeInformation, name); + + context.setOutputDataSet(output, source); + } + } + + private static class TextIOWriteTranslatorBatch<T> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<TextIO.Write.Bound<T>> { + private static final Logger LOG = LoggerFactory.getLogger(TextIOWriteTranslatorBatch.class); + + @Override + public void translateNode(TextIO.Write.Bound<T> transform, FlinkBatchTranslationContext context) { + PValue input = context.getInput(transform); + DataSet<T> inputDataSet = context.getInputDataSet(input); + + String filenamePrefix = transform.getFilenamePrefix(); + String filenameSuffix = transform.getFilenameSuffix(); + boolean needsValidation = transform.needsValidation(); + int numShards = transform.getNumShards(); + String shardNameTemplate = transform.getShardNameTemplate(); + + // TODO: Implement these. We need Flink support for this. + LOG.warn("Translation of TextIO.Write.needsValidation not yet supported. Is: {}.", needsValidation); + LOG.warn("Translation of TextIO.Write.filenameSuffix not yet supported. Is: {}.", filenameSuffix); + LOG.warn("Translation of TextIO.Write.shardNameTemplate not yet supported. Is: {}.", shardNameTemplate); + + //inputDataSet.print(); + DataSink<T> dataSink = inputDataSet.writeAsText(filenamePrefix); + + if (numShards > 0) { + dataSink.setParallelism(numShards); + } + } + } + + private static class ConsoleIOWriteTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator<ConsoleIO.Write.Bound> { + @Override + public void translateNode(ConsoleIO.Write.Bound transform, FlinkBatchTranslationContext context) { + PValue input = context.getInput(transform); + DataSet<?> inputDataSet = context.getInputDataSet(input); + inputDataSet.printOnTaskManager(transform.getName()); + } + } + + private static class WriteSinkTranslatorBatch<T> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<Write.Bound<T>> { + + @Override + public void translateNode(Write.Bound<T> transform, FlinkBatchTranslationContext context) { + String name = transform.getName(); + PValue input = context.getInput(transform); + DataSet<T> inputDataSet = context.getInputDataSet(input); + + inputDataSet.output(new SinkOutputFormat<>(transform, context.getPipelineOptions())).name(name); + } + } + + private static class GroupByKeyOnlyTranslatorBatch<K, V> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<GroupByKey.GroupByKeyOnly<K, V>> { + + @Override + public void translateNode(GroupByKey.GroupByKeyOnly<K, V> transform, FlinkBatchTranslationContext context) { + DataSet<KV<K, V>> inputDataSet = context.getInputDataSet(context.getInput(transform)); + GroupReduceFunction<KV<K, V>, KV<K, Iterable<V>>> groupReduceFunction = new FlinkKeyedListAggregationFunction<>(); + + TypeInformation<KV<K, Iterable<V>>> typeInformation = context.getTypeInfo(context.getOutput(transform)); + + Grouping<KV<K, V>> grouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet.getType())); + + GroupReduceOperator<KV<K, V>, KV<K, Iterable<V>>> outputDataSet = + new GroupReduceOperator<>(grouping, typeInformation, groupReduceFunction, transform.getName()); + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + } + + /** + * Translates a GroupByKey while ignoring window assignments. This is identical to the {@link GroupByKeyOnlyTranslatorBatch} + */ + private static class GroupByKeyTranslatorBatch<K, V> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<GroupByKey<K, V>> { + + @Override + public void translateNode(GroupByKey<K, V> transform, FlinkBatchTranslationContext context) { + DataSet<KV<K, V>> inputDataSet = context.getInputDataSet(context.getInput(transform)); + GroupReduceFunction<KV<K, V>, KV<K, Iterable<V>>> groupReduceFunction = new FlinkKeyedListAggregationFunction<>(); + + TypeInformation<KV<K, Iterable<V>>> typeInformation = context.getTypeInfo(context.getOutput(transform)); + + Grouping<KV<K, V>> grouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet.getType())); + + GroupReduceOperator<KV<K, V>, KV<K, Iterable<V>>> outputDataSet = + new GroupReduceOperator<>(grouping, typeInformation, groupReduceFunction, transform.getName()); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + } + + private static class CombinePerKeyTranslatorBatch<K, VI, VA, VO> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<Combine.PerKey<K, VI, VO>> { + + @Override + public void translateNode(Combine.PerKey<K, VI, VO> transform, FlinkBatchTranslationContext context) { + DataSet<KV<K, VI>> inputDataSet = context.getInputDataSet(context.getInput(transform)); + + @SuppressWarnings("unchecked") + Combine.KeyedCombineFn<K, VI, VA, VO> keyedCombineFn = (Combine.KeyedCombineFn<K, VI, VA, VO>) transform.getFn(); + + KvCoder<K, VI> inputCoder = (KvCoder<K, VI>) context.getInput(transform).getCoder(); + + Coder<VA> accumulatorCoder = + null; + try { + accumulatorCoder = keyedCombineFn.getAccumulatorCoder(context.getInput(transform).getPipeline().getCoderRegistry(), inputCoder.getKeyCoder(), inputCoder.getValueCoder()); + } catch (CannotProvideCoderException e) { + e.printStackTrace(); + // TODO + } + + TypeInformation<KV<K, VI>> kvCoderTypeInformation = new KvCoderTypeInformation<>(inputCoder); + TypeInformation<KV<K, VA>> partialReduceTypeInfo = new KvCoderTypeInformation<>(KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder)); + + Grouping<KV<K, VI>> inputGrouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{"key"}, kvCoderTypeInformation)); + + FlinkPartialReduceFunction<K, VI, VA> partialReduceFunction = new FlinkPartialReduceFunction<>(keyedCombineFn); + + // Partially GroupReduce the values into the intermediate format VA (combine) + GroupCombineOperator<KV<K, VI>, KV<K, VA>> groupCombine = + new GroupCombineOperator<>(inputGrouping, partialReduceTypeInfo, partialReduceFunction, + "GroupCombine: " + transform.getName()); + + // Reduce fully to VO + GroupReduceFunction<KV<K, VA>, KV<K, VO>> reduceFunction = new FlinkReduceFunction<>(keyedCombineFn); + + TypeInformation<KV<K, VO>> reduceTypeInfo = context.getTypeInfo(context.getOutput(transform)); + + Grouping<KV<K, VA>> intermediateGrouping = new UnsortedGrouping<>(groupCombine, new Keys.ExpressionKeys<>(new String[]{"key"}, groupCombine.getType())); + + // Fully reduce the values and create output format VO + GroupReduceOperator<KV<K, VA>, KV<K, VO>> outputDataSet = + new GroupReduceOperator<>(intermediateGrouping, reduceTypeInfo, reduceFunction, transform.getName()); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + } + +// private static class CombineGroupedValuesTranslator<K, VI, VO> implements FlinkPipelineTranslator.TransformTranslator<Combine.GroupedValues<K, VI, VO>> { +// +// @Override +// public void translateNode(Combine.GroupedValues<K, VI, VO> transform, TranslationContext context) { +// DataSet<KV<K, VI>> inputDataSet = context.getInputDataSet(transform.getInput()); +// +// Combine.KeyedCombineFn<? super K, ? super VI, ?, VO> keyedCombineFn = transform.getFn(); +// +// GroupReduceFunction<KV<K, VI>, KV<K, VO>> groupReduceFunction = new FlinkCombineFunction<>(keyedCombineFn); +// +// TypeInformation<KV<K, VO>> typeInformation = context.getTypeInfo(transform.getOutput()); +// +// Grouping<KV<K, VI>> grouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{""}, inputDataSet.getType())); +// +// GroupReduceOperator<KV<K, VI>, KV<K, VO>> outputDataSet = +// new GroupReduceOperator<>(grouping, typeInformation, groupReduceFunction, transform.getName()); +// context.setOutputDataSet(transform.getOutput(), outputDataSet); +// } +// } + + private static class ParDoBoundTranslatorBatch<IN, OUT> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<ParDo.Bound<IN, OUT>> { + private static final Logger LOG = LoggerFactory.getLogger(ParDoBoundTranslatorBatch.class); + + @Override + public void translateNode(ParDo.Bound<IN, OUT> transform, FlinkBatchTranslationContext context) { + DataSet<IN> inputDataSet = context.getInputDataSet(context.getInput(transform)); + + final DoFn<IN, OUT> doFn = transform.getFn(); + + TypeInformation<OUT> typeInformation = context.getTypeInfo(context.getOutput(transform)); + + FlinkDoFnFunction<IN, OUT> doFnWrapper = new FlinkDoFnFunction<>(doFn, context.getPipelineOptions()); + MapPartitionOperator<IN, OUT> outputDataSet = new MapPartitionOperator<>(inputDataSet, typeInformation, doFnWrapper, transform.getName()); + + transformSideInputs(transform.getSideInputs(), outputDataSet, context); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + } + + private static class ParDoBoundMultiTranslatorBatch<IN, OUT> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<ParDo.BoundMulti<IN, OUT>> { + private static final Logger LOG = LoggerFactory.getLogger(ParDoBoundMultiTranslatorBatch.class); + + @Override + public void translateNode(ParDo.BoundMulti<IN, OUT> transform, FlinkBatchTranslationContext context) { + DataSet<IN> inputDataSet = context.getInputDataSet(context.getInput(transform)); + + final DoFn<IN, OUT> doFn = transform.getFn(); + + Map<TupleTag<?>, PCollection<?>> outputs = context.getOutput(transform).getAll(); + + Map<TupleTag<?>, Integer> outputMap = Maps.newHashMap(); + // put the main output at index 0, FlinkMultiOutputDoFnFunction also expects this + outputMap.put(transform.getMainOutputTag(), 0); + int count = 1; + for (TupleTag<?> tag: outputs.keySet()) { + if (!outputMap.containsKey(tag)) { + outputMap.put(tag, count++); + } + } + + // collect all output Coders and create a UnionCoder for our tagged outputs + List<Coder<?>> outputCoders = Lists.newArrayList(); + for (PCollection<?> coll: outputs.values()) { + outputCoders.add(coll.getCoder()); + } + + UnionCoder unionCoder = UnionCoder.of(outputCoders); + + @SuppressWarnings("unchecked") + TypeInformation<RawUnionValue> typeInformation = new CoderTypeInformation<>(unionCoder); + + @SuppressWarnings("unchecked") + FlinkMultiOutputDoFnFunction<IN, OUT> doFnWrapper = new FlinkMultiOutputDoFnFunction(doFn, context.getPipelineOptions(), outputMap); + MapPartitionOperator<IN, RawUnionValue> outputDataSet = new MapPartitionOperator<>(inputDataSet, typeInformation, doFnWrapper, transform.getName()); + + transformSideInputs(transform.getSideInputs(), outputDataSet, context); + + for (Map.Entry<TupleTag<?>, PCollection<?>> output: outputs.entrySet()) { + TypeInformation<Object> outputType = context.getTypeInfo(output.getValue()); + int outputTag = outputMap.get(output.getKey()); + FlinkMultiOutputPruningFunction<Object> pruningFunction = new FlinkMultiOutputPruningFunction<>(outputTag); + FlatMapOperator<RawUnionValue, Object> pruningOperator = new + FlatMapOperator<>(outputDataSet, outputType, + pruningFunction, output.getValue().getName()); + context.setOutputDataSet(output.getValue(), pruningOperator); + + } + } + } + + private static class FlattenPCollectionTranslatorBatch<T> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<Flatten.FlattenPCollectionList<T>> { + + @Override + public void translateNode(Flatten.FlattenPCollectionList<T> transform, FlinkBatchTranslationContext context) { + List<PCollection<T>> allInputs = context.getInput(transform).getAll(); + DataSet<T> result = null; + for(PCollection<T> collection : allInputs) { + DataSet<T> current = context.getInputDataSet(collection); + if (result == null) { + result = current; + } else { + result = result.union(current); + } + } + context.setOutputDataSet(context.getOutput(transform), result); + } + } + + private static class CreatePCollectionViewTranslatorBatch<R, T> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<View.CreatePCollectionView<R, T>> { + @Override + public void translateNode(View.CreatePCollectionView<R, T> transform, FlinkBatchTranslationContext context) { + DataSet<T> inputDataSet = context.getInputDataSet(context.getInput(transform)); + PCollectionView<T> input = transform.apply(null); + context.setSideInputDataSet(input, inputDataSet); + } + } + + private static class CreateTranslatorBatch<OUT> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<Create.Values<OUT>> { + + @Override + public void translateNode(Create.Values<OUT> transform, FlinkBatchTranslationContext context) { + TypeInformation<OUT> typeInformation = context.getOutputTypeInfo(); + Iterable<OUT> elements = transform.getElements(); + + // we need to serialize the elements to byte arrays, since they might contain + // elements that are not serializable by Java serialization. We deserialize them + // in the FlatMap function using the Coder. + + List<byte[]> serializedElements = Lists.newArrayList(); + Coder<OUT> coder = context.getOutput(transform).getCoder(); + for (OUT element: elements) { + ByteArrayOutputStream bao = new ByteArrayOutputStream(); + try { + coder.encode(element, bao, Coder.Context.OUTER); + serializedElements.add(bao.toByteArray()); + } catch (IOException e) { + throw new RuntimeException("Could not serialize Create elements using Coder: " + e); + } + } + + DataSet<Integer> initDataSet = context.getExecutionEnvironment().fromElements(1); + FlinkCreateFunction<Integer, OUT> flatMapFunction = new FlinkCreateFunction<>(serializedElements, coder); + FlatMapOperator<Integer, OUT> outputDataSet = new FlatMapOperator<>(initDataSet, typeInformation, flatMapFunction, transform.getName()); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + } + + private static void transformSideInputs(List<PCollectionView<?>> sideInputs, + MapPartitionOperator<?, ?> outputDataSet, + FlinkBatchTranslationContext context) { + // get corresponding Flink broadcast DataSets + for(PCollectionView<?> input : sideInputs) { + DataSet<?> broadcastSet = context.getSideInputDataSet(input); + outputDataSet.withBroadcastSet(broadcastSet, input.getTagInternal().getId()); + } + } + +// Disabled because it depends on a pending pull request to the DataFlowSDK + /** + * Special composite transform translator. Only called if the CoGroup is two dimensional. + * @param <K> + */ + private static class CoGroupByKeyTranslatorBatch<K, V1, V2> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<CoGroupByKey<K>> { + + @Override + public void translateNode(CoGroupByKey<K> transform, FlinkBatchTranslationContext context) { + KeyedPCollectionTuple<K> input = context.getInput(transform); + + CoGbkResultSchema schema = input.getCoGbkResultSchema(); + List<KeyedPCollectionTuple.TaggedKeyedPCollection<K, ?>> keyedCollections = input.getKeyedCollections(); + + KeyedPCollectionTuple.TaggedKeyedPCollection<K, ?> taggedCollection1 = keyedCollections.get(0); + KeyedPCollectionTuple.TaggedKeyedPCollection<K, ?> taggedCollection2 = keyedCollections.get(1); + + TupleTag<?> tupleTag1 = taggedCollection1.getTupleTag(); + TupleTag<?> tupleTag2 = taggedCollection2.getTupleTag(); + + PCollection<? extends KV<K, ?>> collection1 = taggedCollection1.getCollection(); + PCollection<? extends KV<K, ?>> collection2 = taggedCollection2.getCollection(); + + DataSet<KV<K,V1>> inputDataSet1 = context.getInputDataSet(collection1); + DataSet<KV<K,V2>> inputDataSet2 = context.getInputDataSet(collection2); + + TypeInformation<KV<K,CoGbkResult>> typeInfo = context.getOutputTypeInfo(); + + FlinkCoGroupKeyedListAggregator<K,V1,V2> aggregator = new FlinkCoGroupKeyedListAggregator<>(schema, tupleTag1, tupleTag2); + + Keys.ExpressionKeys<KV<K,V1>> keySelector1 = new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet1.getType()); + Keys.ExpressionKeys<KV<K,V2>> keySelector2 = new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet2.getType()); + + DataSet<KV<K, CoGbkResult>> out = new CoGroupOperator<>(inputDataSet1, inputDataSet2, + keySelector1, keySelector2, + aggregator, typeInfo, null, transform.getName()); + context.setOutputDataSet(context.getOutput(transform), out); + } + } + + // -------------------------------------------------------------------------------------------- + // Miscellaneous + // -------------------------------------------------------------------------------------------- + + private FlinkBatchTransformTranslators() {} +}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTranslationContext.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTranslationContext.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTranslationContext.java new file mode 100644 index 0000000..2294318 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTranslationContext.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation; + +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.types.KvCoderTypeInformation; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TypedPValue; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.typeutils.GenericTypeInfo; + +import java.util.HashMap; +import java.util.Map; + +public class FlinkBatchTranslationContext { + + private final Map<PValue, DataSet<?>> dataSets; + private final Map<PCollectionView<?>, DataSet<?>> broadcastDataSets; + + private final ExecutionEnvironment env; + private final PipelineOptions options; + + private AppliedPTransform<?, ?, ?> currentTransform; + + // ------------------------------------------------------------------------ + + public FlinkBatchTranslationContext(ExecutionEnvironment env, PipelineOptions options) { + this.env = env; + this.options = options; + this.dataSets = new HashMap<>(); + this.broadcastDataSets = new HashMap<>(); + } + + // ------------------------------------------------------------------------ + + public ExecutionEnvironment getExecutionEnvironment() { + return env; + } + + public PipelineOptions getPipelineOptions() { + return options; + } + + @SuppressWarnings("unchecked") + public <T> DataSet<T> getInputDataSet(PValue value) { + return (DataSet<T>) dataSets.get(value); + } + + public void setOutputDataSet(PValue value, DataSet<?> set) { + if (!dataSets.containsKey(value)) { + dataSets.put(value, set); + } + } + + /** + * Sets the AppliedPTransform which carries input/output. + * @param currentTransform + */ + public void setCurrentTransform(AppliedPTransform<?, ?, ?> currentTransform) { + this.currentTransform = currentTransform; + } + + @SuppressWarnings("unchecked") + public <T> DataSet<T> getSideInputDataSet(PCollectionView<?> value) { + return (DataSet<T>) broadcastDataSets.get(value); + } + + public void setSideInputDataSet(PCollectionView<?> value, DataSet<?> set) { + if (!broadcastDataSets.containsKey(value)) { + broadcastDataSets.put(value, set); + } + } + + @SuppressWarnings("unchecked") + public <T> TypeInformation<T> getTypeInfo(PInput output) { + if (output instanceof TypedPValue) { + Coder<?> outputCoder = ((TypedPValue) output).getCoder(); + if (outputCoder instanceof KvCoder) { + return new KvCoderTypeInformation((KvCoder) outputCoder); + } else { + return new CoderTypeInformation(outputCoder); + } + } + return new GenericTypeInfo<>((Class<T>)Object.class); + } + + public <T> TypeInformation<T> getInputTypeInfo() { + return getTypeInfo(currentTransform.getInput()); + } + + public <T> TypeInformation<T> getOutputTypeInfo() { + return getTypeInfo((PValue) currentTransform.getOutput()); + } + + @SuppressWarnings("unchecked") + <I extends PInput> I getInput(PTransform<I, ?> transform) { + return (I) currentTransform.getInput(); + } + + @SuppressWarnings("unchecked") + <O extends POutput> O getOutput(PTransform<?, O> transform) { + return (O) currentTransform.getOutput(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java new file mode 100644 index 0000000..9407bf5 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation; + +import com.google.cloud.dataflow.sdk.Pipeline; + +/** + * The role of this class is to translate the Beam operators to + * their Flink counterparts. If we have a streaming job, this is instantiated as a + * {@link FlinkStreamingPipelineTranslator}. In other case, i.e. for a batch job, + * a {@link FlinkBatchPipelineTranslator} is created. Correspondingly, the + * {@link com.google.cloud.dataflow.sdk.values.PCollection}-based user-provided job is translated into + * a {@link org.apache.flink.streaming.api.datastream.DataStream} (for streaming) or a + * {@link org.apache.flink.api.java.DataSet} (for batch) one. + */ +public abstract class FlinkPipelineTranslator implements Pipeline.PipelineVisitor { + + public void translate(Pipeline pipeline) { + pipeline.traverseTopologically(this); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java new file mode 100644 index 0000000..ac96807 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PValue; +import org.apache.beam.runners.flink.FlinkPipelineRunner; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This is a {@link FlinkPipelineTranslator} for streaming jobs. Its role is to translate the user-provided + * {@link com.google.cloud.dataflow.sdk.values.PCollection}-based job into a + * {@link org.apache.flink.streaming.api.datastream.DataStream} one. + * + * This is based on {@link com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator} + * */ +public class FlinkStreamingPipelineTranslator extends FlinkPipelineTranslator { + + private static final Logger LOG = LoggerFactory.getLogger(FlinkStreamingPipelineTranslator.class); + + /** The necessary context in the case of a straming job. */ + private final FlinkStreamingTranslationContext streamingContext; + + private int depth = 0; + + /** Composite transform that we want to translate before proceeding with other transforms. */ + private PTransform<?, ?> currentCompositeTransform; + + public FlinkStreamingPipelineTranslator(StreamExecutionEnvironment env, PipelineOptions options) { + this.streamingContext = new FlinkStreamingTranslationContext(env, options); + } + + // -------------------------------------------------------------------------------------------- + // Pipeline Visitor Methods + // -------------------------------------------------------------------------------------------- + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + LOG.info(genSpaces(this.depth) + "enterCompositeTransform- " + formatNodeName(node)); + + PTransform<?, ?> transform = node.getTransform(); + if (transform != null && currentCompositeTransform == null) { + + StreamTransformTranslator<?> translator = FlinkStreamingTransformTranslators.getTranslator(transform); + if (translator != null) { + currentCompositeTransform = transform; + } + } + this.depth++; + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + PTransform<?, ?> transform = node.getTransform(); + if (transform != null && currentCompositeTransform == transform) { + + StreamTransformTranslator<?> translator = FlinkStreamingTransformTranslators.getTranslator(transform); + if (translator != null) { + LOG.info(genSpaces(this.depth) + "doingCompositeTransform- " + formatNodeName(node)); + applyStreamingTransform(transform, node, translator); + currentCompositeTransform = null; + } else { + throw new IllegalStateException("Attempted to translate composite transform " + + "but no translator was found: " + currentCompositeTransform); + } + } + this.depth--; + LOG.info(genSpaces(this.depth) + "leaveCompositeTransform- " + formatNodeName(node)); + } + + @Override + public void visitTransform(TransformTreeNode node) { + LOG.info(genSpaces(this.depth) + "visitTransform- " + formatNodeName(node)); + if (currentCompositeTransform != null) { + // ignore it + return; + } + + // get the transformation corresponding to hte node we are + // currently visiting and translate it into its Flink alternative. + + PTransform<?, ?> transform = node.getTransform(); + StreamTransformTranslator<?> translator = FlinkStreamingTransformTranslators.getTranslator(transform); + if (translator == null) { + LOG.info(node.getTransform().getClass().toString()); + throw new UnsupportedOperationException("The transform " + transform + " is currently not supported."); + } + applyStreamingTransform(transform, node, translator); + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + // do nothing here + } + + private <T extends PTransform<?, ?>> void applyStreamingTransform(PTransform<?, ?> transform, TransformTreeNode node, StreamTransformTranslator<?> translator) { + + @SuppressWarnings("unchecked") + T typedTransform = (T) transform; + + @SuppressWarnings("unchecked") + StreamTransformTranslator<T> typedTranslator = (StreamTransformTranslator<T>) translator; + + // create the applied PTransform on the streamingContext + streamingContext.setCurrentTransform(AppliedPTransform.of( + node.getFullName(), node.getInput(), node.getOutput(), (PTransform) transform)); + typedTranslator.translateNode(typedTransform, streamingContext); + } + + /** + * The interface that every Flink translator of a Beam operator should implement. + * This interface is for <b>streaming</b> jobs. For examples of such translators see + * {@link FlinkStreamingTransformTranslators}. + */ + public interface StreamTransformTranslator<Type extends PTransform> { + void translateNode(Type transform, FlinkStreamingTranslationContext context); + } + + private static String genSpaces(int n) { + String s = ""; + for (int i = 0; i < n; i++) { + s += "| "; + } + return s; + } + + private static String formatNodeName(TransformTreeNode node) { + return node.toString().split("@")[1] + node.getTransform(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java new file mode 100644 index 0000000..bdefeaf --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java @@ -0,0 +1,406 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.flink.translation; + +import org.apache.beam.runners.flink.translation.functions.UnionCoder; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.wrappers.streaming.*; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.FlinkStreamingCreateFunction; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.UnboundedFlinkSource; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.UnboundedSourceWrapper; +import com.google.api.client.util.Maps; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.*; +import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; +import com.google.cloud.dataflow.sdk.transforms.windowing.*; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.Lists; +import org.apache.flink.api.common.functions.FilterFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.streaming.api.datastream.*; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.*; + +/** + * This class contains all the mappings between Beam and Flink + * <b>streaming</b> transformations. The {@link FlinkStreamingPipelineTranslator} + * traverses the Beam job and comes here to translate the encountered Beam transformations + * into Flink one, based on the mapping available in this class. + */ +public class FlinkStreamingTransformTranslators { + + // -------------------------------------------------------------------------------------------- + // Transform Translator Registry + // -------------------------------------------------------------------------------------------- + + @SuppressWarnings("rawtypes") + private static final Map<Class<? extends PTransform>, FlinkStreamingPipelineTranslator.StreamTransformTranslator> TRANSLATORS = new HashMap<>(); + + // here you can find all the available translators. + static { + TRANSLATORS.put(Create.Values.class, new CreateStreamingTranslator()); + TRANSLATORS.put(Read.Unbounded.class, new UnboundedReadSourceTranslator()); + TRANSLATORS.put(ParDo.Bound.class, new ParDoBoundStreamingTranslator()); + TRANSLATORS.put(TextIO.Write.Bound.class, new TextIOWriteBoundStreamingTranslator()); + TRANSLATORS.put(Window.Bound.class, new WindowBoundTranslator()); + TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslator()); + TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslator()); + TRANSLATORS.put(Flatten.FlattenPCollectionList.class, new FlattenPCollectionTranslator()); + TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoBoundMultiStreamingTranslator()); + } + + public static FlinkStreamingPipelineTranslator.StreamTransformTranslator<?> getTranslator(PTransform<?, ?> transform) { + return TRANSLATORS.get(transform.getClass()); + } + + // -------------------------------------------------------------------------------------------- + // Transformation Implementations + // -------------------------------------------------------------------------------------------- + + private static class CreateStreamingTranslator<OUT> implements + FlinkStreamingPipelineTranslator.StreamTransformTranslator<Create.Values<OUT>> { + + @Override + public void translateNode(Create.Values<OUT> transform, FlinkStreamingTranslationContext context) { + PCollection<OUT> output = context.getOutput(transform); + Iterable<OUT> elements = transform.getElements(); + + // we need to serialize the elements to byte arrays, since they might contain + // elements that are not serializable by Java serialization. We deserialize them + // in the FlatMap function using the Coder. + + List<byte[]> serializedElements = Lists.newArrayList(); + Coder<OUT> elementCoder = context.getOutput(transform).getCoder(); + for (OUT element: elements) { + ByteArrayOutputStream bao = new ByteArrayOutputStream(); + try { + elementCoder.encode(element, bao, Coder.Context.OUTER); + serializedElements.add(bao.toByteArray()); + } catch (IOException e) { + throw new RuntimeException("Could not serialize Create elements using Coder: " + e); + } + } + + + DataStream<Integer> initDataSet = context.getExecutionEnvironment().fromElements(1); + + FlinkStreamingCreateFunction<Integer, OUT> createFunction = + new FlinkStreamingCreateFunction<>(serializedElements, elementCoder); + + WindowedValue.ValueOnlyWindowedValueCoder<OUT> windowCoder = WindowedValue.getValueOnlyCoder(elementCoder); + TypeInformation<WindowedValue<OUT>> outputType = new CoderTypeInformation<>(windowCoder); + + DataStream<WindowedValue<OUT>> outputDataStream = initDataSet.flatMap(createFunction) + .returns(outputType); + + context.setOutputDataStream(context.getOutput(transform), outputDataStream); + } + } + + + private static class TextIOWriteBoundStreamingTranslator<T> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<TextIO.Write.Bound<T>> { + private static final Logger LOG = LoggerFactory.getLogger(TextIOWriteBoundStreamingTranslator.class); + + @Override + public void translateNode(TextIO.Write.Bound<T> transform, FlinkStreamingTranslationContext context) { + PValue input = context.getInput(transform); + DataStream<WindowedValue<T>> inputDataStream = context.getInputDataStream(input); + + String filenamePrefix = transform.getFilenamePrefix(); + String filenameSuffix = transform.getFilenameSuffix(); + boolean needsValidation = transform.needsValidation(); + int numShards = transform.getNumShards(); + String shardNameTemplate = transform.getShardNameTemplate(); + + // TODO: Implement these. We need Flink support for this. + LOG.warn("Translation of TextIO.Write.needsValidation not yet supported. Is: {}.", needsValidation); + LOG.warn("Translation of TextIO.Write.filenameSuffix not yet supported. Is: {}.", filenameSuffix); + LOG.warn("Translation of TextIO.Write.shardNameTemplate not yet supported. Is: {}.", shardNameTemplate); + + DataStream<String> dataSink = inputDataStream.flatMap(new FlatMapFunction<WindowedValue<T>, String>() { + @Override + public void flatMap(WindowedValue<T> value, Collector<String> out) throws Exception { + out.collect(value.getValue().toString()); + } + }); + DataStreamSink<String> output = dataSink.writeAsText(filenamePrefix, FileSystem.WriteMode.OVERWRITE); + + if (numShards > 0) { + output.setParallelism(numShards); + } + } + } + + private static class UnboundedReadSourceTranslator<T> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<Read.Unbounded<T>> { + + @Override + public void translateNode(Read.Unbounded<T> transform, FlinkStreamingTranslationContext context) { + PCollection<T> output = context.getOutput(transform); + + DataStream<WindowedValue<T>> source; + if (transform.getSource().getClass().equals(UnboundedFlinkSource.class)) { + UnboundedFlinkSource flinkSource = (UnboundedFlinkSource) transform.getSource(); + source = context.getExecutionEnvironment() + .addSource(flinkSource.getFlinkSource()) + .flatMap(new FlatMapFunction<String, WindowedValue<String>>() { + @Override + public void flatMap(String s, Collector<WindowedValue<String>> collector) throws Exception { + collector.collect(WindowedValue.<String>of(s, Instant.now(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING)); + } + }); + } else { + source = context.getExecutionEnvironment() + .addSource(new UnboundedSourceWrapper<>(context.getPipelineOptions(), transform)); + } + context.setOutputDataStream(output, source); + } + } + + private static class ParDoBoundStreamingTranslator<IN, OUT> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<ParDo.Bound<IN, OUT>> { + + @Override + public void translateNode(ParDo.Bound<IN, OUT> transform, FlinkStreamingTranslationContext context) { + PCollection<OUT> output = context.getOutput(transform); + + final WindowingStrategy<OUT, ? extends BoundedWindow> windowingStrategy = + (WindowingStrategy<OUT, ? extends BoundedWindow>) + context.getOutput(transform).getWindowingStrategy(); + + WindowedValue.WindowedValueCoder<OUT> outputStreamCoder = WindowedValue.getFullCoder(output.getCoder(), + windowingStrategy.getWindowFn().windowCoder()); + CoderTypeInformation<WindowedValue<OUT>> outputWindowedValueCoder = + new CoderTypeInformation<>(outputStreamCoder); + + FlinkParDoBoundWrapper<IN, OUT> doFnWrapper = new FlinkParDoBoundWrapper<>( + context.getPipelineOptions(), windowingStrategy, transform.getFn()); + DataStream<WindowedValue<IN>> inputDataStream = context.getInputDataStream(context.getInput(transform)); + SingleOutputStreamOperator<WindowedValue<OUT>> outDataStream = inputDataStream.flatMap(doFnWrapper) + .returns(outputWindowedValueCoder); + + context.setOutputDataStream(context.getOutput(transform), outDataStream); + } + } + + public static class WindowBoundTranslator<T> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<Window.Bound<T>> { + + @Override + public void translateNode(Window.Bound<T> transform, FlinkStreamingTranslationContext context) { + PValue input = context.getInput(transform); + DataStream<WindowedValue<T>> inputDataStream = context.getInputDataStream(input); + + final WindowingStrategy<T, ? extends BoundedWindow> windowingStrategy = + (WindowingStrategy<T, ? extends BoundedWindow>) + context.getOutput(transform).getWindowingStrategy(); + + final WindowFn<T, ? extends BoundedWindow> windowFn = windowingStrategy.getWindowFn(); + + WindowedValue.WindowedValueCoder<T> outputStreamCoder = WindowedValue.getFullCoder( + context.getInput(transform).getCoder(), windowingStrategy.getWindowFn().windowCoder()); + CoderTypeInformation<WindowedValue<T>> outputWindowedValueCoder = + new CoderTypeInformation<>(outputStreamCoder); + + final FlinkParDoBoundWrapper<T, T> windowDoFnAssigner = new FlinkParDoBoundWrapper<>( + context.getPipelineOptions(), windowingStrategy, createWindowAssigner(windowFn)); + + SingleOutputStreamOperator<WindowedValue<T>> windowedStream = + inputDataStream.flatMap(windowDoFnAssigner).returns(outputWindowedValueCoder); + context.setOutputDataStream(context.getOutput(transform), windowedStream); + } + + private static <T, W extends BoundedWindow> DoFn<T, T> createWindowAssigner(final WindowFn<T, W> windowFn) { + return new DoFn<T, T>() { + + @Override + public void processElement(final ProcessContext c) throws Exception { + Collection<W> windows = windowFn.assignWindows( + windowFn.new AssignContext() { + @Override + public T element() { + return c.element(); + } + + @Override + public Instant timestamp() { + return c.timestamp(); + } + + @Override + public Collection<? extends BoundedWindow> windows() { + return c.windowingInternals().windows(); + } + }); + + c.windowingInternals().outputWindowedValue( + c.element(), c.timestamp(), windows, c.pane()); + } + }; + } + } + + public static class GroupByKeyTranslator<K, V> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<GroupByKey<K, V>> { + + @Override + public void translateNode(GroupByKey<K, V> transform, FlinkStreamingTranslationContext context) { + PValue input = context.getInput(transform); + + DataStream<WindowedValue<KV<K, V>>> inputDataStream = context.getInputDataStream(input); + KvCoder<K, V> inputKvCoder = (KvCoder<K, V>) context.getInput(transform).getCoder(); + + KeyedStream<WindowedValue<KV<K, V>>, K> groupByKStream = FlinkGroupByKeyWrapper + .groupStreamByKey(inputDataStream, inputKvCoder); + + DataStream<WindowedValue<KV<K, Iterable<V>>>> groupedByKNWstream = + FlinkGroupAlsoByWindowWrapper.createForIterable(context.getPipelineOptions(), + context.getInput(transform), groupByKStream); + + context.setOutputDataStream(context.getOutput(transform), groupedByKNWstream); + } + } + + public static class CombinePerKeyTranslator<K, VIN, VACC, VOUT> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<Combine.PerKey<K, VIN, VOUT>> { + + @Override + public void translateNode(Combine.PerKey<K, VIN, VOUT> transform, FlinkStreamingTranslationContext context) { + PValue input = context.getInput(transform); + + DataStream<WindowedValue<KV<K, VIN>>> inputDataStream = context.getInputDataStream(input); + KvCoder<K, VIN> inputKvCoder = (KvCoder<K, VIN>) context.getInput(transform).getCoder(); + KvCoder<K, VOUT> outputKvCoder = (KvCoder<K, VOUT>) context.getOutput(transform).getCoder(); + + KeyedStream<WindowedValue<KV<K, VIN>>, K> groupByKStream = FlinkGroupByKeyWrapper + .groupStreamByKey(inputDataStream, inputKvCoder); + + Combine.KeyedCombineFn<K, VIN, VACC, VOUT> combineFn = (Combine.KeyedCombineFn<K, VIN, VACC, VOUT>) transform.getFn(); + DataStream<WindowedValue<KV<K, VOUT>>> groupedByKNWstream = + FlinkGroupAlsoByWindowWrapper.create(context.getPipelineOptions(), + context.getInput(transform), groupByKStream, combineFn, outputKvCoder); + + context.setOutputDataStream(context.getOutput(transform), groupedByKNWstream); + } + } + + public static class FlattenPCollectionTranslator<T> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<Flatten.FlattenPCollectionList<T>> { + + @Override + public void translateNode(Flatten.FlattenPCollectionList<T> transform, FlinkStreamingTranslationContext context) { + List<PCollection<T>> allInputs = context.getInput(transform).getAll(); + DataStream<T> result = null; + for (PCollection<T> collection : allInputs) { + DataStream<T> current = context.getInputDataStream(collection); + result = (result == null) ? current : result.union(current); + } + context.setOutputDataStream(context.getOutput(transform), result); + } + } + + public static class ParDoBoundMultiStreamingTranslator<IN, OUT> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<ParDo.BoundMulti<IN, OUT>> { + + private final int MAIN_TAG_INDEX = 0; + + @Override + public void translateNode(ParDo.BoundMulti<IN, OUT> transform, FlinkStreamingTranslationContext context) { + + // we assume that the transformation does not change the windowing strategy. + WindowingStrategy<?, ? extends BoundedWindow> windowingStrategy = context.getInput(transform).getWindowingStrategy(); + + Map<TupleTag<?>, PCollection<?>> outputs = context.getOutput(transform).getAll(); + Map<TupleTag<?>, Integer> tagsToLabels = transformTupleTagsToLabels( + transform.getMainOutputTag(), outputs.keySet()); + + UnionCoder intermUnionCoder = getIntermUnionCoder(outputs.values()); + WindowedValue.WindowedValueCoder<RawUnionValue> outputStreamCoder = WindowedValue.getFullCoder( + intermUnionCoder, windowingStrategy.getWindowFn().windowCoder()); + + CoderTypeInformation<WindowedValue<RawUnionValue>> intermWindowedValueCoder = + new CoderTypeInformation<>(outputStreamCoder); + + FlinkParDoBoundMultiWrapper<IN, OUT> doFnWrapper = new FlinkParDoBoundMultiWrapper<>( + context.getPipelineOptions(), windowingStrategy, transform.getFn(), + transform.getMainOutputTag(), tagsToLabels); + + DataStream<WindowedValue<IN>> inputDataStream = context.getInputDataStream(context.getInput(transform)); + SingleOutputStreamOperator<WindowedValue<RawUnionValue>> intermDataStream = + inputDataStream.flatMap(doFnWrapper).returns(intermWindowedValueCoder); + + for (Map.Entry<TupleTag<?>, PCollection<?>> output : outputs.entrySet()) { + final int outputTag = tagsToLabels.get(output.getKey()); + + WindowedValue.WindowedValueCoder<?> coderForTag = WindowedValue.getFullCoder( + output.getValue().getCoder(), + windowingStrategy.getWindowFn().windowCoder()); + + CoderTypeInformation<WindowedValue<?>> windowedValueCoder = + new CoderTypeInformation(coderForTag); + + context.setOutputDataStream(output.getValue(), + intermDataStream.filter(new FilterFunction<WindowedValue<RawUnionValue>>() { + @Override + public boolean filter(WindowedValue<RawUnionValue> value) throws Exception { + return value.getValue().getUnionTag() == outputTag; + } + }).flatMap(new FlatMapFunction<WindowedValue<RawUnionValue>, WindowedValue<?>>() { + @Override + public void flatMap(WindowedValue<RawUnionValue> value, Collector<WindowedValue<?>> collector) throws Exception { + collector.collect(WindowedValue.of( + value.getValue().getValue(), + value.getTimestamp(), + value.getWindows(), + value.getPane())); + } + }).returns(windowedValueCoder)); + } + } + + private Map<TupleTag<?>, Integer> transformTupleTagsToLabels(TupleTag<?> mainTag, Set<TupleTag<?>> secondaryTags) { + Map<TupleTag<?>, Integer> tagToLabelMap = Maps.newHashMap(); + tagToLabelMap.put(mainTag, MAIN_TAG_INDEX); + int count = MAIN_TAG_INDEX + 1; + for (TupleTag<?> tag : secondaryTags) { + if (!tagToLabelMap.containsKey(tag)) { + tagToLabelMap.put(tag, count++); + } + } + return tagToLabelMap; + } + + private UnionCoder getIntermUnionCoder(Collection<PCollection<?>> taggedCollections) { + List<Coder<?>> outputCoders = Lists.newArrayList(); + for (PCollection<?> coll : taggedCollections) { + outputCoders.add(coll.getCoder()); + } + return UnionCoder.of(outputCoders); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTranslationContext.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTranslationContext.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTranslationContext.java new file mode 100644 index 0000000..f6bdecd --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTranslationContext.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.Preconditions; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; + +import java.util.HashMap; +import java.util.Map; + +public class FlinkStreamingTranslationContext { + + private final StreamExecutionEnvironment env; + private final PipelineOptions options; + + /** + * Keeps a mapping between the output value of the PTransform (in Dataflow) and the + * Flink Operator that produced it, after the translation of the correspondinf PTransform + * to its Flink equivalent. + * */ + private final Map<PValue, DataStream<?>> dataStreams; + + private AppliedPTransform<?, ?, ?> currentTransform; + + public FlinkStreamingTranslationContext(StreamExecutionEnvironment env, PipelineOptions options) { + this.env = Preconditions.checkNotNull(env); + this.options = Preconditions.checkNotNull(options); + this.dataStreams = new HashMap<>(); + } + + public StreamExecutionEnvironment getExecutionEnvironment() { + return env; + } + + public PipelineOptions getPipelineOptions() { + return options; + } + + @SuppressWarnings("unchecked") + public <T> DataStream<T> getInputDataStream(PValue value) { + return (DataStream<T>) dataStreams.get(value); + } + + public void setOutputDataStream(PValue value, DataStream<?> set) { + if (!dataStreams.containsKey(value)) { + dataStreams.put(value, set); + } + } + + /** + * Sets the AppliedPTransform which carries input/output. + * @param currentTransform + */ + public void setCurrentTransform(AppliedPTransform<?, ?, ?> currentTransform) { + this.currentTransform = currentTransform; + } + + @SuppressWarnings("unchecked") + public <I extends PInput> I getInput(PTransform<I, ?> transform) { + return (I) currentTransform.getInput(); + } + + @SuppressWarnings("unchecked") + public <O extends POutput> O getOutput(PTransform<?, O> transform) { + return (O) currentTransform.getOutput(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCoGroupKeyedListAggregator.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCoGroupKeyedListAggregator.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCoGroupKeyedListAggregator.java new file mode 100644 index 0000000..d5562b8 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCoGroupKeyedListAggregator.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResultSchema; +import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import org.apache.flink.api.common.functions.CoGroupFunction; +import org.apache.flink.util.Collector; + +import java.util.ArrayList; +import java.util.List; + + +public class FlinkCoGroupKeyedListAggregator<K,V1,V2> implements CoGroupFunction<KV<K,V1>, KV<K,V2>, KV<K, CoGbkResult>>{ + + private CoGbkResultSchema schema; + private TupleTag<?> tupleTag1; + private TupleTag<?> tupleTag2; + + public FlinkCoGroupKeyedListAggregator(CoGbkResultSchema schema, TupleTag<?> tupleTag1, TupleTag<?> tupleTag2) { + this.schema = schema; + this.tupleTag1 = tupleTag1; + this.tupleTag2 = tupleTag2; + } + + @Override + public void coGroup(Iterable<KV<K,V1>> first, Iterable<KV<K,V2>> second, Collector<KV<K, CoGbkResult>> out) throws Exception { + K k = null; + List<RawUnionValue> result = new ArrayList<>(); + int index1 = schema.getIndex(tupleTag1); + for (KV<K,?> entry : first) { + k = entry.getKey(); + result.add(new RawUnionValue(index1, entry.getValue())); + } + int index2 = schema.getIndex(tupleTag2); + for (KV<K,?> entry : second) { + k = entry.getKey(); + result.add(new RawUnionValue(index2, entry.getValue())); + } + out.collect(KV.of(k, new CoGbkResult(schema, result))); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCreateFunction.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCreateFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCreateFunction.java new file mode 100644 index 0000000..56af397 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCreateFunction.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.runners.flink.translation.types.VoidCoderTypeSerializer; +import com.google.cloud.dataflow.sdk.coders.Coder; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.util.Collector; + +import java.io.ByteArrayInputStream; +import java.util.List; + +/** + * This is a hack for transforming a {@link com.google.cloud.dataflow.sdk.transforms.Create} + * operation. Flink does not allow {@code null} in it's equivalent operation: + * {@link org.apache.flink.api.java.ExecutionEnvironment#fromElements(Object[])}. Therefore + * we use a DataSource with one dummy element and output the elements of the Create operation + * inside this FlatMap. + */ +public class FlinkCreateFunction<IN, OUT> implements FlatMapFunction<IN, OUT> { + + private final List<byte[]> elements; + private final Coder<OUT> coder; + + public FlinkCreateFunction(List<byte[]> elements, Coder<OUT> coder) { + this.elements = elements; + this.coder = coder; + } + + @Override + @SuppressWarnings("unchecked") + public void flatMap(IN value, Collector<OUT> out) throws Exception { + + for (byte[] element : elements) { + ByteArrayInputStream bai = new ByteArrayInputStream(element); + OUT outValue = coder.decode(bai, Coder.Context.OUTER); + if (outValue == null) { + // TODO Flink doesn't allow null values in records + out.collect((OUT) VoidCoderTypeSerializer.VoidValue.INSTANCE); + } else { + out.collect(outValue); + } + } + + out.close(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java new file mode 100644 index 0000000..fe77e64 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.runners.flink.translation.wrappers.SerializableFnAggregatorWrapper; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.TimerInternals; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingInternals; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.ImmutableList; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * Encapsulates a {@link com.google.cloud.dataflow.sdk.transforms.DoFn} + * inside a Flink {@link org.apache.flink.api.common.functions.RichMapPartitionFunction}. + */ +public class FlinkDoFnFunction<IN, OUT> extends RichMapPartitionFunction<IN, OUT> { + + private final DoFn<IN, OUT> doFn; + private transient PipelineOptions options; + + public FlinkDoFnFunction(DoFn<IN, OUT> doFn, PipelineOptions options) { + this.doFn = doFn; + this.options = options; + } + + private void writeObject(ObjectOutputStream out) + throws IOException, ClassNotFoundException { + out.defaultWriteObject(); + ObjectMapper mapper = new ObjectMapper(); + mapper.writeValue(out, options); + } + + private void readObject(ObjectInputStream in) + throws IOException, ClassNotFoundException { + in.defaultReadObject(); + ObjectMapper mapper = new ObjectMapper(); + options = mapper.readValue(in, PipelineOptions.class); + } + + @Override + public void mapPartition(Iterable<IN> values, Collector<OUT> out) throws Exception { + ProcessContext context = new ProcessContext(doFn, out); + this.doFn.startBundle(context); + for (IN value : values) { + context.inValue = value; + doFn.processElement(context); + } + this.doFn.finishBundle(context); + } + + private class ProcessContext extends DoFn<IN, OUT>.ProcessContext { + + IN inValue; + Collector<OUT> outCollector; + + public ProcessContext(DoFn<IN, OUT> fn, Collector<OUT> outCollector) { + fn.super(); + super.setupDelegateAggregators(); + this.outCollector = outCollector; + } + + @Override + public IN element() { + return this.inValue; + } + + + @Override + public Instant timestamp() { + return Instant.now(); + } + + @Override + public BoundedWindow window() { + return GlobalWindow.INSTANCE; + } + + @Override + public PaneInfo pane() { + return PaneInfo.NO_FIRING; + } + + @Override + public WindowingInternals<IN, OUT> windowingInternals() { + return new WindowingInternals<IN, OUT>() { + @Override + public StateInternals stateInternals() { + return null; + } + + @Override + public void outputWindowedValue(OUT output, Instant timestamp, Collection<? extends BoundedWindow> windows, PaneInfo pane) { + + } + + @Override + public TimerInternals timerInternals() { + return null; + } + + @Override + public Collection<? extends BoundedWindow> windows() { + return ImmutableList.of(GlobalWindow.INSTANCE); + } + + @Override + public PaneInfo pane() { + return PaneInfo.NO_FIRING; + } + + @Override + public <T> void writePCollectionViewData(TupleTag<?> tag, Iterable<WindowedValue<T>> data, Coder<T> elemCoder) throws IOException { + } + + @Override + public <T> T sideInput(PCollectionView<T> view, BoundedWindow mainInputWindow) { + throw new RuntimeException("sideInput() not implemented."); + } + }; + } + + @Override + public PipelineOptions getPipelineOptions() { + return options; + } + + @Override + public <T> T sideInput(PCollectionView<T> view) { + List<T> sideInput = getRuntimeContext().getBroadcastVariable(view.getTagInternal().getId()); + List<WindowedValue<?>> windowedValueList = new ArrayList<>(sideInput.size()); + for (T input : sideInput) { + windowedValueList.add(WindowedValue.of(input, Instant.now(), ImmutableList.of(GlobalWindow.INSTANCE), pane())); + } + return view.fromIterableInternal(windowedValueList); + } + + @Override + public void output(OUT output) { + outCollector.collect(output); + } + + @Override + public void outputWithTimestamp(OUT output, Instant timestamp) { + // not FLink's way, just output normally + output(output); + } + + @Override + public <T> void sideOutput(TupleTag<T> tag, T output) { + // ignore the side output, this can happen when a user does not register + // side outputs but then outputs using a freshly created TupleTag. + } + + @Override + public <T> void sideOutputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) { + sideOutput(tag, output); + } + + @Override + protected <AggInputT, AggOutputT> Aggregator<AggInputT, AggOutputT> createAggregatorInternal(String name, Combine.CombineFn<AggInputT, ?, AggOutputT> combiner) { + SerializableFnAggregatorWrapper<AggInputT, AggOutputT> wrapper = new SerializableFnAggregatorWrapper<>(combiner); + getRuntimeContext().addAccumulator(name, wrapper); + return wrapper; + } + + + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkKeyedListAggregationFunction.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkKeyedListAggregationFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkKeyedListAggregationFunction.java new file mode 100644 index 0000000..f92f888 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkKeyedListAggregationFunction.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import com.google.cloud.dataflow.sdk.values.KV; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.util.Collector; + +import java.util.Iterator; + +/** + * Flink {@link org.apache.flink.api.common.functions.GroupReduceFunction} for executing a + * {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey} operation. This reads the input + * {@link com.google.cloud.dataflow.sdk.values.KV} elements, extracts the key and collects + * the values in a {@code List}. + */ +public class FlinkKeyedListAggregationFunction<K,V> implements GroupReduceFunction<KV<K, V>, KV<K, Iterable<V>>> { + + @Override + public void reduce(Iterable<KV<K, V>> values, Collector<KV<K, Iterable<V>>> out) throws Exception { + Iterator<KV<K, V>> it = values.iterator(); + KV<K, V> first = it.next(); + Iterable<V> passThrough = new PassThroughIterable<>(first, it); + out.collect(KV.of(first.getKey(), passThrough)); + } + + private static class PassThroughIterable<K, V> implements Iterable<V>, Iterator<V> { + private KV<K, V> first; + private Iterator<KV<K, V>> iterator; + + public PassThroughIterable(KV<K, V> first, Iterator<KV<K, V>> iterator) { + this.first = first; + this.iterator = iterator; + } + + @Override + public Iterator<V> iterator() { + return this; + } + + @Override + public boolean hasNext() { + return first != null || iterator.hasNext(); + } + + @Override + public V next() { + if (first != null) { + V result = first.getValue(); + first = null; + return result; + } else { + return iterator.next().getValue(); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException("Cannot remove elements from input."); + } + } +}
