http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationContext.java new file mode 100644 index 0000000..836987f --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationContext.java @@ -0,0 +1,284 @@ +/* + * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark; + +import static com.google.common.base.Preconditions.checkArgument; + +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.runners.AggregatorRetrievalException; +import com.google.cloud.dataflow.sdk.runners.AggregatorValues; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.Function; +import com.google.common.collect.Iterables; +import org.apache.beam.runners.spark.coders.CoderHelpers; +import org.apache.spark.api.java.JavaRDDLike; +import org.apache.spark.api.java.JavaSparkContext; + + +/** + * Evaluation context allows us to define how pipeline instructions. + */ +public class EvaluationContext implements EvaluationResult { + private final JavaSparkContext jsc; + private final Pipeline pipeline; + private final SparkRuntimeContext runtime; + private final Map<PValue, RDDHolder<?>> pcollections = new LinkedHashMap<>(); + private final Set<RDDHolder<?>> leafRdds = new LinkedHashSet<>(); + private final Set<PValue> multireads = new LinkedHashSet<>(); + private final Map<PValue, Object> pobjects = new LinkedHashMap<>(); + private final Map<PValue, Iterable<? extends WindowedValue<?>>> pview = new LinkedHashMap<>(); + protected AppliedPTransform<?, ?, ?> currentTransform; + + public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline) { + this.jsc = jsc; + this.pipeline = pipeline; + this.runtime = new SparkRuntimeContext(jsc, pipeline); + } + + /** + * Holds an RDD or values for deferred conversion to an RDD if needed. PCollections are + * sometimes created from a collection of objects (using RDD parallelize) and then + * only used to create View objects; in which case they do not need to be + * converted to bytes since they are not transferred across the network until they are + * broadcast. + */ + private class RDDHolder<T> { + + private Iterable<T> values; + private Coder<T> coder; + private JavaRDDLike<WindowedValue<T>, ?> rdd; + + RDDHolder(Iterable<T> values, Coder<T> coder) { + this.values = values; + this.coder = coder; + } + + RDDHolder(JavaRDDLike<WindowedValue<T>, ?> rdd) { + this.rdd = rdd; + } + + JavaRDDLike<WindowedValue<T>, ?> getRDD() { + if (rdd == null) { + Iterable<WindowedValue<T>> windowedValues = Iterables.transform(values, + new Function<T, WindowedValue<T>>() { + @Override + public WindowedValue<T> apply(T t) { + // TODO: this is wrong if T is a TimestampedValue + return WindowedValue.valueInEmptyWindows(t); + } + }); + WindowedValue.ValueOnlyWindowedValueCoder<T> windowCoder = + WindowedValue.getValueOnlyCoder(coder); + rdd = jsc.parallelize(CoderHelpers.toByteArrays(windowedValues, windowCoder)) + .map(CoderHelpers.fromByteFunction(windowCoder)); + } + return rdd; + } + + Iterable<T> getValues(PCollection<T> pcollection) { + if (values == null) { + coder = pcollection.getCoder(); + JavaRDDLike<byte[], ?> bytesRDD = rdd.map(WindowingHelpers.<T>unwindowFunction()) + .map(CoderHelpers.toByteFunction(coder)); + List<byte[]> clientBytes = bytesRDD.collect(); + values = Iterables.transform(clientBytes, new Function<byte[], T>() { + @Override + public T apply(byte[] bytes) { + return CoderHelpers.fromByteArray(bytes, coder); + } + }); + } + return values; + } + + Iterable<WindowedValue<T>> getWindowedValues(PCollection<T> pcollection) { + return Iterables.transform(get(pcollection), new Function<T, WindowedValue<T>>() { + @Override + public WindowedValue<T> apply(T t) { + return WindowedValue.valueInEmptyWindows(t); // TODO: not the right place? + } + }); + } + } + + protected JavaSparkContext getSparkContext() { + return jsc; + } + + protected Pipeline getPipeline() { + return pipeline; + } + + protected SparkRuntimeContext getRuntimeContext() { + return runtime; + } + + protected void setCurrentTransform(AppliedPTransform<?, ?, ?> transform) { + this.currentTransform = transform; + } + + protected AppliedPTransform<?, ?, ?> getCurrentTransform() { + return currentTransform; + } + + protected <I extends PInput> I getInput(PTransform<I, ?> transform) { + checkArgument(currentTransform != null && currentTransform.getTransform() == transform, + "can only be called with current transform"); + @SuppressWarnings("unchecked") + I input = (I) currentTransform.getInput(); + return input; + } + + protected <O extends POutput> O getOutput(PTransform<?, O> transform) { + checkArgument(currentTransform != null && currentTransform.getTransform() == transform, + "can only be called with current transform"); + @SuppressWarnings("unchecked") + O output = (O) currentTransform.getOutput(); + return output; + } + + protected <T> void setOutputRDD(PTransform<?, ?> transform, + JavaRDDLike<WindowedValue<T>, ?> rdd) { + setRDD((PValue) getOutput(transform), rdd); + } + + protected <T> void setOutputRDDFromValues(PTransform<?, ?> transform, Iterable<T> values, + Coder<T> coder) { + pcollections.put((PValue) getOutput(transform), new RDDHolder<>(values, coder)); + } + + void setPView(PValue view, Iterable<? extends WindowedValue<?>> value) { + pview.put(view, value); + } + + protected boolean hasOutputRDD(PTransform<? extends PInput, ?> transform) { + PValue pvalue = (PValue) getOutput(transform); + return pcollections.containsKey(pvalue); + } + + protected JavaRDDLike<?, ?> getRDD(PValue pvalue) { + RDDHolder<?> rddHolder = pcollections.get(pvalue); + JavaRDDLike<?, ?> rdd = rddHolder.getRDD(); + leafRdds.remove(rddHolder); + if (multireads.contains(pvalue)) { + // Ensure the RDD is marked as cached + rdd.rdd().cache(); + } else { + multireads.add(pvalue); + } + return rdd; + } + + protected <T> void setRDD(PValue pvalue, JavaRDDLike<WindowedValue<T>, ?> rdd) { + try { + rdd.rdd().setName(pvalue.getName()); + } catch (IllegalStateException e) { + // name not set, ignore + } + RDDHolder<T> rddHolder = new RDDHolder<>(rdd); + pcollections.put(pvalue, rddHolder); + leafRdds.add(rddHolder); + } + + JavaRDDLike<?, ?> getInputRDD(PTransform<? extends PInput, ?> transform) { + return getRDD((PValue) getInput(transform)); + } + + + <T> Iterable<? extends WindowedValue<?>> getPCollectionView(PCollectionView<T> view) { + return pview.get(view); + } + + /** + * Computes the outputs for all RDDs that are leaves in the DAG and do not have any + * actions (like saving to a file) registered on them (i.e. they are performed for side + * effects). + */ + protected void computeOutputs() { + for (RDDHolder<?> rddHolder : leafRdds) { + JavaRDDLike<?, ?> rdd = rddHolder.getRDD(); + rdd.rdd().cache(); // cache so that any subsequent get() is cheap + rdd.count(); // force the RDD to be computed + } + } + + @Override + public <T> T get(PValue value) { + if (pobjects.containsKey(value)) { + @SuppressWarnings("unchecked") + T result = (T) pobjects.get(value); + return result; + } + if (pcollections.containsKey(value)) { + JavaRDDLike<?, ?> rdd = pcollections.get(value).getRDD(); + @SuppressWarnings("unchecked") + T res = (T) Iterables.getOnlyElement(rdd.collect()); + pobjects.put(value, res); + return res; + } + throw new IllegalStateException("Cannot resolve un-known PObject: " + value); + } + + @Override + public <T> T getAggregatorValue(String named, Class<T> resultType) { + return runtime.getAggregatorValue(named, resultType); + } + + @Override + public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator) + throws AggregatorRetrievalException { + return runtime.getAggregatorValues(aggregator); + } + + @Override + public <T> Iterable<T> get(PCollection<T> pcollection) { + @SuppressWarnings("unchecked") + RDDHolder<T> rddHolder = (RDDHolder<T>) pcollections.get(pcollection); + return rddHolder.getValues(pcollection); + } + + <T> Iterable<WindowedValue<T>> getWindowedValues(PCollection<T> pcollection) { + @SuppressWarnings("unchecked") + RDDHolder<T> rddHolder = (RDDHolder<T>) pcollections.get(pcollection); + return rddHolder.getWindowedValues(pcollection); + } + + @Override + public void close() { + SparkContextFactory.stopSparkContext(jsc); + } + + /** The runner is blocking. */ + @Override + public State getState() { + return State.DONE; + } +}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationResult.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationResult.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationResult.java new file mode 100644 index 0000000..4de97f6 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/EvaluationResult.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark; + +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PValue; + +/** + * Interface for retrieving the result(s) of running a pipeline. Allows us to translate between + * {@code PObject<T>}s or {@code PCollection<T>}s and Ts or collections of Ts. + */ +public interface EvaluationResult extends PipelineResult { + /** + * Retrieves an iterable of results associated with the PCollection passed in. + * + * @param pcollection Collection we wish to translate. + * @param <T> Type of elements contained in collection. + * @return Natively types result associated with collection. + */ + <T> Iterable<T> get(PCollection<T> pcollection); + + /** + * Retrieve an object of Type T associated with the PValue passed in. + * + * @param pval PValue to retrieve associated data for. + * @param <T> Type of object to return. + * @return Native object. + */ + <T> T get(PValue pval); + + /** + * Retrieves the final value of the aggregator. + * + * @param aggName name of aggregator. + * @param resultType Class of final result of aggregation. + * @param <T> Type of final result of aggregation. + * @return Result of aggregation associated with specified name. + */ + <T> T getAggregatorValue(String aggName, Class<T> resultType); + + /** + * Releases any runtime resources, including distributed-execution contexts currently held by + * this EvaluationResult; once close() has been called, + * {@link EvaluationResult#get(PCollection)} might + * not work for subsequent calls. + */ + void close(); +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/MultiDoFnFunction.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/MultiDoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/MultiDoFnFunction.java new file mode 100644 index 0000000..968825b --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/MultiDoFnFunction.java @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark; + +import java.util.Iterator; +import java.util.Map; + +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.base.Function; +import com.google.common.collect.Iterators; +import com.google.common.collect.LinkedListMultimap; +import com.google.common.collect.Multimap; +import org.apache.beam.runners.spark.util.BroadcastHelper; +import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.joda.time.Instant; +import scala.Tuple2; + +/** + * DoFunctions ignore side outputs. MultiDoFunctions deal with side outputs by enriching the + * underlying data with multiple TupleTags. + * + * @param <I> Input type for DoFunction. + * @param <O> Output type for DoFunction. + */ +class MultiDoFnFunction<I, O> + implements PairFlatMapFunction<Iterator<WindowedValue<I>>, TupleTag<?>, WindowedValue<?>> { + private final DoFn<I, O> mFunction; + private final SparkRuntimeContext mRuntimeContext; + private final TupleTag<O> mMainOutputTag; + private final Map<TupleTag<?>, BroadcastHelper<?>> mSideInputs; + + MultiDoFnFunction( + DoFn<I, O> fn, + SparkRuntimeContext runtimeContext, + TupleTag<O> mainOutputTag, + Map<TupleTag<?>, BroadcastHelper<?>> sideInputs) { + this.mFunction = fn; + this.mRuntimeContext = runtimeContext; + this.mMainOutputTag = mainOutputTag; + this.mSideInputs = sideInputs; + } + + @Override + public Iterable<Tuple2<TupleTag<?>, WindowedValue<?>>> + call(Iterator<WindowedValue<I>> iter) throws Exception { + ProcCtxt ctxt = new ProcCtxt(mFunction, mRuntimeContext, mSideInputs); + mFunction.startBundle(ctxt); + ctxt.setup(); + return ctxt.getOutputIterable(iter, mFunction); + } + + private class ProcCtxt extends SparkProcessContext<I, O, Tuple2<TupleTag<?>, WindowedValue<?>>> { + + private final Multimap<TupleTag<?>, WindowedValue<?>> outputs = LinkedListMultimap.create(); + + ProcCtxt(DoFn<I, O> fn, SparkRuntimeContext runtimeContext, Map<TupleTag<?>, + BroadcastHelper<?>> sideInputs) { + super(fn, runtimeContext, sideInputs); + } + + @Override + public synchronized void output(O o) { + outputs.put(mMainOutputTag, windowedValue.withValue(o)); + } + + @Override + public synchronized void output(WindowedValue<O> o) { + outputs.put(mMainOutputTag, o); + } + + @Override + public synchronized <T> void sideOutput(TupleTag<T> tag, T t) { + outputs.put(tag, windowedValue.withValue(t)); + } + + @Override + public <T> void sideOutputWithTimestamp(TupleTag<T> tupleTag, T t, Instant instant) { + outputs.put(tupleTag, WindowedValue.of(t, instant, + windowedValue.getWindows(), windowedValue.getPane())); + } + + @Override + protected void clearOutput() { + outputs.clear(); + } + + @Override + protected Iterator<Tuple2<TupleTag<?>, WindowedValue<?>>> getOutputIterator() { + return Iterators.transform(outputs.entries().iterator(), + new Function<Map.Entry<TupleTag<?>, WindowedValue<?>>, + Tuple2<TupleTag<?>, WindowedValue<?>>>() { + @Override + public Tuple2<TupleTag<?>, WindowedValue<?>> apply(Map.Entry<TupleTag<?>, + WindowedValue<?>> input) { + return new Tuple2<TupleTag<?>, WindowedValue<?>>(input.getKey(), input.getValue()); + } + }); + } + + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkContextFactory.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkContextFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkContextFactory.java new file mode 100644 index 0000000..10b7369 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkContextFactory.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.serializer.KryoSerializer; + +final class SparkContextFactory { + + /** + * If the property {@code dataflow.spark.test.reuseSparkContext} is set to + * {@code true} then the Spark context will be reused for dataflow pipelines. + * This property should only be enabled for tests. + */ + static final String TEST_REUSE_SPARK_CONTEXT = + "dataflow.spark.test.reuseSparkContext"; + private static JavaSparkContext sparkContext; + private static String sparkMaster; + + private SparkContextFactory() { + } + + static synchronized JavaSparkContext getSparkContext(String master, String appName) { + if (Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT)) { + if (sparkContext == null) { + sparkContext = createSparkContext(master, appName); + sparkMaster = master; + } else if (!master.equals(sparkMaster)) { + throw new IllegalArgumentException(String.format("Cannot reuse spark context " + + "with different spark master URL. Existing: %s, requested: %s.", + sparkMaster, master)); + } + return sparkContext; + } else { + return createSparkContext(master, appName); + } + } + + static synchronized void stopSparkContext(JavaSparkContext context) { + if (!Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT)) { + context.stop(); + } + } + + private static JavaSparkContext createSparkContext(String master, String appName) { + SparkConf conf = new SparkConf(); + conf.setMaster(master); + conf.setAppName(appName); + conf.set("spark.serializer", KryoSerializer.class.getCanonicalName()); + return new JavaSparkContext(conf); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineEvaluator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineEvaluator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineEvaluator.java new file mode 100644 index 0000000..913e5a1 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineEvaluator.java @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark; + +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; + +/** + * Pipeline {@link SparkPipelineRunner.Evaluator} for Spark. + */ +public final class SparkPipelineEvaluator extends SparkPipelineRunner.Evaluator { + + private final EvaluationContext ctxt; + + public SparkPipelineEvaluator(EvaluationContext ctxt, SparkPipelineTranslator translator) { + super(translator); + this.ctxt = ctxt; + } + + @Override + protected <PT extends PTransform<? super PInput, POutput>> void doVisitTransform(TransformTreeNode + node) { + @SuppressWarnings("unchecked") + PT transform = (PT) node.getTransform(); + @SuppressWarnings("unchecked") + Class<PT> transformClass = (Class<PT>) (Class<?>) transform.getClass(); + @SuppressWarnings("unchecked") TransformEvaluator<PT> evaluator = + (TransformEvaluator<PT>) translator.translate(transformClass); + LOG.info("Evaluating {}", transform); + AppliedPTransform<PInput, POutput, PT> appliedTransform = + AppliedPTransform.of(node.getFullName(), node.getInput(), node.getOutput(), transform); + ctxt.setCurrentTransform(appliedTransform); + evaluator.evaluate(transform, ctxt); + ctxt.setCurrentTransform(null); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java new file mode 100644 index 0000000..1a5093b --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark; + +import com.google.cloud.dataflow.sdk.options.ApplicationNameOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.StreamingOptions; + +public interface SparkPipelineOptions extends PipelineOptions, StreamingOptions, + ApplicationNameOptions { + @Description("The url of the spark master to connect to, (e.g. spark://host:port, local[4]).") + @Default.String("local[1]") + String getSparkMaster(); + + void setSparkMaster(String master); + + @Override + @Default.Boolean(false) + boolean isStreaming(); + + @Override + @Default.String("spark dataflow pipeline job") + String getAppName(); +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsFactory.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsFactory.java new file mode 100644 index 0000000..7b44ee4 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsFactory.java @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark; + +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; + +public final class SparkPipelineOptionsFactory { + private SparkPipelineOptionsFactory() { + } + + public static SparkPipelineOptions create() { + return PipelineOptionsFactory.as(SparkPipelineOptions.class); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsRegistrar.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsRegistrar.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsRegistrar.java new file mode 100644 index 0000000..9f7f8c1 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptionsRegistrar.java @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsRegistrar; +import com.google.common.collect.ImmutableList; + +public class SparkPipelineOptionsRegistrar implements PipelineOptionsRegistrar { + @Override + public Iterable<Class<? extends PipelineOptions>> getPipelineOptions() { + return ImmutableList.<Class<? extends PipelineOptions>>of(SparkPipelineOptions.class); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java new file mode 100644 index 0000000..429750d --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java @@ -0,0 +1,252 @@ +/* + * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsValidator; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; + +import org.apache.beam.runners.spark.streaming.SparkStreamingPipelineOptions; +import org.apache.beam.runners.spark.streaming.StreamingEvaluationContext; +import org.apache.beam.runners.spark.streaming.StreamingTransformTranslator; +import org.apache.beam.runners.spark.streaming.StreamingWindowPipelineDetector; + +import org.apache.spark.SparkException; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The SparkPipelineRunner translate operations defined on a pipeline to a representation + * executable by Spark, and then submitting the job to Spark to be executed. If we wanted to run + * a dataflow pipeline with the default options of a single threaded spark instance in local mode, + * we would do the following: + * + * {@code + * Pipeline p = [logic for pipeline creation] + * EvaluationResult result = SparkPipelineRunner.create().run(p); + * } + * + * To create a pipeline runner to run against a different spark cluster, with a custom master url + * we would do the following: + * + * {@code + * Pipeline p = [logic for pipeline creation] + * SparkPipelineOptions options = SparkPipelineOptionsFactory.create(); + * options.setSparkMaster("spark://host:port"); + * EvaluationResult result = SparkPipelineRunner.create(options).run(p); + * } + * + * To create a Spark streaming pipeline runner use {@link SparkStreamingPipelineOptions} + */ +public final class SparkPipelineRunner extends PipelineRunner<EvaluationResult> { + + private static final Logger LOG = LoggerFactory.getLogger(SparkPipelineRunner.class); + /** + * Options used in this pipeline runner. + */ + private final SparkPipelineOptions mOptions; + + /** + * Creates and returns a new SparkPipelineRunner with default options. In particular, against a + * spark instance running in local mode. + * + * @return A pipeline runner with default options. + */ + public static SparkPipelineRunner create() { + SparkPipelineOptions options = SparkPipelineOptionsFactory.create(); + return new SparkPipelineRunner(options); + } + + /** + * Creates and returns a new SparkPipelineRunner with specified options. + * + * @param options The SparkPipelineOptions to use when executing the job. + * @return A pipeline runner that will execute with specified options. + */ + public static SparkPipelineRunner create(SparkPipelineOptions options) { + return new SparkPipelineRunner(options); + } + + /** + * Creates and returns a new SparkPipelineRunner with specified options. + * + * @param options The PipelineOptions to use when executing the job. + * @return A pipeline runner that will execute with specified options. + */ + public static SparkPipelineRunner fromOptions(PipelineOptions options) { + SparkPipelineOptions sparkOptions = + PipelineOptionsValidator.validate(SparkPipelineOptions.class, options); + return new SparkPipelineRunner(sparkOptions); + } + + /** + * No parameter constructor defaults to running this pipeline in Spark's local mode, in a single + * thread. + */ + private SparkPipelineRunner(SparkPipelineOptions options) { + mOptions = options; + } + + + @Override + public EvaluationResult run(Pipeline pipeline) { + try { + // validate streaming configuration + if (mOptions.isStreaming() && !(mOptions instanceof SparkStreamingPipelineOptions)) { + throw new RuntimeException("A streaming job must be configured with " + + SparkStreamingPipelineOptions.class.getSimpleName() + ", found " + + mOptions.getClass().getSimpleName()); + } + LOG.info("Executing pipeline using the SparkPipelineRunner."); + JavaSparkContext jsc = SparkContextFactory.getSparkContext(mOptions + .getSparkMaster(), mOptions.getAppName()); + + if (mOptions.isStreaming()) { + SparkPipelineTranslator translator = + new StreamingTransformTranslator.Translator(new TransformTranslator.Translator()); + // if streaming - fixed window should be defined on all UNBOUNDED inputs + StreamingWindowPipelineDetector streamingWindowPipelineDetector = + new StreamingWindowPipelineDetector(translator); + pipeline.traverseTopologically(streamingWindowPipelineDetector); + if (!streamingWindowPipelineDetector.isWindowing()) { + throw new IllegalStateException("Spark streaming pipeline must be windowed!"); + } + + Duration batchInterval = streamingWindowPipelineDetector.getBatchDuration(); + LOG.info("Setting Spark streaming batchInterval to {} msec", batchInterval.milliseconds()); + EvaluationContext ctxt = createStreamingEvaluationContext(jsc, pipeline, batchInterval); + + pipeline.traverseTopologically(new SparkPipelineEvaluator(ctxt, translator)); + ctxt.computeOutputs(); + + LOG.info("Streaming pipeline construction complete. Starting execution.."); + ((StreamingEvaluationContext) ctxt).getStreamingContext().start(); + + return ctxt; + } else { + EvaluationContext ctxt = new EvaluationContext(jsc, pipeline); + SparkPipelineTranslator translator = new TransformTranslator.Translator(); + pipeline.traverseTopologically(new SparkPipelineEvaluator(ctxt, translator)); + ctxt.computeOutputs(); + + LOG.info("Pipeline execution complete."); + + return ctxt; + } + } catch (Exception e) { + // Scala doesn't declare checked exceptions in the bytecode, and the Java compiler + // won't let you catch something that is not declared, so we can't catch + // SparkException here. Instead we do an instanceof check. + // Then we find the cause by seeing if it's a user exception (wrapped by our + // SparkProcessException), or just use the SparkException cause. + if (e instanceof SparkException && e.getCause() != null) { + if (e.getCause() instanceof SparkProcessContext.SparkProcessException && + e.getCause().getCause() != null) { + throw new RuntimeException(e.getCause().getCause()); + } else { + throw new RuntimeException(e.getCause()); + } + } + // otherwise just wrap in a RuntimeException + throw new RuntimeException(e); + } + } + + private EvaluationContext + createStreamingEvaluationContext(JavaSparkContext jsc, Pipeline pipeline, + Duration batchDuration) { + SparkStreamingPipelineOptions streamingOptions = (SparkStreamingPipelineOptions) mOptions; + JavaStreamingContext jssc = new JavaStreamingContext(jsc, batchDuration); + return new StreamingEvaluationContext(jsc, pipeline, jssc, streamingOptions.getTimeout()); + } + + public abstract static class Evaluator implements Pipeline.PipelineVisitor { + protected static final Logger LOG = LoggerFactory.getLogger(Evaluator.class); + + protected final SparkPipelineTranslator translator; + + protected Evaluator(SparkPipelineTranslator translator) { + this.translator = translator; + } + + // Set upon entering a composite node which can be directly mapped to a single + // TransformEvaluator. + private TransformTreeNode currentTranslatedCompositeNode; + + /** + * If true, we're currently inside a subtree of a composite node which directly maps to a + * single + * TransformEvaluator; children nodes are ignored, and upon post-visiting the translated + * composite node, the associated TransformEvaluator will be visited. + */ + private boolean inTranslatedCompositeNode() { + return currentTranslatedCompositeNode != null; + } + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + if (!inTranslatedCompositeNode() && node.getTransform() != null) { + @SuppressWarnings("unchecked") + Class<PTransform<?, ?>> transformClass = + (Class<PTransform<?, ?>>) node.getTransform().getClass(); + if (translator.hasTranslation(transformClass)) { + LOG.info("Entering directly-translatable composite transform: '{}'", node.getFullName()); + LOG.debug("Composite transform class: '{}'", transformClass); + currentTranslatedCompositeNode = node; + } + } + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + // NB: We depend on enterCompositeTransform and leaveCompositeTransform providing 'node' + // objects for which Object.equals() returns true iff they are the same logical node + // within the tree. + if (inTranslatedCompositeNode() && node.equals(currentTranslatedCompositeNode)) { + LOG.info("Post-visiting directly-translatable composite transform: '{}'", + node.getFullName()); + doVisitTransform(node); + currentTranslatedCompositeNode = null; + } + } + + @Override + public void visitTransform(TransformTreeNode node) { + if (inTranslatedCompositeNode()) { + LOG.info("Skipping '{}'; already in composite transform.", node.getFullName()); + return; + } + doVisitTransform(node); + } + + protected abstract <PT extends PTransform<? super PInput, POutput>> void + doVisitTransform(TransformTreeNode node); + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + } + } +} + http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunnerRegistrar.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunnerRegistrar.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunnerRegistrar.java new file mode 100644 index 0000000..9a84370 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunnerRegistrar.java @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark; + +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunnerRegistrar; +import com.google.common.collect.ImmutableList; + +public class SparkPipelineRunnerRegistrar implements PipelineRunnerRegistrar { + @Override + public Iterable<Class<? extends PipelineRunner<?>>> getPipelineRunners() { + return ImmutableList.<Class<? extends PipelineRunner<?>>>of(SparkPipelineRunner.class); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineTranslator.java new file mode 100644 index 0000000..e45491a --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineTranslator.java @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ +package org.apache.beam.runners.spark; + +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +/** + * Translator to support translation between Dataflow transformations and Spark transformations. + */ +public interface SparkPipelineTranslator { + + boolean hasTranslation(Class<? extends PTransform<?, ?>> clazz); + + <PT extends PTransform<?, ?>> TransformEvaluator<PT> translate(Class<PT> clazz); +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkProcessContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkProcessContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkProcessContext.java new file mode 100644 index 0000000..c634152 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkProcessContext.java @@ -0,0 +1,257 @@ +/* + * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark; + +import java.io.IOException; +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.TimerInternals; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingInternals; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.Iterables; + +import org.apache.beam.runners.spark.util.BroadcastHelper; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +abstract class SparkProcessContext<I, O, V> extends DoFn<I, O>.ProcessContext { + + private static final Logger LOG = LoggerFactory.getLogger(SparkProcessContext.class); + + private final DoFn<I, O> fn; + private final SparkRuntimeContext mRuntimeContext; + private final Map<TupleTag<?>, BroadcastHelper<?>> mSideInputs; + + protected WindowedValue<I> windowedValue; + + SparkProcessContext(DoFn<I, O> fn, + SparkRuntimeContext runtime, + Map<TupleTag<?>, BroadcastHelper<?>> sideInputs) { + fn.super(); + this.fn = fn; + this.mRuntimeContext = runtime; + this.mSideInputs = sideInputs; + } + + void setup() { + setupDelegateAggregators(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return mRuntimeContext.getPipelineOptions(); + } + + @Override + public <T> T sideInput(PCollectionView<T> view) { + @SuppressWarnings("unchecked") + BroadcastHelper<Iterable<WindowedValue<?>>> broadcastHelper = + (BroadcastHelper<Iterable<WindowedValue<?>>>) mSideInputs.get(view.getTagInternal()); + Iterable<WindowedValue<?>> contents = broadcastHelper.getValue(); + return view.fromIterableInternal(contents); + } + + @Override + public abstract void output(O output); + + public abstract void output(WindowedValue<O> output); + + @Override + public <T> void sideOutput(TupleTag<T> tupleTag, T t) { + String message = "sideOutput is an unsupported operation for doFunctions, use a " + + "MultiDoFunction instead."; + LOG.warn(message); + throw new UnsupportedOperationException(message); + } + + @Override + public <T> void sideOutputWithTimestamp(TupleTag<T> tupleTag, T t, Instant instant) { + String message = + "sideOutputWithTimestamp is an unsupported operation for doFunctions, use a " + + "MultiDoFunction instead."; + LOG.warn(message); + throw new UnsupportedOperationException(message); + } + + @Override + public <AI, AO> Aggregator<AI, AO> createAggregatorInternal( + String named, + Combine.CombineFn<AI, ?, AO> combineFn) { + return mRuntimeContext.createAggregator(named, combineFn); + } + + @Override + public I element() { + return windowedValue.getValue(); + } + + @Override + public void outputWithTimestamp(O output, Instant timestamp) { + output(WindowedValue.of(output, timestamp, + windowedValue.getWindows(), windowedValue.getPane())); + } + + @Override + public Instant timestamp() { + return windowedValue.getTimestamp(); + } + + @Override + public BoundedWindow window() { + if (!(fn instanceof DoFn.RequiresWindowAccess)) { + throw new UnsupportedOperationException( + "window() is only available in the context of a DoFn marked as RequiresWindow."); + } + return Iterables.getOnlyElement(windowedValue.getWindows()); + } + + @Override + public PaneInfo pane() { + return windowedValue.getPane(); + } + + @Override + public WindowingInternals<I, O> windowingInternals() { + return new WindowingInternals<I, O>() { + + @Override + public Collection<? extends BoundedWindow> windows() { + return windowedValue.getWindows(); + } + + @Override + public void outputWindowedValue(O output, Instant timestamp, Collection<? + extends BoundedWindow> windows, PaneInfo paneInfo) { + output(WindowedValue.of(output, timestamp, windows, paneInfo)); + } + + @Override + public StateInternals stateInternals() { + throw new UnsupportedOperationException( + "WindowingInternals#stateInternals() is not yet supported."); + } + + @Override + public TimerInternals timerInternals() { + throw new UnsupportedOperationException( + "WindowingInternals#timerInternals() is not yet supported."); + } + + @Override + public PaneInfo pane() { + return windowedValue.getPane(); + } + + @Override + public <T> void writePCollectionViewData(TupleTag<?> tag, + Iterable<WindowedValue<T>> data, Coder<T> elemCoder) throws IOException { + throw new UnsupportedOperationException( + "WindowingInternals#writePCollectionViewData() is not yet supported."); + } + + @Override + public <T> T sideInput(PCollectionView<T> view, BoundedWindow mainInputWindow) { + throw new UnsupportedOperationException( + "WindowingInternals#sideInput() is not yet supported."); + } + }; + } + + protected abstract void clearOutput(); + protected abstract Iterator<V> getOutputIterator(); + + protected Iterable<V> getOutputIterable(final Iterator<WindowedValue<I>> iter, + final DoFn<I, O> doFn) { + return new Iterable<V>() { + @Override + public Iterator<V> iterator() { + return new ProcCtxtIterator(iter, doFn); + } + }; + } + + private class ProcCtxtIterator extends AbstractIterator<V> { + + private final Iterator<WindowedValue<I>> inputIterator; + private final DoFn<I, O> doFn; + private Iterator<V> outputIterator; + private boolean calledFinish; + + ProcCtxtIterator(Iterator<WindowedValue<I>> iterator, DoFn<I, O> doFn) { + this.inputIterator = iterator; + this.doFn = doFn; + this.outputIterator = getOutputIterator(); + } + + @Override + protected V computeNext() { + // Process each element from the (input) iterator, which produces, zero, one or more + // output elements (of type V) in the output iterator. Note that the output + // collection (and iterator) is reset between each call to processElement, so the + // collection only holds the output values for each call to processElement, rather + // than for the whole partition (which would use too much memory). + while (true) { + if (outputIterator.hasNext()) { + return outputIterator.next(); + } else if (inputIterator.hasNext()) { + clearOutput(); + windowedValue = inputIterator.next(); + try { + doFn.processElement(SparkProcessContext.this); + } catch (Exception e) { + throw new SparkProcessException(e); + } + outputIterator = getOutputIterator(); + } else { + // no more input to consume, but finishBundle can produce more output + if (!calledFinish) { + clearOutput(); + try { + calledFinish = true; + doFn.finishBundle(SparkProcessContext.this); + } catch (Exception e) { + throw new SparkProcessException(e); + } + outputIterator = getOutputIterator(); + continue; // try to consume outputIterator from start of loop + } + return endOfData(); + } + } + } + } + + static class SparkProcessException extends RuntimeException { + SparkProcessException(Throwable t) { + super(t); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRuntimeContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRuntimeContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRuntimeContext.java new file mode 100644 index 0000000..da48ad7 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRuntimeContext.java @@ -0,0 +1,214 @@ +/* + * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark; + +import java.io.IOException; +import java.io.Serializable; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.AggregatorValues; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Max; +import com.google.cloud.dataflow.sdk.transforms.Min; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.collect.ImmutableList; + +import org.apache.beam.runners.spark.aggregators.AggAccumParam; +import org.apache.beam.runners.spark.aggregators.NamedAggregators; + +import org.apache.spark.Accumulator; +import org.apache.spark.api.java.JavaSparkContext; + + +/** + * The SparkRuntimeContext allows us to define useful features on the client side before our + * data flow program is launched. + */ +public class SparkRuntimeContext implements Serializable { + /** + * An accumulator that is a map from names to aggregators. + */ + private final Accumulator<NamedAggregators> accum; + + private final String serializedPipelineOptions; + + /** + * Map fo names to dataflow aggregators. + */ + private final Map<String, Aggregator<?, ?>> aggregators = new HashMap<>(); + private transient CoderRegistry coderRegistry; + + SparkRuntimeContext(JavaSparkContext jsc, Pipeline pipeline) { + this.accum = jsc.accumulator(new NamedAggregators(), new AggAccumParam()); + this.serializedPipelineOptions = serializePipelineOptions(pipeline.getOptions()); + } + + private static String serializePipelineOptions(PipelineOptions pipelineOptions) { + try { + return new ObjectMapper().writeValueAsString(pipelineOptions); + } catch (JsonProcessingException e) { + throw new IllegalStateException("Failed to serialize the pipeline options.", e); + } + } + + private static PipelineOptions deserializePipelineOptions(String serializedPipelineOptions) { + try { + return new ObjectMapper().readValue(serializedPipelineOptions, PipelineOptions.class); + } catch (IOException e) { + throw new IllegalStateException("Failed to deserialize the pipeline options.", e); + } + } + + /** + * Retrieves corresponding value of an aggregator. + * + * @param aggregatorName Name of the aggregator to retrieve the value of. + * @param typeClass Type class of value to be retrieved. + * @param <T> Type of object to be returned. + * @return The value of the aggregator. + */ + public <T> T getAggregatorValue(String aggregatorName, Class<T> typeClass) { + return accum.value().getValue(aggregatorName, typeClass); + } + + public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator) { + @SuppressWarnings("unchecked") + Class<T> aggValueClass = (Class<T>) aggregator.getCombineFn().getOutputType().getRawType(); + final T aggregatorValue = getAggregatorValue(aggregator.getName(), aggValueClass); + return new AggregatorValues<T>() { + @Override + public Collection<T> getValues() { + return ImmutableList.of(aggregatorValue); + } + + @Override + public Map<String, T> getValuesAtSteps() { + throw new UnsupportedOperationException("getValuesAtSteps is not supported."); + } + }; + } + + public synchronized PipelineOptions getPipelineOptions() { + return deserializePipelineOptions(serializedPipelineOptions); + } + + /** + * Creates and aggregator and associates it with the specified name. + * + * @param named Name of aggregator. + * @param combineFn Combine function used in aggregation. + * @param <IN> Type of inputs to aggregator. + * @param <INTER> Intermediate data type + * @param <OUT> Type of aggregator outputs. + * @return Specified aggregator + */ + public synchronized <IN, INTER, OUT> Aggregator<IN, OUT> createAggregator( + String named, + Combine.CombineFn<? super IN, INTER, OUT> combineFn) { + @SuppressWarnings("unchecked") + Aggregator<IN, OUT> aggregator = (Aggregator<IN, OUT>) aggregators.get(named); + if (aggregator == null) { + @SuppressWarnings("unchecked") + NamedAggregators.CombineFunctionState<IN, INTER, OUT> state = + new NamedAggregators.CombineFunctionState<>( + (Combine.CombineFn<IN, INTER, OUT>) combineFn, + (Coder<IN>) getCoder(combineFn), + this); + accum.add(new NamedAggregators(named, state)); + aggregator = new SparkAggregator<>(named, state); + aggregators.put(named, aggregator); + } + return aggregator; + } + + public CoderRegistry getCoderRegistry() { + if (coderRegistry == null) { + coderRegistry = new CoderRegistry(); + coderRegistry.registerStandardCoders(); + } + return coderRegistry; + } + + private Coder<?> getCoder(Combine.CombineFn<?, ?, ?> combiner) { + try { + if (combiner.getClass() == Sum.SumIntegerFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Integer.class)); + } else if (combiner.getClass() == Sum.SumLongFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Long.class)); + } else if (combiner.getClass() == Sum.SumDoubleFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Double.class)); + } else if (combiner.getClass() == Min.MinIntegerFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Integer.class)); + } else if (combiner.getClass() == Min.MinLongFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Long.class)); + } else if (combiner.getClass() == Min.MinDoubleFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Double.class)); + } else if (combiner.getClass() == Max.MaxIntegerFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Integer.class)); + } else if (combiner.getClass() == Max.MaxLongFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Long.class)); + } else if (combiner.getClass() == Max.MaxDoubleFn.class) { + return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Double.class)); + } else { + throw new IllegalArgumentException("unsupported combiner in Aggregator: " + + combiner.getClass().getName()); + } + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Could not determine default coder for combiner", e); + } + } + + /** + * Initialize spark aggregators exactly once. + * + * @param <IN> Type of element fed in to aggregator. + */ + private static class SparkAggregator<IN, OUT> implements Aggregator<IN, OUT>, Serializable { + private final String name; + private final NamedAggregators.State<IN, ?, OUT> state; + + SparkAggregator(String name, NamedAggregators.State<IN, ?, OUT> state) { + this.name = name; + this.state = state; + } + + @Override + public String getName() { + return name; + } + + @Override + public void addValue(IN elem) { + state.update(elem); + } + + @Override + public Combine.CombineFn<IN, ?, OUT> getCombineFn() { + return state.getCombineFn(); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/TransformEvaluator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TransformEvaluator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TransformEvaluator.java new file mode 100644 index 0000000..8aaceeb --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TransformEvaluator.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark; + +import java.io.Serializable; + +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +public interface TransformEvaluator<PT extends PTransform<?, ?>> extends Serializable { + void evaluate(PT transform, EvaluationContext context); +}
