http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java new file mode 100644 index 0000000..a1ddd44 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java @@ -0,0 +1,288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import static com.google.common.base.Preconditions.checkArgument; + +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.runners.AggregatorRetrievalException; +import com.google.cloud.dataflow.sdk.runners.AggregatorValues; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.Function; +import com.google.common.collect.Iterables; +import org.apache.beam.runners.spark.EvaluationResult; +import org.apache.beam.runners.spark.coders.CoderHelpers; +import org.apache.spark.api.java.JavaRDDLike; +import org.apache.spark.api.java.JavaSparkContext; + + +/** + * Evaluation context allows us to define how pipeline instructions. + */ +public class EvaluationContext implements EvaluationResult { + private final JavaSparkContext jsc; + private final Pipeline pipeline; + private final SparkRuntimeContext runtime; + private final Map<PValue, RDDHolder<?>> pcollections = new LinkedHashMap<>(); + private final Set<RDDHolder<?>> leafRdds = new LinkedHashSet<>(); + private final Set<PValue> multireads = new LinkedHashSet<>(); + private final Map<PValue, Object> pobjects = new LinkedHashMap<>(); + private final Map<PValue, Iterable<? extends WindowedValue<?>>> pview = new LinkedHashMap<>(); + protected AppliedPTransform<?, ?, ?> currentTransform; + + public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline) { + this.jsc = jsc; + this.pipeline = pipeline; + this.runtime = new SparkRuntimeContext(jsc, pipeline); + } + + /** + * Holds an RDD or values for deferred conversion to an RDD if needed. PCollections are + * sometimes created from a collection of objects (using RDD parallelize) and then + * only used to create View objects; in which case they do not need to be + * converted to bytes since they are not transferred across the network until they are + * broadcast. + */ + private class RDDHolder<T> { + + private Iterable<T> values; + private Coder<T> coder; + private JavaRDDLike<WindowedValue<T>, ?> rdd; + + RDDHolder(Iterable<T> values, Coder<T> coder) { + this.values = values; + this.coder = coder; + } + + RDDHolder(JavaRDDLike<WindowedValue<T>, ?> rdd) { + this.rdd = rdd; + } + + JavaRDDLike<WindowedValue<T>, ?> getRDD() { + if (rdd == null) { + Iterable<WindowedValue<T>> windowedValues = Iterables.transform(values, + new Function<T, WindowedValue<T>>() { + @Override + public WindowedValue<T> apply(T t) { + // TODO: this is wrong if T is a TimestampedValue + return WindowedValue.valueInEmptyWindows(t); + } + }); + WindowedValue.ValueOnlyWindowedValueCoder<T> windowCoder = + WindowedValue.getValueOnlyCoder(coder); + rdd = jsc.parallelize(CoderHelpers.toByteArrays(windowedValues, windowCoder)) + .map(CoderHelpers.fromByteFunction(windowCoder)); + } + return rdd; + } + + Iterable<T> getValues(PCollection<T> pcollection) { + if (values == null) { + coder = pcollection.getCoder(); + JavaRDDLike<byte[], ?> bytesRDD = rdd.map(WindowingHelpers.<T>unwindowFunction()) + .map(CoderHelpers.toByteFunction(coder)); + List<byte[]> clientBytes = bytesRDD.collect(); + values = Iterables.transform(clientBytes, new Function<byte[], T>() { + @Override + public T apply(byte[] bytes) { + return CoderHelpers.fromByteArray(bytes, coder); + } + }); + } + return values; + } + + Iterable<WindowedValue<T>> getWindowedValues(PCollection<T> pcollection) { + return Iterables.transform(get(pcollection), new Function<T, WindowedValue<T>>() { + @Override + public WindowedValue<T> apply(T t) { + return WindowedValue.valueInEmptyWindows(t); // TODO: not the right place? + } + }); + } + } + + protected JavaSparkContext getSparkContext() { + return jsc; + } + + protected Pipeline getPipeline() { + return pipeline; + } + + protected SparkRuntimeContext getRuntimeContext() { + return runtime; + } + + protected void setCurrentTransform(AppliedPTransform<?, ?, ?> transform) { + this.currentTransform = transform; + } + + protected AppliedPTransform<?, ?, ?> getCurrentTransform() { + return currentTransform; + } + + protected <I extends PInput> I getInput(PTransform<I, ?> transform) { + checkArgument(currentTransform != null && currentTransform.getTransform() == transform, + "can only be called with current transform"); + @SuppressWarnings("unchecked") + I input = (I) currentTransform.getInput(); + return input; + } + + protected <O extends POutput> O getOutput(PTransform<?, O> transform) { + checkArgument(currentTransform != null && currentTransform.getTransform() == transform, + "can only be called with current transform"); + @SuppressWarnings("unchecked") + O output = (O) currentTransform.getOutput(); + return output; + } + + protected <T> void setOutputRDD(PTransform<?, ?> transform, + JavaRDDLike<WindowedValue<T>, ?> rdd) { + setRDD((PValue) getOutput(transform), rdd); + } + + protected <T> void setOutputRDDFromValues(PTransform<?, ?> transform, Iterable<T> values, + Coder<T> coder) { + pcollections.put((PValue) getOutput(transform), new RDDHolder<>(values, coder)); + } + + void setPView(PValue view, Iterable<? extends WindowedValue<?>> value) { + pview.put(view, value); + } + + protected boolean hasOutputRDD(PTransform<? extends PInput, ?> transform) { + PValue pvalue = (PValue) getOutput(transform); + return pcollections.containsKey(pvalue); + } + + protected JavaRDDLike<?, ?> getRDD(PValue pvalue) { + RDDHolder<?> rddHolder = pcollections.get(pvalue); + JavaRDDLike<?, ?> rdd = rddHolder.getRDD(); + leafRdds.remove(rddHolder); + if (multireads.contains(pvalue)) { + // Ensure the RDD is marked as cached + rdd.rdd().cache(); + } else { + multireads.add(pvalue); + } + return rdd; + } + + protected <T> void setRDD(PValue pvalue, JavaRDDLike<WindowedValue<T>, ?> rdd) { + try { + rdd.rdd().setName(pvalue.getName()); + } catch (IllegalStateException e) { + // name not set, ignore + } + RDDHolder<T> rddHolder = new RDDHolder<>(rdd); + pcollections.put(pvalue, rddHolder); + leafRdds.add(rddHolder); + } + + JavaRDDLike<?, ?> getInputRDD(PTransform<? extends PInput, ?> transform) { + return getRDD((PValue) getInput(transform)); + } + + + <T> Iterable<? extends WindowedValue<?>> getPCollectionView(PCollectionView<T> view) { + return pview.get(view); + } + + /** + * Computes the outputs for all RDDs that are leaves in the DAG and do not have any + * actions (like saving to a file) registered on them (i.e. they are performed for side + * effects). + */ + public void computeOutputs() { + for (RDDHolder<?> rddHolder : leafRdds) { + JavaRDDLike<?, ?> rdd = rddHolder.getRDD(); + rdd.rdd().cache(); // cache so that any subsequent get() is cheap + rdd.count(); // force the RDD to be computed + } + } + + @Override + public <T> T get(PValue value) { + if (pobjects.containsKey(value)) { + @SuppressWarnings("unchecked") + T result = (T) pobjects.get(value); + return result; + } + if (pcollections.containsKey(value)) { + JavaRDDLike<?, ?> rdd = pcollections.get(value).getRDD(); + @SuppressWarnings("unchecked") + T res = (T) Iterables.getOnlyElement(rdd.collect()); + pobjects.put(value, res); + return res; + } + throw new IllegalStateException("Cannot resolve un-known PObject: " + value); + } + + @Override + public <T> T getAggregatorValue(String named, Class<T> resultType) { + return runtime.getAggregatorValue(named, resultType); + } + + @Override + public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator) + throws AggregatorRetrievalException { + return runtime.getAggregatorValues(aggregator); + } + + @Override + public <T> Iterable<T> get(PCollection<T> pcollection) { + @SuppressWarnings("unchecked") + RDDHolder<T> rddHolder = (RDDHolder<T>) pcollections.get(pcollection); + return rddHolder.getValues(pcollection); + } + + <T> Iterable<WindowedValue<T>> getWindowedValues(PCollection<T> pcollection) { + @SuppressWarnings("unchecked") + RDDHolder<T> rddHolder = (RDDHolder<T>) pcollections.get(pcollection); + return rddHolder.getWindowedValues(pcollection); + } + + @Override + public void close() { + SparkContextFactory.stopSparkContext(jsc); + } + + /** The runner is blocking. */ + @Override + public State getState() { + return State.DONE; + } +}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java new file mode 100644 index 0000000..cecf962 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import java.util.Iterator; +import java.util.Map; + +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.base.Function; +import com.google.common.collect.Iterators; +import com.google.common.collect.LinkedListMultimap; +import com.google.common.collect.Multimap; +import org.apache.beam.runners.spark.util.BroadcastHelper; +import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.joda.time.Instant; +import scala.Tuple2; + +/** + * DoFunctions ignore side outputs. MultiDoFunctions deal with side outputs by enriching the + * underlying data with multiple TupleTags. + * + * @param <I> Input type for DoFunction. + * @param <O> Output type for DoFunction. + */ +class MultiDoFnFunction<I, O> + implements PairFlatMapFunction<Iterator<WindowedValue<I>>, TupleTag<?>, WindowedValue<?>> { + private final DoFn<I, O> mFunction; + private final SparkRuntimeContext mRuntimeContext; + private final TupleTag<O> mMainOutputTag; + private final Map<TupleTag<?>, BroadcastHelper<?>> mSideInputs; + + MultiDoFnFunction( + DoFn<I, O> fn, + SparkRuntimeContext runtimeContext, + TupleTag<O> mainOutputTag, + Map<TupleTag<?>, BroadcastHelper<?>> sideInputs) { + this.mFunction = fn; + this.mRuntimeContext = runtimeContext; + this.mMainOutputTag = mainOutputTag; + this.mSideInputs = sideInputs; + } + + @Override + public Iterable<Tuple2<TupleTag<?>, WindowedValue<?>>> + call(Iterator<WindowedValue<I>> iter) throws Exception { + ProcCtxt ctxt = new ProcCtxt(mFunction, mRuntimeContext, mSideInputs); + mFunction.startBundle(ctxt); + ctxt.setup(); + return ctxt.getOutputIterable(iter, mFunction); + } + + private class ProcCtxt extends SparkProcessContext<I, O, Tuple2<TupleTag<?>, WindowedValue<?>>> { + + private final Multimap<TupleTag<?>, WindowedValue<?>> outputs = LinkedListMultimap.create(); + + ProcCtxt(DoFn<I, O> fn, SparkRuntimeContext runtimeContext, Map<TupleTag<?>, + BroadcastHelper<?>> sideInputs) { + super(fn, runtimeContext, sideInputs); + } + + @Override + public synchronized void output(O o) { + outputs.put(mMainOutputTag, windowedValue.withValue(o)); + } + + @Override + public synchronized void output(WindowedValue<O> o) { + outputs.put(mMainOutputTag, o); + } + + @Override + public synchronized <T> void sideOutput(TupleTag<T> tag, T t) { + outputs.put(tag, windowedValue.withValue(t)); + } + + @Override + public <T> void sideOutputWithTimestamp(TupleTag<T> tupleTag, T t, Instant instant) { + outputs.put(tupleTag, WindowedValue.of(t, instant, + windowedValue.getWindows(), windowedValue.getPane())); + } + + @Override + protected void clearOutput() { + outputs.clear(); + } + + @Override + protected Iterator<Tuple2<TupleTag<?>, WindowedValue<?>>> getOutputIterator() { + return Iterators.transform(outputs.entries().iterator(), + new Function<Map.Entry<TupleTag<?>, WindowedValue<?>>, + Tuple2<TupleTag<?>, WindowedValue<?>>>() { + @Override + public Tuple2<TupleTag<?>, WindowedValue<?>> apply(Map.Entry<TupleTag<?>, + WindowedValue<?>> input) { + return new Tuple2<TupleTag<?>, WindowedValue<?>>(input.getKey(), input.getValue()); + } + }); + } + + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java new file mode 100644 index 0000000..2bc8a7b --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.serializer.KryoSerializer; + +public final class SparkContextFactory { + + /** + * If the property {@code dataflow.spark.test.reuseSparkContext} is set to + * {@code true} then the Spark context will be reused for dataflow pipelines. + * This property should only be enabled for tests. + */ + static final String TEST_REUSE_SPARK_CONTEXT = + "dataflow.spark.test.reuseSparkContext"; + private static JavaSparkContext sparkContext; + private static String sparkMaster; + + private SparkContextFactory() { + } + + public static synchronized JavaSparkContext getSparkContext(String master, String appName) { + if (Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT)) { + if (sparkContext == null) { + sparkContext = createSparkContext(master, appName); + sparkMaster = master; + } else if (!master.equals(sparkMaster)) { + throw new IllegalArgumentException(String.format("Cannot reuse spark context " + + "with different spark master URL. Existing: %s, requested: %s.", + sparkMaster, master)); + } + return sparkContext; + } else { + return createSparkContext(master, appName); + } + } + + static synchronized void stopSparkContext(JavaSparkContext context) { + if (!Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT)) { + context.stop(); + } + } + + private static JavaSparkContext createSparkContext(String master, String appName) { + SparkConf conf = new SparkConf(); + conf.setMaster(master); + conf.setAppName(appName); + conf.set("spark.serializer", KryoSerializer.class.getCanonicalName()); + return new JavaSparkContext(conf); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineEvaluator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineEvaluator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineEvaluator.java new file mode 100644 index 0000000..0186c8c --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineEvaluator.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import org.apache.beam.runners.spark.SparkPipelineRunner; + +/** + * Pipeline {@link SparkPipelineRunner.Evaluator} for Spark. + */ +public final class SparkPipelineEvaluator extends SparkPipelineRunner.Evaluator { + + private final EvaluationContext ctxt; + + public SparkPipelineEvaluator(EvaluationContext ctxt, SparkPipelineTranslator translator) { + super(translator); + this.ctxt = ctxt; + } + + @Override + protected <PT extends PTransform<? super PInput, POutput>> void doVisitTransform(TransformTreeNode + node) { + @SuppressWarnings("unchecked") + PT transform = (PT) node.getTransform(); + @SuppressWarnings("unchecked") + Class<PT> transformClass = (Class<PT>) (Class<?>) transform.getClass(); + @SuppressWarnings("unchecked") TransformEvaluator<PT> evaluator = + (TransformEvaluator<PT>) translator.translate(transformClass); + LOG.info("Evaluating {}", transform); + AppliedPTransform<PInput, POutput, PT> appliedTransform = + AppliedPTransform.of(node.getFullName(), node.getInput(), node.getOutput(), transform); + ctxt.setCurrentTransform(appliedTransform); + evaluator.evaluate(transform, ctxt); + ctxt.setCurrentTransform(null); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineOptionsFactory.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineOptionsFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineOptionsFactory.java new file mode 100644 index 0000000..2b6804e --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineOptionsFactory.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import org.apache.beam.runners.spark.SparkPipelineOptions; + +public final class SparkPipelineOptionsFactory { + private SparkPipelineOptionsFactory() { + } + + public static SparkPipelineOptions create() { + return PipelineOptionsFactory.as(SparkPipelineOptions.class); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineOptionsRegistrar.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineOptionsRegistrar.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineOptionsRegistrar.java new file mode 100644 index 0000000..9775b3e --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineOptionsRegistrar.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsRegistrar; +import com.google.common.collect.ImmutableList; +import org.apache.beam.runners.spark.SparkPipelineOptions; + +public class SparkPipelineOptionsRegistrar implements PipelineOptionsRegistrar { + @Override + public Iterable<Class<? extends PipelineOptions>> getPipelineOptions() { + return ImmutableList.<Class<? extends PipelineOptions>>of(SparkPipelineOptions.class); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineRunnerRegistrar.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineRunnerRegistrar.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineRunnerRegistrar.java new file mode 100644 index 0000000..e44d999 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineRunnerRegistrar.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunnerRegistrar; +import com.google.common.collect.ImmutableList; +import org.apache.beam.runners.spark.SparkPipelineRunner; + +public class SparkPipelineRunnerRegistrar implements PipelineRunnerRegistrar { + @Override + public Iterable<Class<? extends PipelineRunner<?>>> getPipelineRunners() { + return ImmutableList.<Class<? extends PipelineRunner<?>>>of(SparkPipelineRunner.class); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineTranslator.java new file mode 100644 index 0000000..ac1c685 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineTranslator.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.translation; + +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +/** + * Translator to support translation between Dataflow transformations and Spark transformations. + */ +public interface SparkPipelineTranslator { + + boolean hasTranslation(Class<? extends PTransform<?, ?>> clazz); + + <PT extends PTransform<?, ?>> TransformEvaluator<PT> translate(Class<PT> clazz); +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java new file mode 100644 index 0000000..bfcdd80 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import java.io.IOException; +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.TimerInternals; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingInternals; +import com.google.cloud.dataflow.sdk.util.state.*; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.Iterables; + +import org.apache.beam.runners.spark.util.BroadcastHelper; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class SparkProcessContext<I, O, V> extends DoFn<I, O>.ProcessContext { + + private static final Logger LOG = LoggerFactory.getLogger(SparkProcessContext.class); + + private final DoFn<I, O> fn; + private final SparkRuntimeContext mRuntimeContext; + private final Map<TupleTag<?>, BroadcastHelper<?>> mSideInputs; + + protected WindowedValue<I> windowedValue; + + SparkProcessContext(DoFn<I, O> fn, + SparkRuntimeContext runtime, + Map<TupleTag<?>, BroadcastHelper<?>> sideInputs) { + fn.super(); + this.fn = fn; + this.mRuntimeContext = runtime; + this.mSideInputs = sideInputs; + } + + void setup() { + setupDelegateAggregators(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return mRuntimeContext.getPipelineOptions(); + } + + @Override + public <T> T sideInput(PCollectionView<T> view) { + @SuppressWarnings("unchecked") + BroadcastHelper<Iterable<WindowedValue<?>>> broadcastHelper = + (BroadcastHelper<Iterable<WindowedValue<?>>>) mSideInputs.get(view.getTagInternal()); + Iterable<WindowedValue<?>> contents = broadcastHelper.getValue(); + return view.fromIterableInternal(contents); + } + + @Override + public abstract void output(O output); + + public abstract void output(WindowedValue<O> output); + + @Override + public <T> void sideOutput(TupleTag<T> tupleTag, T t) { + String message = "sideOutput is an unsupported operation for doFunctions, use a " + + "MultiDoFunction instead."; + LOG.warn(message); + throw new UnsupportedOperationException(message); + } + + @Override + public <T> void sideOutputWithTimestamp(TupleTag<T> tupleTag, T t, Instant instant) { + String message = + "sideOutputWithTimestamp is an unsupported operation for doFunctions, use a " + + "MultiDoFunction instead."; + LOG.warn(message); + throw new UnsupportedOperationException(message); + } + + @Override + public <AI, AO> Aggregator<AI, AO> createAggregatorInternal( + String named, + Combine.CombineFn<AI, ?, AO> combineFn) { + return mRuntimeContext.createAggregator(named, combineFn); + } + + @Override + public I element() { + return windowedValue.getValue(); + } + + @Override + public void outputWithTimestamp(O output, Instant timestamp) { + output(WindowedValue.of(output, timestamp, + windowedValue.getWindows(), windowedValue.getPane())); + } + + @Override + public Instant timestamp() { + return windowedValue.getTimestamp(); + } + + @Override + public BoundedWindow window() { + if (!(fn instanceof DoFn.RequiresWindowAccess)) { + throw new UnsupportedOperationException( + "window() is only available in the context of a DoFn marked as RequiresWindow."); + } + return Iterables.getOnlyElement(windowedValue.getWindows()); + } + + @Override + public PaneInfo pane() { + return windowedValue.getPane(); + } + + @Override + public WindowingInternals<I, O> windowingInternals() { + return new WindowingInternals<I, O>() { + + @Override + public Collection<? extends BoundedWindow> windows() { + return windowedValue.getWindows(); + } + + @Override + public void outputWindowedValue(O output, Instant timestamp, Collection<? + extends BoundedWindow> windows, PaneInfo paneInfo) { + output(WindowedValue.of(output, timestamp, windows, paneInfo)); + } + + @Override + public StateInternals stateInternals() { + //TODO: implement state internals. + // This is a temporary placeholder to get the TfIdfTest + // working for the initial Beam code drop. + return InMemoryStateInternals.forKey("DUMMY"); + } + + @Override + public TimerInternals timerInternals() { + throw new UnsupportedOperationException( + "WindowingInternals#timerInternals() is not yet supported."); + } + + @Override + public PaneInfo pane() { + return windowedValue.getPane(); + } + + @Override + public <T> void writePCollectionViewData(TupleTag<?> tag, + Iterable<WindowedValue<T>> data, Coder<T> elemCoder) throws IOException { + throw new UnsupportedOperationException( + "WindowingInternals#writePCollectionViewData() is not yet supported."); + } + + @Override + public <T> T sideInput(PCollectionView<T> view, BoundedWindow mainInputWindow) { + throw new UnsupportedOperationException( + "WindowingInternals#sideInput() is not yet supported."); + } + }; + } + + protected abstract void clearOutput(); + protected abstract Iterator<V> getOutputIterator(); + + protected Iterable<V> getOutputIterable(final Iterator<WindowedValue<I>> iter, + final DoFn<I, O> doFn) { + return new Iterable<V>() { + @Override + public Iterator<V> iterator() { + return new ProcCtxtIterator(iter, doFn); + } + }; + } + + private class ProcCtxtIterator extends AbstractIterator<V> { + + private final Iterator<WindowedValue<I>> inputIterator; + private final DoFn<I, O> doFn; + private Iterator<V> outputIterator; + private boolean calledFinish; + + ProcCtxtIterator(Iterator<WindowedValue<I>> iterator, DoFn<I, O> doFn) { + this.inputIterator = iterator; + this.doFn = doFn; + this.outputIterator = getOutputIterator(); + } + + @Override + protected V computeNext() { + // Process each element from the (input) iterator, which produces, zero, one or more + // output elements (of type V) in the output iterator. Note that the output + // collection (and iterator) is reset between each call to processElement, so the + // collection only holds the output values for each call to processElement, rather + // than for the whole partition (which would use too much memory). + while (true) { + if (outputIterator.hasNext()) { + return outputIterator.next(); + } else if (inputIterator.hasNext()) { + clearOutput(); + windowedValue = inputIterator.next(); + try { + doFn.processElement(SparkProcessContext.this); + } catch (Exception e) { + throw new SparkProcessException(e); + } + outputIterator = getOutputIterator(); + } else { + // no more input to consume, but finishBundle can produce more output + if (!calledFinish) { + clearOutput(); + try { + calledFinish = true; + doFn.finishBundle(SparkProcessContext.this); + } catch (Exception e) { + throw new SparkProcessException(e); + } + outputIterator = getOutputIterator(); + continue; // try to consume outputIterator from start of loop + } + return endOfData(); + } + } + } + } + + public static class SparkProcessException extends RuntimeException { + SparkProcessException(Throwable t) { + super(t); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java new file mode 100644 index 0000000..bf618c4 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import java.io.IOException; +import java.io.Serializable; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.AggregatorValues; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Max; +import com.google.cloud.dataflow.sdk.transforms.Min; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.collect.ImmutableList; + +import org.apache.beam.runners.spark.aggregators.AggAccumParam; +import org.apache.beam.runners.spark.aggregators.NamedAggregators; + +import org.apache.spark.Accumulator; +import org.apache.spark.api.java.JavaSparkContext; + + +/** + * The SparkRuntimeContext allows us to define useful features on the client side before our + * data flow program is launched. + */ +public class SparkRuntimeContext implements Serializable { + /** + * An accumulator that is a map from names to aggregators. + */ + private final Accumulator<NamedAggregators> accum; + + private final String serializedPipelineOptions; + + /** + * Map fo names to dataflow aggregators. + */ + private final Map<String, Aggregator<?, ?>> aggregators = new HashMap<>(); + private transient CoderRegistry coderRegistry; + + SparkRuntimeContext(JavaSparkContext jsc, Pipeline pipeline) { + this.accum = jsc.accumulator(new NamedAggregators(), new AggAccumParam()); + this.serializedPipelineOptions = serializePipelineOptions(pipeline.getOptions()); + } + + private static String serializePipelineOptions(PipelineOptions pipelineOptions) { + try { + return new ObjectMapper().writeValueAsString(pipelineOptions); + } catch (JsonProcessingException e) { + throw new IllegalStateException("Failed to serialize the pipeline options.", e); + } + } + + private static PipelineOptions deserializePipelineOptions(String serializedPipelineOptions) { + try { + return new ObjectMapper().readValue(serializedPipelineOptions, PipelineOptions.class); + } catch (IOException e) { + throw new IllegalStateException("Failed to deserialize the pipeline options.", e); + } + } + + /** + * Retrieves corresponding value of an aggregator. + * + * @param aggregatorName Name of the aggregator to retrieve the value of. + * @param typeClass Type class of value to be retrieved. + * @param <T> Type of object to be returned. + * @return The value of the aggregator. + */ + public <T> T getAggregatorValue(String aggregatorName, Class<T> typeClass) { + return accum.value().getValue(aggregatorName, typeClass); + } + + public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator) { + @SuppressWarnings("unchecked") + Class<T> aggValueClass = (Class<T>) aggregator.getCombineFn().getOutputType().getRawType(); + final T aggregatorValue = getAggregatorValue(aggregator.getName(), aggValueClass); + return new AggregatorValues<T>() { + @Override + public Collection<T> getValues() { + return ImmutableList.of(aggregatorValue); + } + + @Override + public Map<String, T> getValuesAtSteps() { + throw new UnsupportedOperationException("getValuesAtSteps is not supported."); + } + }; + } + + public synchronized PipelineOptions getPipelineOptions() { + return deserializePipelineOptions(serializedPipelineOptions); + } + + /** + * Creates and aggregator and associates it with the specified name. + * + * @param named Name of aggregator. + * @param combineFn Combine function used in aggregation. + * @param <IN> Type of inputs to aggregator. + * @param <INTER> Intermediate data type + * @param <OUT> Type of aggregator outputs. + * @return Specified aggregator + */ + public synchronized <IN, INTER, OUT> Aggregator<IN, OUT> createAggregator( + String named, + Combine.CombineFn<? super IN, INTER, OUT> combineFn) { + @SuppressWarnings("unchecked") + Aggregator<IN, OUT> aggregator = (Aggregator<IN, OUT>) aggregators.get(named); + if (aggregator == null) { + @SuppressWarnings("unchecked") + NamedAggregators.CombineFunctionState<IN, INTER, OUT> state = + new NamedAggregators.CombineFunctionState<>( + (Combine.CombineFn<IN, INTER, OUT>) combineFn, + (Coder<IN>) getCoder(combineFn), + this); + accum.add(new NamedAggregators(named, state)); + aggregator = new SparkAggregator<>(named, state); + aggregators.put(named, aggregator); + } + return aggregator; + } + + public CoderRegistry getCoderRegistry() { + if (coderRegistry == null) { + coderRegistry = new CoderRegistry(); + coderRegistry.registerStandardCoders(); + } + return coderRegistry; + } + + private Coder<?> getCoder(Combine.CombineFn<?, ?, ?> combiner) { + try { + if (combiner.getClass() == Sum.SumIntegerFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Integer.class)); + } else if (combiner.getClass() == Sum.SumLongFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Long.class)); + } else if (combiner.getClass() == Sum.SumDoubleFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Double.class)); + } else if (combiner.getClass() == Min.MinIntegerFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Integer.class)); + } else if (combiner.getClass() == Min.MinLongFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Long.class)); + } else if (combiner.getClass() == Min.MinDoubleFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Double.class)); + } else if (combiner.getClass() == Max.MaxIntegerFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Integer.class)); + } else if (combiner.getClass() == Max.MaxLongFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Long.class)); + } else if (combiner.getClass() == Max.MaxDoubleFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Double.class)); + } else { + throw new IllegalArgumentException("unsupported combiner in Aggregator: " + + combiner.getClass().getName()); + } + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Could not determine default coder for combiner", e); + } + } + + /** + * Initialize spark aggregators exactly once. + * + * @param <IN> Type of element fed in to aggregator. + */ + private static class SparkAggregator<IN, OUT> implements Aggregator<IN, OUT>, Serializable { + private final String name; + private final NamedAggregators.State<IN, ?, OUT> state; + + SparkAggregator(String name, NamedAggregators.State<IN, ?, OUT> state) { + this.name = name; + this.state = state; + } + + @Override + public String getName() { + return name; + } + + @Override + public void addValue(IN elem) { + state.update(elem); + } + + @Override + public Combine.CombineFn<IN, ?, OUT> getCombineFn() { + return state.getCombineFn(); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformEvaluator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformEvaluator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformEvaluator.java new file mode 100644 index 0000000..d8481bf --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformEvaluator.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import java.io.Serializable; + +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +public interface TransformEvaluator<PT extends PTransform<?, ?>> extends Serializable { + void evaluate(PT transform, EvaluationContext context); +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java new file mode 100644 index 0000000..0bd047c --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -0,0 +1,808 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputDirectory; +import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputFilePrefix; +import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputFileTemplate; +import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.replaceShardCount; + +import java.io.IOException; +import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import com.google.api.client.util.Maps; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.io.AvroIO; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.AssignWindowsDoFn; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.ImmutableMap; + +import org.apache.avro.mapred.AvroKey; +import org.apache.avro.mapreduce.AvroJob; +import org.apache.avro.mapreduce.AvroKeyInputFormat; +import org.apache.beam.runners.spark.coders.CoderHelpers; +import org.apache.beam.runners.spark.io.hadoop.HadoopIO; +import org.apache.beam.runners.spark.io.hadoop.ShardNameTemplateHelper; +import org.apache.beam.runners.spark.io.hadoop.TemplatedAvroKeyOutputFormat; +import org.apache.beam.runners.spark.io.hadoop.TemplatedTextOutputFormat; +import org.apache.beam.runners.spark.util.BroadcastHelper; +import org.apache.beam.runners.spark.util.ByteArray; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaRDDLike; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import scala.Tuple2; + +/** + * Supports translation between a DataFlow transform, and Spark's operations on RDDs. + */ +public final class TransformTranslator { + + private TransformTranslator() { + } + + public static class FieldGetter { + private final Map<String, Field> fields; + + public FieldGetter(Class<?> clazz) { + this.fields = Maps.newHashMap(); + for (Field f : clazz.getDeclaredFields()) { + f.setAccessible(true); + this.fields.put(f.getName(), f); + } + } + + public <T> T get(String fieldname, Object value) { + try { + @SuppressWarnings("unchecked") + T fieldValue = (T) fields.get(fieldname).get(value); + return fieldValue; + } catch (IllegalAccessException e) { + throw new IllegalStateException(e); + } + } + } + + private static <T> TransformEvaluator<Flatten.FlattenPCollectionList<T>> flattenPColl() { + return new TransformEvaluator<Flatten.FlattenPCollectionList<T>>() { + @SuppressWarnings("unchecked") + @Override + public void evaluate(Flatten.FlattenPCollectionList<T> transform, EvaluationContext context) { + PCollectionList<T> pcs = context.getInput(transform); + JavaRDD<WindowedValue<T>>[] rdds = new JavaRDD[pcs.size()]; + for (int i = 0; i < rdds.length; i++) { + rdds[i] = (JavaRDD<WindowedValue<T>>) context.getRDD(pcs.get(i)); + } + JavaRDD<WindowedValue<T>> rdd = context.getSparkContext().union(rdds); + context.setOutputRDD(transform, rdd); + } + }; + } + + private static <K, V> TransformEvaluator<GroupByKey.GroupByKeyOnly<K, V>> gbk() { + return new TransformEvaluator<GroupByKey.GroupByKeyOnly<K, V>>() { + @Override + public void evaluate(GroupByKey.GroupByKeyOnly<K, V> transform, EvaluationContext context) { + @SuppressWarnings("unchecked") + JavaRDDLike<WindowedValue<KV<K, V>>, ?> inRDD = + (JavaRDDLike<WindowedValue<KV<K, V>>, ?>) context.getInputRDD(transform); + @SuppressWarnings("unchecked") + KvCoder<K, V> coder = (KvCoder<K, V>) context.getInput(transform).getCoder(); + Coder<K> keyCoder = coder.getKeyCoder(); + Coder<V> valueCoder = coder.getValueCoder(); + + // Use coders to convert objects in the PCollection to byte arrays, so they + // can be transferred over the network for the shuffle. + JavaRDDLike<WindowedValue<KV<K, Iterable<V>>>, ?> outRDD = fromPair( + toPair(inRDD.map(WindowingHelpers.<KV<K, V>>unwindowFunction())) + .mapToPair(CoderHelpers.toByteFunction(keyCoder, valueCoder)) + .groupByKey() + .mapToPair(CoderHelpers.fromByteFunctionIterable(keyCoder, valueCoder))) + // empty windows are OK here, see GroupByKey#evaluateHelper in the SDK + .map(WindowingHelpers.<KV<K, Iterable<V>>>windowFunction()); + context.setOutputRDD(transform, outRDD); + } + }; + } + + private static final FieldGetter GROUPED_FG = new FieldGetter(Combine.GroupedValues.class); + + private static <K, VI, VO> TransformEvaluator<Combine.GroupedValues<K, VI, VO>> grouped() { + return new TransformEvaluator<Combine.GroupedValues<K, VI, VO>>() { + @Override + public void evaluate(Combine.GroupedValues<K, VI, VO> transform, EvaluationContext context) { + Combine.KeyedCombineFn<K, VI, ?, VO> keyed = GROUPED_FG.get("fn", transform); + @SuppressWarnings("unchecked") + JavaRDDLike<WindowedValue<KV<K, Iterable<VI>>>, ?> inRDD = + (JavaRDDLike<WindowedValue<KV<K, Iterable<VI>>>, ?>) context.getInputRDD(transform); + context.setOutputRDD(transform, + inRDD.map(new KVFunction<>(keyed))); + } + }; + } + + private static final FieldGetter COMBINE_GLOBALLY_FG = new FieldGetter(Combine.Globally.class); + + private static <I, A, O> TransformEvaluator<Combine.Globally<I, O>> combineGlobally() { + return new TransformEvaluator<Combine.Globally<I, O>>() { + + @Override + public void evaluate(Combine.Globally<I, O> transform, EvaluationContext context) { + final Combine.CombineFn<I, A, O> globally = COMBINE_GLOBALLY_FG.get("fn", transform); + + @SuppressWarnings("unchecked") + JavaRDDLike<WindowedValue<I>, ?> inRdd = + (JavaRDDLike<WindowedValue<I>, ?>) context.getInputRDD(transform); + + final Coder<I> iCoder = context.getInput(transform).getCoder(); + final Coder<A> aCoder; + try { + aCoder = globally.getAccumulatorCoder( + context.getPipeline().getCoderRegistry(), iCoder); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Could not determine coder for accumulator", e); + } + + // Use coders to convert objects in the PCollection to byte arrays, so they + // can be transferred over the network for the shuffle. + JavaRDD<byte[]> inRddBytes = inRdd + .map(WindowingHelpers.<I>unwindowFunction()) + .map(CoderHelpers.toByteFunction(iCoder)); + + /*A*/ byte[] acc = inRddBytes.aggregate( + CoderHelpers.toByteArray(globally.createAccumulator(), aCoder), + new Function2</*A*/ byte[], /*I*/ byte[], /*A*/ byte[]>() { + @Override + public /*A*/ byte[] call(/*A*/ byte[] ab, /*I*/ byte[] ib) throws Exception { + A a = CoderHelpers.fromByteArray(ab, aCoder); + I i = CoderHelpers.fromByteArray(ib, iCoder); + return CoderHelpers.toByteArray(globally.addInput(a, i), aCoder); + } + }, + new Function2</*A*/ byte[], /*A*/ byte[], /*A*/ byte[]>() { + @Override + public /*A*/ byte[] call(/*A*/ byte[] a1b, /*A*/ byte[] a2b) throws Exception { + A a1 = CoderHelpers.fromByteArray(a1b, aCoder); + A a2 = CoderHelpers.fromByteArray(a2b, aCoder); + // don't use Guava's ImmutableList.of as values may be null + List<A> accumulators = Collections.unmodifiableList(Arrays.asList(a1, a2)); + A merged = globally.mergeAccumulators(accumulators); + return CoderHelpers.toByteArray(merged, aCoder); + } + } + ); + O output = globally.extractOutput(CoderHelpers.fromByteArray(acc, aCoder)); + + Coder<O> coder = context.getOutput(transform).getCoder(); + JavaRDD<byte[]> outRdd = context.getSparkContext().parallelize( + // don't use Guava's ImmutableList.of as output may be null + CoderHelpers.toByteArrays(Collections.singleton(output), coder)); + context.setOutputRDD(transform, outRdd.map(CoderHelpers.fromByteFunction(coder)) + .map(WindowingHelpers.<O>windowFunction())); + } + }; + } + + private static final FieldGetter COMBINE_PERKEY_FG = new FieldGetter(Combine.PerKey.class); + + private static <K, VI, VA, VO> TransformEvaluator<Combine.PerKey<K, VI, VO>> combinePerKey() { + return new TransformEvaluator<Combine.PerKey<K, VI, VO>>() { + @Override + public void evaluate(Combine.PerKey<K, VI, VO> transform, EvaluationContext context) { + final Combine.KeyedCombineFn<K, VI, VA, VO> keyed = + COMBINE_PERKEY_FG.get("fn", transform); + @SuppressWarnings("unchecked") + JavaRDDLike<WindowedValue<KV<K, VI>>, ?> inRdd = + (JavaRDDLike<WindowedValue<KV<K, VI>>, ?>) context.getInputRDD(transform); + + @SuppressWarnings("unchecked") + KvCoder<K, VI> inputCoder = (KvCoder<K, VI>) context.getInput(transform).getCoder(); + Coder<K> keyCoder = inputCoder.getKeyCoder(); + Coder<VI> viCoder = inputCoder.getValueCoder(); + Coder<VA> vaCoder; + try { + vaCoder = keyed.getAccumulatorCoder( + context.getPipeline().getCoderRegistry(), keyCoder, viCoder); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Could not determine coder for accumulator", e); + } + Coder<KV<K, VI>> kviCoder = KvCoder.of(keyCoder, viCoder); + Coder<KV<K, VA>> kvaCoder = KvCoder.of(keyCoder, vaCoder); + + // We need to duplicate K as both the key of the JavaPairRDD as well as inside the value, + // since the functions passed to combineByKey don't receive the associated key of each + // value, and we need to map back into methods in Combine.KeyedCombineFn, which each + // require the key in addition to the VI's and VA's being merged/accumulated. Once Spark + // provides a way to include keys in the arguments of combine/merge functions, we won't + // need to duplicate the keys anymore. + + // Key has to bw windowed in order to group by window as well + JavaPairRDD<WindowedValue<K>, WindowedValue<KV<K, VI>>> inRddDuplicatedKeyPair = + inRdd.mapToPair( + new PairFunction<WindowedValue<KV<K, VI>>, WindowedValue<K>, + WindowedValue<KV<K, VI>>>() { + @Override + public Tuple2<WindowedValue<K>, + WindowedValue<KV<K, VI>>> call(WindowedValue<KV<K, VI>> kv) { + WindowedValue<K> wk = WindowedValue.of(kv.getValue().getKey(), + kv.getTimestamp(), kv.getWindows(), kv.getPane()); + return new Tuple2<>(wk, kv); + } + }); + //-- windowed coders + final WindowedValue.FullWindowedValueCoder<K> wkCoder = + WindowedValue.FullWindowedValueCoder.of(keyCoder, + context.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); + final WindowedValue.FullWindowedValueCoder<KV<K, VI>> wkviCoder = + WindowedValue.FullWindowedValueCoder.of(kviCoder, + context.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); + final WindowedValue.FullWindowedValueCoder<KV<K, VA>> wkvaCoder = + WindowedValue.FullWindowedValueCoder.of(kvaCoder, + context.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); + + // Use coders to convert objects in the PCollection to byte arrays, so they + // can be transferred over the network for the shuffle. + JavaPairRDD<ByteArray, byte[]> inRddDuplicatedKeyPairBytes = inRddDuplicatedKeyPair + .mapToPair(CoderHelpers.toByteFunction(wkCoder, wkviCoder)); + + // The output of combineByKey will be "VA" (accumulator) types rather than "VO" (final + // output types) since Combine.CombineFn only provides ways to merge VAs, and no way + // to merge VOs. + JavaPairRDD</*K*/ ByteArray, /*KV<K, VA>*/ byte[]> accumulatedBytes = + inRddDuplicatedKeyPairBytes.combineByKey( + new Function</*KV<K, VI>*/ byte[], /*KV<K, VA>*/ byte[]>() { + @Override + public /*KV<K, VA>*/ byte[] call(/*KV<K, VI>*/ byte[] input) { + WindowedValue<KV<K, VI>> wkvi = CoderHelpers.fromByteArray(input, wkviCoder); + VA va = keyed.createAccumulator(wkvi.getValue().getKey()); + va = keyed.addInput(wkvi.getValue().getKey(), va, wkvi.getValue().getValue()); + WindowedValue<KV<K, VA>> wkva = + WindowedValue.of(KV.of(wkvi.getValue().getKey(), va), wkvi.getTimestamp(), + wkvi.getWindows(), wkvi.getPane()); + return CoderHelpers.toByteArray(wkva, wkvaCoder); + } + }, + new Function2</*KV<K, VA>*/ byte[], /*KV<K, VI>*/ byte[], /*KV<K, VA>*/ byte[]>() { + @Override + public /*KV<K, VA>*/ byte[] call(/*KV<K, VA>*/ byte[] acc, + /*KV<K, VI>*/ byte[] input) { + WindowedValue<KV<K, VA>> wkva = CoderHelpers.fromByteArray(acc, wkvaCoder); + WindowedValue<KV<K, VI>> wkvi = CoderHelpers.fromByteArray(input, wkviCoder); + VA va = keyed.addInput(wkva.getValue().getKey(), wkva.getValue().getValue(), + wkvi.getValue().getValue()); + wkva = WindowedValue.of(KV.of(wkva.getValue().getKey(), va), wkva.getTimestamp(), + wkva.getWindows(), wkva.getPane()); + return CoderHelpers.toByteArray(wkva, wkvaCoder); + } + }, + new Function2</*KV<K, VA>*/ byte[], /*KV<K, VA>*/ byte[], /*KV<K, VA>*/ byte[]>() { + @Override + public /*KV<K, VA>*/ byte[] call(/*KV<K, VA>*/ byte[] acc1, + /*KV<K, VA>*/ byte[] acc2) { + WindowedValue<KV<K, VA>> wkva1 = CoderHelpers.fromByteArray(acc1, wkvaCoder); + WindowedValue<KV<K, VA>> wkva2 = CoderHelpers.fromByteArray(acc2, wkvaCoder); + VA va = keyed.mergeAccumulators(wkva1.getValue().getKey(), + // don't use Guava's ImmutableList.of as values may be null + Collections.unmodifiableList(Arrays.asList(wkva1.getValue().getValue(), + wkva2.getValue().getValue()))); + WindowedValue<KV<K, VA>> wkva = WindowedValue.of(KV.of(wkva1.getValue().getKey(), + va), wkva1.getTimestamp(), wkva1.getWindows(), wkva1.getPane()); + return CoderHelpers.toByteArray(wkva, wkvaCoder); + } + }); + + JavaPairRDD<WindowedValue<K>, WindowedValue<VO>> extracted = accumulatedBytes + .mapToPair(CoderHelpers.fromByteFunction(wkCoder, wkvaCoder)) + .mapValues( + new Function<WindowedValue<KV<K, VA>>, WindowedValue<VO>>() { + @Override + public WindowedValue<VO> call(WindowedValue<KV<K, VA>> acc) { + return WindowedValue.of(keyed.extractOutput(acc.getValue().getKey(), + acc.getValue().getValue()), acc.getTimestamp(), + acc.getWindows(), acc.getPane()); + } + }); + + context.setOutputRDD(transform, + fromPair(extracted) + .map(new Function<KV<WindowedValue<K>, WindowedValue<VO>>, WindowedValue<KV<K, VO>>>() { + @Override + public WindowedValue<KV<K, VO>> call(KV<WindowedValue<K>, WindowedValue<VO>> kwvo) + throws Exception { + WindowedValue<VO> wvo = kwvo.getValue(); + KV<K, VO> kvo = KV.of(kwvo.getKey().getValue(), wvo.getValue()); + return WindowedValue.of(kvo, wvo.getTimestamp(), wvo.getWindows(), wvo.getPane()); + } + })); + } + }; + } + + private static final class KVFunction<K, VI, VO> + implements Function<WindowedValue<KV<K, Iterable<VI>>>, WindowedValue<KV<K, VO>>> { + private final Combine.KeyedCombineFn<K, VI, ?, VO> keyed; + + KVFunction(Combine.KeyedCombineFn<K, VI, ?, VO> keyed) { + this.keyed = keyed; + } + + @Override + public WindowedValue<KV<K, VO>> call(WindowedValue<KV<K, Iterable<VI>>> windowedKv) + throws Exception { + KV<K, Iterable<VI>> kv = windowedKv.getValue(); + return WindowedValue.of(KV.of(kv.getKey(), keyed.apply(kv.getKey(), kv.getValue())), + windowedKv.getTimestamp(), windowedKv.getWindows(), windowedKv.getPane()); + } + } + + private static <K, V> JavaPairRDD<K, V> toPair(JavaRDDLike<KV<K, V>, ?> rdd) { + return rdd.mapToPair(new PairFunction<KV<K, V>, K, V>() { + @Override + public Tuple2<K, V> call(KV<K, V> kv) { + return new Tuple2<>(kv.getKey(), kv.getValue()); + } + }); + } + + private static <K, V> JavaRDDLike<KV<K, V>, ?> fromPair(JavaPairRDD<K, V> rdd) { + return rdd.map(new Function<Tuple2<K, V>, KV<K, V>>() { + @Override + public KV<K, V> call(Tuple2<K, V> t2) { + return KV.of(t2._1(), t2._2()); + } + }); + } + + private static <I, O> TransformEvaluator<ParDo.Bound<I, O>> parDo() { + return new TransformEvaluator<ParDo.Bound<I, O>>() { + @Override + public void evaluate(ParDo.Bound<I, O> transform, EvaluationContext context) { + DoFnFunction<I, O> dofn = + new DoFnFunction<>(transform.getFn(), + context.getRuntimeContext(), + getSideInputs(transform.getSideInputs(), context)); + @SuppressWarnings("unchecked") + JavaRDDLike<WindowedValue<I>, ?> inRDD = + (JavaRDDLike<WindowedValue<I>, ?>) context.getInputRDD(transform); + context.setOutputRDD(transform, inRDD.mapPartitions(dofn)); + } + }; + } + + private static final FieldGetter MULTIDO_FG = new FieldGetter(ParDo.BoundMulti.class); + + private static <I, O> TransformEvaluator<ParDo.BoundMulti<I, O>> multiDo() { + return new TransformEvaluator<ParDo.BoundMulti<I, O>>() { + @Override + public void evaluate(ParDo.BoundMulti<I, O> transform, EvaluationContext context) { + TupleTag<O> mainOutputTag = MULTIDO_FG.get("mainOutputTag", transform); + MultiDoFnFunction<I, O> multifn = new MultiDoFnFunction<>( + transform.getFn(), + context.getRuntimeContext(), + mainOutputTag, + getSideInputs(transform.getSideInputs(), context)); + + @SuppressWarnings("unchecked") + JavaRDDLike<WindowedValue<I>, ?> inRDD = + (JavaRDDLike<WindowedValue<I>, ?>) context.getInputRDD(transform); + JavaPairRDD<TupleTag<?>, WindowedValue<?>> all = inRDD + .mapPartitionsToPair(multifn) + .cache(); + + PCollectionTuple pct = context.getOutput(transform); + for (Map.Entry<TupleTag<?>, PCollection<?>> e : pct.getAll().entrySet()) { + @SuppressWarnings("unchecked") + JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered = + all.filter(new TupleTagFilter(e.getKey())); + @SuppressWarnings("unchecked") + // Object is the best we can do since different outputs can have different tags + JavaRDD<WindowedValue<Object>> values = + (JavaRDD<WindowedValue<Object>>) (JavaRDD<?>) filtered.values(); + context.setRDD(e.getValue(), values); + } + } + }; + } + + + private static <T> TransformEvaluator<TextIO.Read.Bound<T>> readText() { + return new TransformEvaluator<TextIO.Read.Bound<T>>() { + @Override + public void evaluate(TextIO.Read.Bound<T> transform, EvaluationContext context) { + String pattern = transform.getFilepattern(); + JavaRDD<WindowedValue<String>> rdd = context.getSparkContext().textFile(pattern) + .map(WindowingHelpers.<String>windowFunction()); + context.setOutputRDD(transform, rdd); + } + }; + } + + private static <T> TransformEvaluator<TextIO.Write.Bound<T>> writeText() { + return new TransformEvaluator<TextIO.Write.Bound<T>>() { + @Override + public void evaluate(TextIO.Write.Bound<T> transform, EvaluationContext context) { + @SuppressWarnings("unchecked") + JavaPairRDD<T, Void> last = + ((JavaRDDLike<WindowedValue<T>, ?>) context.getInputRDD(transform)) + .map(WindowingHelpers.<T>unwindowFunction()) + .mapToPair(new PairFunction<T, T, + Void>() { + @Override + public Tuple2<T, Void> call(T t) throws Exception { + return new Tuple2<>(t, null); + } + }); + ShardTemplateInformation shardTemplateInfo = + new ShardTemplateInformation(transform.getNumShards(), + transform.getShardTemplate(), transform.getFilenamePrefix(), + transform.getFilenameSuffix()); + writeHadoopFile(last, new Configuration(), shardTemplateInfo, Text.class, + NullWritable.class, TemplatedTextOutputFormat.class); + } + }; + } + + private static <T> TransformEvaluator<AvroIO.Read.Bound<T>> readAvro() { + return new TransformEvaluator<AvroIO.Read.Bound<T>>() { + @Override + public void evaluate(AvroIO.Read.Bound<T> transform, EvaluationContext context) { + String pattern = transform.getFilepattern(); + JavaSparkContext jsc = context.getSparkContext(); + @SuppressWarnings("unchecked") + JavaRDD<AvroKey<T>> avroFile = (JavaRDD<AvroKey<T>>) (JavaRDD<?>) + jsc.newAPIHadoopFile(pattern, + AvroKeyInputFormat.class, + AvroKey.class, NullWritable.class, + new Configuration()).keys(); + JavaRDD<WindowedValue<T>> rdd = avroFile.map( + new Function<AvroKey<T>, T>() { + @Override + public T call(AvroKey<T> key) { + return key.datum(); + } + }).map(WindowingHelpers.<T>windowFunction()); + context.setOutputRDD(transform, rdd); + } + }; + } + + private static <T> TransformEvaluator<AvroIO.Write.Bound<T>> writeAvro() { + return new TransformEvaluator<AvroIO.Write.Bound<T>>() { + @Override + public void evaluate(AvroIO.Write.Bound<T> transform, EvaluationContext context) { + Job job; + try { + job = Job.getInstance(); + } catch (IOException e) { + throw new IllegalStateException(e); + } + AvroJob.setOutputKeySchema(job, transform.getSchema()); + @SuppressWarnings("unchecked") + JavaPairRDD<AvroKey<T>, NullWritable> last = + ((JavaRDDLike<WindowedValue<T>, ?>) context.getInputRDD(transform)) + .map(WindowingHelpers.<T>unwindowFunction()) + .mapToPair(new PairFunction<T, AvroKey<T>, NullWritable>() { + @Override + public Tuple2<AvroKey<T>, NullWritable> call(T t) throws Exception { + return new Tuple2<>(new AvroKey<>(t), NullWritable.get()); + } + }); + ShardTemplateInformation shardTemplateInfo = + new ShardTemplateInformation(transform.getNumShards(), + transform.getShardTemplate(), transform.getFilenamePrefix(), + transform.getFilenameSuffix()); + writeHadoopFile(last, job.getConfiguration(), shardTemplateInfo, + AvroKey.class, NullWritable.class, TemplatedAvroKeyOutputFormat.class); + } + }; + } + + private static <K, V> TransformEvaluator<HadoopIO.Read.Bound<K, V>> readHadoop() { + return new TransformEvaluator<HadoopIO.Read.Bound<K, V>>() { + @Override + public void evaluate(HadoopIO.Read.Bound<K, V> transform, EvaluationContext context) { + String pattern = transform.getFilepattern(); + JavaSparkContext jsc = context.getSparkContext(); + @SuppressWarnings ("unchecked") + JavaPairRDD<K, V> file = jsc.newAPIHadoopFile(pattern, + transform.getFormatClass(), + transform.getKeyClass(), transform.getValueClass(), + new Configuration()); + JavaRDD<WindowedValue<KV<K, V>>> rdd = + file.map(new Function<Tuple2<K, V>, KV<K, V>>() { + @Override + public KV<K, V> call(Tuple2<K, V> t2) throws Exception { + return KV.of(t2._1(), t2._2()); + } + }).map(WindowingHelpers.<KV<K, V>>windowFunction()); + context.setOutputRDD(transform, rdd); + } + }; + } + + private static <K, V> TransformEvaluator<HadoopIO.Write.Bound<K, V>> writeHadoop() { + return new TransformEvaluator<HadoopIO.Write.Bound<K, V>>() { + @Override + public void evaluate(HadoopIO.Write.Bound<K, V> transform, EvaluationContext context) { + @SuppressWarnings("unchecked") + JavaPairRDD<K, V> last = ((JavaRDDLike<WindowedValue<KV<K, V>>, ?>) context + .getInputRDD(transform)) + .map(WindowingHelpers.<KV<K, V>>unwindowFunction()) + .mapToPair(new PairFunction<KV<K, V>, K, V>() { + @Override + public Tuple2<K, V> call(KV<K, V> t) throws Exception { + return new Tuple2<>(t.getKey(), t.getValue()); + } + }); + ShardTemplateInformation shardTemplateInfo = + new ShardTemplateInformation(transform.getNumShards(), + transform.getShardTemplate(), transform.getFilenamePrefix(), + transform.getFilenameSuffix()); + Configuration conf = new Configuration(); + for (Map.Entry<String, String> e : transform.getConfigurationProperties().entrySet()) { + conf.set(e.getKey(), e.getValue()); + } + writeHadoopFile(last, conf, shardTemplateInfo, + transform.getKeyClass(), transform.getValueClass(), transform.getFormatClass()); + } + }; + } + + private static final class ShardTemplateInformation { + private final int numShards; + private final String shardTemplate; + private final String filenamePrefix; + private final String filenameSuffix; + + private ShardTemplateInformation(int numShards, String shardTemplate, String + filenamePrefix, String filenameSuffix) { + this.numShards = numShards; + this.shardTemplate = shardTemplate; + this.filenamePrefix = filenamePrefix; + this.filenameSuffix = filenameSuffix; + } + + int getNumShards() { + return numShards; + } + + String getShardTemplate() { + return shardTemplate; + } + + String getFilenamePrefix() { + return filenamePrefix; + } + + String getFilenameSuffix() { + return filenameSuffix; + } + } + + private static <K, V> void writeHadoopFile(JavaPairRDD<K, V> rdd, Configuration conf, + ShardTemplateInformation shardTemplateInfo, Class<?> keyClass, Class<?> valueClass, + Class<? extends FileOutputFormat> formatClass) { + int numShards = shardTemplateInfo.getNumShards(); + String shardTemplate = shardTemplateInfo.getShardTemplate(); + String filenamePrefix = shardTemplateInfo.getFilenamePrefix(); + String filenameSuffix = shardTemplateInfo.getFilenameSuffix(); + if (numShards != 0) { + // number of shards was set explicitly, so repartition + rdd = rdd.repartition(numShards); + } + int actualNumShards = rdd.partitions().size(); + String template = replaceShardCount(shardTemplate, actualNumShards); + String outputDir = getOutputDirectory(filenamePrefix, template); + String filePrefix = getOutputFilePrefix(filenamePrefix, template); + String fileTemplate = getOutputFileTemplate(filenamePrefix, template); + + conf.set(ShardNameTemplateHelper.OUTPUT_FILE_PREFIX, filePrefix); + conf.set(ShardNameTemplateHelper.OUTPUT_FILE_TEMPLATE, fileTemplate); + conf.set(ShardNameTemplateHelper.OUTPUT_FILE_SUFFIX, filenameSuffix); + rdd.saveAsNewAPIHadoopFile(outputDir, keyClass, valueClass, formatClass, conf); + } + + private static final FieldGetter WINDOW_FG = new FieldGetter(Window.Bound.class); + + private static <T, W extends BoundedWindow> TransformEvaluator<Window.Bound<T>> window() { + return new TransformEvaluator<Window.Bound<T>>() { + @Override + public void evaluate(Window.Bound<T> transform, EvaluationContext context) { + @SuppressWarnings("unchecked") + JavaRDDLike<WindowedValue<T>, ?> inRDD = + (JavaRDDLike<WindowedValue<T>, ?>) context.getInputRDD(transform); + WindowFn<? super T, W> windowFn = WINDOW_FG.get("windowFn", transform); + if (windowFn instanceof GlobalWindows) { + context.setOutputRDD(transform, inRDD); + } else { + @SuppressWarnings("unchecked") + DoFn<T, T> addWindowsDoFn = new AssignWindowsDoFn<>(windowFn); + DoFnFunction<T, T> dofn = + new DoFnFunction<>(addWindowsDoFn, context.getRuntimeContext(), null); + context.setOutputRDD(transform, inRDD.mapPartitions(dofn)); + } + } + }; + } + + private static <T> TransformEvaluator<Create.Values<T>> create() { + return new TransformEvaluator<Create.Values<T>>() { + @Override + public void evaluate(Create.Values<T> transform, EvaluationContext context) { + Iterable<T> elems = transform.getElements(); + // Use a coder to convert the objects in the PCollection to byte arrays, so they + // can be transferred over the network. + Coder<T> coder = context.getOutput(transform).getCoder(); + context.setOutputRDDFromValues(transform, elems, coder); + } + }; + } + + private static <T> TransformEvaluator<View.AsSingleton<T>> viewAsSingleton() { + return new TransformEvaluator<View.AsSingleton<T>>() { + @Override + public void evaluate(View.AsSingleton<T> transform, EvaluationContext context) { + Iterable<? extends WindowedValue<?>> iter = + context.getWindowedValues(context.getInput(transform)); + context.setPView(context.getOutput(transform), iter); + } + }; + } + + private static <T> TransformEvaluator<View.AsIterable<T>> viewAsIter() { + return new TransformEvaluator<View.AsIterable<T>>() { + @Override + public void evaluate(View.AsIterable<T> transform, EvaluationContext context) { + Iterable<? extends WindowedValue<?>> iter = + context.getWindowedValues(context.getInput(transform)); + context.setPView(context.getOutput(transform), iter); + } + }; + } + + private static <R, W> TransformEvaluator<View.CreatePCollectionView<R, W>> createPCollView() { + return new TransformEvaluator<View.CreatePCollectionView<R, W>>() { + @Override + public void evaluate(View.CreatePCollectionView<R, W> transform, EvaluationContext context) { + Iterable<? extends WindowedValue<?>> iter = + context.getWindowedValues(context.getInput(transform)); + context.setPView(context.getOutput(transform), iter); + } + }; + } + + private static final class TupleTagFilter<V> + implements Function<Tuple2<TupleTag<V>, WindowedValue<?>>, Boolean> { + + private final TupleTag<V> tag; + + private TupleTagFilter(TupleTag<V> tag) { + this.tag = tag; + } + + @Override + public Boolean call(Tuple2<TupleTag<V>, WindowedValue<?>> input) { + return tag.equals(input._1()); + } + } + + private static Map<TupleTag<?>, BroadcastHelper<?>> getSideInputs( + List<PCollectionView<?>> views, + EvaluationContext context) { + if (views == null) { + return ImmutableMap.of(); + } else { + Map<TupleTag<?>, BroadcastHelper<?>> sideInputs = Maps.newHashMap(); + for (PCollectionView<?> view : views) { + Iterable<? extends WindowedValue<?>> collectionView = context.getPCollectionView(view); + Coder<Iterable<WindowedValue<?>>> coderInternal = view.getCoderInternal(); + @SuppressWarnings("unchecked") + BroadcastHelper<?> helper = + BroadcastHelper.create((Iterable<WindowedValue<?>>) collectionView, coderInternal); + //broadcast side inputs + helper.broadcast(context.getSparkContext()); + sideInputs.put(view.getTagInternal(), helper); + } + return sideInputs; + } + } + + private static final Map<Class<? extends PTransform>, TransformEvaluator<?>> EVALUATORS = Maps + .newHashMap(); + + static { + EVALUATORS.put(TextIO.Read.Bound.class, readText()); + EVALUATORS.put(TextIO.Write.Bound.class, writeText()); + EVALUATORS.put(AvroIO.Read.Bound.class, readAvro()); + EVALUATORS.put(AvroIO.Write.Bound.class, writeAvro()); + EVALUATORS.put(HadoopIO.Read.Bound.class, readHadoop()); + EVALUATORS.put(HadoopIO.Write.Bound.class, writeHadoop()); + EVALUATORS.put(ParDo.Bound.class, parDo()); + EVALUATORS.put(ParDo.BoundMulti.class, multiDo()); + EVALUATORS.put(GroupByKey.GroupByKeyOnly.class, gbk()); + EVALUATORS.put(Combine.GroupedValues.class, grouped()); + EVALUATORS.put(Combine.Globally.class, combineGlobally()); + EVALUATORS.put(Combine.PerKey.class, combinePerKey()); + EVALUATORS.put(Flatten.FlattenPCollectionList.class, flattenPColl()); + EVALUATORS.put(Create.Values.class, create()); + EVALUATORS.put(View.AsSingleton.class, viewAsSingleton()); + EVALUATORS.put(View.AsIterable.class, viewAsIter()); + EVALUATORS.put(View.CreatePCollectionView.class, createPCollView()); + EVALUATORS.put(Window.Bound.class, window()); + } + + public static <PT extends PTransform<?, ?>> TransformEvaluator<PT> + getTransformEvaluator(Class<PT> clazz) { + @SuppressWarnings("unchecked") + TransformEvaluator<PT> transform = (TransformEvaluator<PT>) EVALUATORS.get(clazz); + if (transform == null) { + throw new IllegalStateException("No TransformEvaluator registered for " + clazz); + } + return transform; + } + + /** + * Translator matches Dataflow transformation with the appropriate evaluator. + */ + public static class Translator implements SparkPipelineTranslator { + + @Override + public boolean hasTranslation(Class<? extends PTransform<?, ?>> clazz) { + return EVALUATORS.containsKey(clazz); + } + + @Override + public <PT extends PTransform<?, ?>> TransformEvaluator<PT> translate(Class<PT> clazz) { + return getTransformEvaluator(clazz); + } + } +}
