http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/TransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TransformTranslator.java new file mode 100644 index 0000000..e64f89a --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TransformTranslator.java @@ -0,0 +1,805 @@ +/* + * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark; + +import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputDirectory; +import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputFilePrefix; +import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputFileTemplate; +import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.replaceShardCount; + +import java.io.IOException; +import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import com.google.api.client.util.Maps; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.io.AvroIO; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.AssignWindowsDoFn; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.ImmutableMap; + +import org.apache.avro.mapred.AvroKey; +import org.apache.avro.mapreduce.AvroJob; +import org.apache.avro.mapreduce.AvroKeyInputFormat; +import org.apache.beam.runners.spark.coders.CoderHelpers; +import org.apache.beam.runners.spark.io.hadoop.HadoopIO; +import org.apache.beam.runners.spark.io.hadoop.ShardNameTemplateHelper; +import org.apache.beam.runners.spark.io.hadoop.TemplatedAvroKeyOutputFormat; +import org.apache.beam.runners.spark.io.hadoop.TemplatedTextOutputFormat; +import org.apache.beam.runners.spark.util.BroadcastHelper; +import org.apache.beam.runners.spark.util.ByteArray; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaRDDLike; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import scala.Tuple2; + +/** + * Supports translation between a DataFlow transform, and Spark's operations on RDDs. + */ +public final class TransformTranslator { + + private TransformTranslator() { + } + + public static class FieldGetter { + private final Map<String, Field> fields; + + public FieldGetter(Class<?> clazz) { + this.fields = Maps.newHashMap(); + for (Field f : clazz.getDeclaredFields()) { + f.setAccessible(true); + this.fields.put(f.getName(), f); + } + } + + public <T> T get(String fieldname, Object value) { + try { + @SuppressWarnings("unchecked") + T fieldValue = (T) fields.get(fieldname).get(value); + return fieldValue; + } catch (IllegalAccessException e) { + throw new IllegalStateException(e); + } + } + } + + private static <T> TransformEvaluator<Flatten.FlattenPCollectionList<T>> flattenPColl() { + return new TransformEvaluator<Flatten.FlattenPCollectionList<T>>() { + @SuppressWarnings("unchecked") + @Override + public void evaluate(Flatten.FlattenPCollectionList<T> transform, EvaluationContext context) { + PCollectionList<T> pcs = context.getInput(transform); + JavaRDD<WindowedValue<T>>[] rdds = new JavaRDD[pcs.size()]; + for (int i = 0; i < rdds.length; i++) { + rdds[i] = (JavaRDD<WindowedValue<T>>) context.getRDD(pcs.get(i)); + } + JavaRDD<WindowedValue<T>> rdd = context.getSparkContext().union(rdds); + context.setOutputRDD(transform, rdd); + } + }; + } + + private static <K, V> TransformEvaluator<GroupByKey.GroupByKeyOnly<K, V>> gbk() { + return new TransformEvaluator<GroupByKey.GroupByKeyOnly<K, V>>() { + @Override + public void evaluate(GroupByKey.GroupByKeyOnly<K, V> transform, EvaluationContext context) { + @SuppressWarnings("unchecked") + JavaRDDLike<WindowedValue<KV<K, V>>, ?> inRDD = + (JavaRDDLike<WindowedValue<KV<K, V>>, ?>) context.getInputRDD(transform); + @SuppressWarnings("unchecked") + KvCoder<K, V> coder = (KvCoder<K, V>) context.getInput(transform).getCoder(); + Coder<K> keyCoder = coder.getKeyCoder(); + Coder<V> valueCoder = coder.getValueCoder(); + + // Use coders to convert objects in the PCollection to byte arrays, so they + // can be transferred over the network for the shuffle. + JavaRDDLike<WindowedValue<KV<K, Iterable<V>>>, ?> outRDD = fromPair( + toPair(inRDD.map(WindowingHelpers.<KV<K, V>>unwindowFunction())) + .mapToPair(CoderHelpers.toByteFunction(keyCoder, valueCoder)) + .groupByKey() + .mapToPair(CoderHelpers.fromByteFunctionIterable(keyCoder, valueCoder))) + // empty windows are OK here, see GroupByKey#evaluateHelper in the SDK + .map(WindowingHelpers.<KV<K, Iterable<V>>>windowFunction()); + context.setOutputRDD(transform, outRDD); + } + }; + } + + private static final FieldGetter GROUPED_FG = new FieldGetter(Combine.GroupedValues.class); + + private static <K, VI, VO> TransformEvaluator<Combine.GroupedValues<K, VI, VO>> grouped() { + return new TransformEvaluator<Combine.GroupedValues<K, VI, VO>>() { + @Override + public void evaluate(Combine.GroupedValues<K, VI, VO> transform, EvaluationContext context) { + Combine.KeyedCombineFn<K, VI, ?, VO> keyed = GROUPED_FG.get("fn", transform); + @SuppressWarnings("unchecked") + JavaRDDLike<WindowedValue<KV<K, Iterable<VI>>>, ?> inRDD = + (JavaRDDLike<WindowedValue<KV<K, Iterable<VI>>>, ?>) context.getInputRDD(transform); + context.setOutputRDD(transform, + inRDD.map(new KVFunction<>(keyed))); + } + }; + } + + private static final FieldGetter COMBINE_GLOBALLY_FG = new FieldGetter(Combine.Globally.class); + + private static <I, A, O> TransformEvaluator<Combine.Globally<I, O>> combineGlobally() { + return new TransformEvaluator<Combine.Globally<I, O>>() { + + @Override + public void evaluate(Combine.Globally<I, O> transform, EvaluationContext context) { + final Combine.CombineFn<I, A, O> globally = COMBINE_GLOBALLY_FG.get("fn", transform); + + @SuppressWarnings("unchecked") + JavaRDDLike<WindowedValue<I>, ?> inRdd = + (JavaRDDLike<WindowedValue<I>, ?>) context.getInputRDD(transform); + + final Coder<I> iCoder = context.getInput(transform).getCoder(); + final Coder<A> aCoder; + try { + aCoder = globally.getAccumulatorCoder( + context.getPipeline().getCoderRegistry(), iCoder); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Could not determine coder for accumulator", e); + } + + // Use coders to convert objects in the PCollection to byte arrays, so they + // can be transferred over the network for the shuffle. + JavaRDD<byte[]> inRddBytes = inRdd + .map(WindowingHelpers.<I>unwindowFunction()) + .map(CoderHelpers.toByteFunction(iCoder)); + + /*A*/ byte[] acc = inRddBytes.aggregate( + CoderHelpers.toByteArray(globally.createAccumulator(), aCoder), + new Function2</*A*/ byte[], /*I*/ byte[], /*A*/ byte[]>() { + @Override + public /*A*/ byte[] call(/*A*/ byte[] ab, /*I*/ byte[] ib) throws Exception { + A a = CoderHelpers.fromByteArray(ab, aCoder); + I i = CoderHelpers.fromByteArray(ib, iCoder); + return CoderHelpers.toByteArray(globally.addInput(a, i), aCoder); + } + }, + new Function2</*A*/ byte[], /*A*/ byte[], /*A*/ byte[]>() { + @Override + public /*A*/ byte[] call(/*A*/ byte[] a1b, /*A*/ byte[] a2b) throws Exception { + A a1 = CoderHelpers.fromByteArray(a1b, aCoder); + A a2 = CoderHelpers.fromByteArray(a2b, aCoder); + // don't use Guava's ImmutableList.of as values may be null + List<A> accumulators = Collections.unmodifiableList(Arrays.asList(a1, a2)); + A merged = globally.mergeAccumulators(accumulators); + return CoderHelpers.toByteArray(merged, aCoder); + } + } + ); + O output = globally.extractOutput(CoderHelpers.fromByteArray(acc, aCoder)); + + Coder<O> coder = context.getOutput(transform).getCoder(); + JavaRDD<byte[]> outRdd = context.getSparkContext().parallelize( + // don't use Guava's ImmutableList.of as output may be null + CoderHelpers.toByteArrays(Collections.singleton(output), coder)); + context.setOutputRDD(transform, outRdd.map(CoderHelpers.fromByteFunction(coder)) + .map(WindowingHelpers.<O>windowFunction())); + } + }; + } + + private static final FieldGetter COMBINE_PERKEY_FG = new FieldGetter(Combine.PerKey.class); + + private static <K, VI, VA, VO> TransformEvaluator<Combine.PerKey<K, VI, VO>> combinePerKey() { + return new TransformEvaluator<Combine.PerKey<K, VI, VO>>() { + @Override + public void evaluate(Combine.PerKey<K, VI, VO> transform, EvaluationContext context) { + final Combine.KeyedCombineFn<K, VI, VA, VO> keyed = + COMBINE_PERKEY_FG.get("fn", transform); + @SuppressWarnings("unchecked") + JavaRDDLike<WindowedValue<KV<K, VI>>, ?> inRdd = + (JavaRDDLike<WindowedValue<KV<K, VI>>, ?>) context.getInputRDD(transform); + + @SuppressWarnings("unchecked") + KvCoder<K, VI> inputCoder = (KvCoder<K, VI>) context.getInput(transform).getCoder(); + Coder<K> keyCoder = inputCoder.getKeyCoder(); + Coder<VI> viCoder = inputCoder.getValueCoder(); + Coder<VA> vaCoder; + try { + vaCoder = keyed.getAccumulatorCoder( + context.getPipeline().getCoderRegistry(), keyCoder, viCoder); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Could not determine coder for accumulator", e); + } + Coder<KV<K, VI>> kviCoder = KvCoder.of(keyCoder, viCoder); + Coder<KV<K, VA>> kvaCoder = KvCoder.of(keyCoder, vaCoder); + + // We need to duplicate K as both the key of the JavaPairRDD as well as inside the value, + // since the functions passed to combineByKey don't receive the associated key of each + // value, and we need to map back into methods in Combine.KeyedCombineFn, which each + // require the key in addition to the VI's and VA's being merged/accumulated. Once Spark + // provides a way to include keys in the arguments of combine/merge functions, we won't + // need to duplicate the keys anymore. + + // Key has to bw windowed in order to group by window as well + JavaPairRDD<WindowedValue<K>, WindowedValue<KV<K, VI>>> inRddDuplicatedKeyPair = + inRdd.mapToPair( + new PairFunction<WindowedValue<KV<K, VI>>, WindowedValue<K>, + WindowedValue<KV<K, VI>>>() { + @Override + public Tuple2<WindowedValue<K>, + WindowedValue<KV<K, VI>>> call(WindowedValue<KV<K, VI>> kv) { + WindowedValue<K> wk = WindowedValue.of(kv.getValue().getKey(), + kv.getTimestamp(), kv.getWindows(), kv.getPane()); + return new Tuple2<>(wk, kv); + } + }); + //-- windowed coders + final WindowedValue.FullWindowedValueCoder<K> wkCoder = + WindowedValue.FullWindowedValueCoder.of(keyCoder, + context.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); + final WindowedValue.FullWindowedValueCoder<KV<K, VI>> wkviCoder = + WindowedValue.FullWindowedValueCoder.of(kviCoder, + context.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); + final WindowedValue.FullWindowedValueCoder<KV<K, VA>> wkvaCoder = + WindowedValue.FullWindowedValueCoder.of(kvaCoder, + context.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); + + // Use coders to convert objects in the PCollection to byte arrays, so they + // can be transferred over the network for the shuffle. + JavaPairRDD<ByteArray, byte[]> inRddDuplicatedKeyPairBytes = inRddDuplicatedKeyPair + .mapToPair(CoderHelpers.toByteFunction(wkCoder, wkviCoder)); + + // The output of combineByKey will be "VA" (accumulator) types rather than "VO" (final + // output types) since Combine.CombineFn only provides ways to merge VAs, and no way + // to merge VOs. + JavaPairRDD</*K*/ ByteArray, /*KV<K, VA>*/ byte[]> accumulatedBytes = + inRddDuplicatedKeyPairBytes.combineByKey( + new Function</*KV<K, VI>*/ byte[], /*KV<K, VA>*/ byte[]>() { + @Override + public /*KV<K, VA>*/ byte[] call(/*KV<K, VI>*/ byte[] input) { + WindowedValue<KV<K, VI>> wkvi = CoderHelpers.fromByteArray(input, wkviCoder); + VA va = keyed.createAccumulator(wkvi.getValue().getKey()); + va = keyed.addInput(wkvi.getValue().getKey(), va, wkvi.getValue().getValue()); + WindowedValue<KV<K, VA>> wkva = + WindowedValue.of(KV.of(wkvi.getValue().getKey(), va), wkvi.getTimestamp(), + wkvi.getWindows(), wkvi.getPane()); + return CoderHelpers.toByteArray(wkva, wkvaCoder); + } + }, + new Function2</*KV<K, VA>*/ byte[], /*KV<K, VI>*/ byte[], /*KV<K, VA>*/ byte[]>() { + @Override + public /*KV<K, VA>*/ byte[] call(/*KV<K, VA>*/ byte[] acc, + /*KV<K, VI>*/ byte[] input) { + WindowedValue<KV<K, VA>> wkva = CoderHelpers.fromByteArray(acc, wkvaCoder); + WindowedValue<KV<K, VI>> wkvi = CoderHelpers.fromByteArray(input, wkviCoder); + VA va = keyed.addInput(wkva.getValue().getKey(), wkva.getValue().getValue(), + wkvi.getValue().getValue()); + wkva = WindowedValue.of(KV.of(wkva.getValue().getKey(), va), wkva.getTimestamp(), + wkva.getWindows(), wkva.getPane()); + return CoderHelpers.toByteArray(wkva, wkvaCoder); + } + }, + new Function2</*KV<K, VA>*/ byte[], /*KV<K, VA>*/ byte[], /*KV<K, VA>*/ byte[]>() { + @Override + public /*KV<K, VA>*/ byte[] call(/*KV<K, VA>*/ byte[] acc1, + /*KV<K, VA>*/ byte[] acc2) { + WindowedValue<KV<K, VA>> wkva1 = CoderHelpers.fromByteArray(acc1, wkvaCoder); + WindowedValue<KV<K, VA>> wkva2 = CoderHelpers.fromByteArray(acc2, wkvaCoder); + VA va = keyed.mergeAccumulators(wkva1.getValue().getKey(), + // don't use Guava's ImmutableList.of as values may be null + Collections.unmodifiableList(Arrays.asList(wkva1.getValue().getValue(), + wkva2.getValue().getValue()))); + WindowedValue<KV<K, VA>> wkva = WindowedValue.of(KV.of(wkva1.getValue().getKey(), + va), wkva1.getTimestamp(), wkva1.getWindows(), wkva1.getPane()); + return CoderHelpers.toByteArray(wkva, wkvaCoder); + } + }); + + JavaPairRDD<WindowedValue<K>, WindowedValue<VO>> extracted = accumulatedBytes + .mapToPair(CoderHelpers.fromByteFunction(wkCoder, wkvaCoder)) + .mapValues( + new Function<WindowedValue<KV<K, VA>>, WindowedValue<VO>>() { + @Override + public WindowedValue<VO> call(WindowedValue<KV<K, VA>> acc) { + return WindowedValue.of(keyed.extractOutput(acc.getValue().getKey(), + acc.getValue().getValue()), acc.getTimestamp(), + acc.getWindows(), acc.getPane()); + } + }); + + context.setOutputRDD(transform, + fromPair(extracted) + .map(new Function<KV<WindowedValue<K>, WindowedValue<VO>>, WindowedValue<KV<K, VO>>>() { + @Override + public WindowedValue<KV<K, VO>> call(KV<WindowedValue<K>, WindowedValue<VO>> kwvo) + throws Exception { + WindowedValue<VO> wvo = kwvo.getValue(); + KV<K, VO> kvo = KV.of(kwvo.getKey().getValue(), wvo.getValue()); + return WindowedValue.of(kvo, wvo.getTimestamp(), wvo.getWindows(), wvo.getPane()); + } + })); + } + }; + } + + private static final class KVFunction<K, VI, VO> + implements Function<WindowedValue<KV<K, Iterable<VI>>>, WindowedValue<KV<K, VO>>> { + private final Combine.KeyedCombineFn<K, VI, ?, VO> keyed; + + KVFunction(Combine.KeyedCombineFn<K, VI, ?, VO> keyed) { + this.keyed = keyed; + } + + @Override + public WindowedValue<KV<K, VO>> call(WindowedValue<KV<K, Iterable<VI>>> windowedKv) + throws Exception { + KV<K, Iterable<VI>> kv = windowedKv.getValue(); + return WindowedValue.of(KV.of(kv.getKey(), keyed.apply(kv.getKey(), kv.getValue())), + windowedKv.getTimestamp(), windowedKv.getWindows(), windowedKv.getPane()); + } + } + + private static <K, V> JavaPairRDD<K, V> toPair(JavaRDDLike<KV<K, V>, ?> rdd) { + return rdd.mapToPair(new PairFunction<KV<K, V>, K, V>() { + @Override + public Tuple2<K, V> call(KV<K, V> kv) { + return new Tuple2<>(kv.getKey(), kv.getValue()); + } + }); + } + + private static <K, V> JavaRDDLike<KV<K, V>, ?> fromPair(JavaPairRDD<K, V> rdd) { + return rdd.map(new Function<Tuple2<K, V>, KV<K, V>>() { + @Override + public KV<K, V> call(Tuple2<K, V> t2) { + return KV.of(t2._1(), t2._2()); + } + }); + } + + private static <I, O> TransformEvaluator<ParDo.Bound<I, O>> parDo() { + return new TransformEvaluator<ParDo.Bound<I, O>>() { + @Override + public void evaluate(ParDo.Bound<I, O> transform, EvaluationContext context) { + DoFnFunction<I, O> dofn = + new DoFnFunction<>(transform.getFn(), + context.getRuntimeContext(), + getSideInputs(transform.getSideInputs(), context)); + @SuppressWarnings("unchecked") + JavaRDDLike<WindowedValue<I>, ?> inRDD = + (JavaRDDLike<WindowedValue<I>, ?>) context.getInputRDD(transform); + context.setOutputRDD(transform, inRDD.mapPartitions(dofn)); + } + }; + } + + private static final FieldGetter MULTIDO_FG = new FieldGetter(ParDo.BoundMulti.class); + + private static <I, O> TransformEvaluator<ParDo.BoundMulti<I, O>> multiDo() { + return new TransformEvaluator<ParDo.BoundMulti<I, O>>() { + @Override + public void evaluate(ParDo.BoundMulti<I, O> transform, EvaluationContext context) { + TupleTag<O> mainOutputTag = MULTIDO_FG.get("mainOutputTag", transform); + MultiDoFnFunction<I, O> multifn = new MultiDoFnFunction<>( + transform.getFn(), + context.getRuntimeContext(), + mainOutputTag, + getSideInputs(transform.getSideInputs(), context)); + + @SuppressWarnings("unchecked") + JavaRDDLike<WindowedValue<I>, ?> inRDD = + (JavaRDDLike<WindowedValue<I>, ?>) context.getInputRDD(transform); + JavaPairRDD<TupleTag<?>, WindowedValue<?>> all = inRDD + .mapPartitionsToPair(multifn) + .cache(); + + PCollectionTuple pct = context.getOutput(transform); + for (Map.Entry<TupleTag<?>, PCollection<?>> e : pct.getAll().entrySet()) { + @SuppressWarnings("unchecked") + JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered = + all.filter(new TupleTagFilter(e.getKey())); + @SuppressWarnings("unchecked") + // Object is the best we can do since different outputs can have different tags + JavaRDD<WindowedValue<Object>> values = + (JavaRDD<WindowedValue<Object>>) (JavaRDD<?>) filtered.values(); + context.setRDD(e.getValue(), values); + } + } + }; + } + + + private static <T> TransformEvaluator<TextIO.Read.Bound<T>> readText() { + return new TransformEvaluator<TextIO.Read.Bound<T>>() { + @Override + public void evaluate(TextIO.Read.Bound<T> transform, EvaluationContext context) { + String pattern = transform.getFilepattern(); + JavaRDD<WindowedValue<String>> rdd = context.getSparkContext().textFile(pattern) + .map(WindowingHelpers.<String>windowFunction()); + context.setOutputRDD(transform, rdd); + } + }; + } + + private static <T> TransformEvaluator<TextIO.Write.Bound<T>> writeText() { + return new TransformEvaluator<TextIO.Write.Bound<T>>() { + @Override + public void evaluate(TextIO.Write.Bound<T> transform, EvaluationContext context) { + @SuppressWarnings("unchecked") + JavaPairRDD<T, Void> last = + ((JavaRDDLike<WindowedValue<T>, ?>) context.getInputRDD(transform)) + .map(WindowingHelpers.<T>unwindowFunction()) + .mapToPair(new PairFunction<T, T, + Void>() { + @Override + public Tuple2<T, Void> call(T t) throws Exception { + return new Tuple2<>(t, null); + } + }); + ShardTemplateInformation shardTemplateInfo = + new ShardTemplateInformation(transform.getNumShards(), + transform.getShardTemplate(), transform.getFilenamePrefix(), + transform.getFilenameSuffix()); + writeHadoopFile(last, new Configuration(), shardTemplateInfo, Text.class, + NullWritable.class, TemplatedTextOutputFormat.class); + } + }; + } + + private static <T> TransformEvaluator<AvroIO.Read.Bound<T>> readAvro() { + return new TransformEvaluator<AvroIO.Read.Bound<T>>() { + @Override + public void evaluate(AvroIO.Read.Bound<T> transform, EvaluationContext context) { + String pattern = transform.getFilepattern(); + JavaSparkContext jsc = context.getSparkContext(); + @SuppressWarnings("unchecked") + JavaRDD<AvroKey<T>> avroFile = (JavaRDD<AvroKey<T>>) (JavaRDD<?>) + jsc.newAPIHadoopFile(pattern, + AvroKeyInputFormat.class, + AvroKey.class, NullWritable.class, + new Configuration()).keys(); + JavaRDD<WindowedValue<T>> rdd = avroFile.map( + new Function<AvroKey<T>, T>() { + @Override + public T call(AvroKey<T> key) { + return key.datum(); + } + }).map(WindowingHelpers.<T>windowFunction()); + context.setOutputRDD(transform, rdd); + } + }; + } + + private static <T> TransformEvaluator<AvroIO.Write.Bound<T>> writeAvro() { + return new TransformEvaluator<AvroIO.Write.Bound<T>>() { + @Override + public void evaluate(AvroIO.Write.Bound<T> transform, EvaluationContext context) { + Job job; + try { + job = Job.getInstance(); + } catch (IOException e) { + throw new IllegalStateException(e); + } + AvroJob.setOutputKeySchema(job, transform.getSchema()); + @SuppressWarnings("unchecked") + JavaPairRDD<AvroKey<T>, NullWritable> last = + ((JavaRDDLike<WindowedValue<T>, ?>) context.getInputRDD(transform)) + .map(WindowingHelpers.<T>unwindowFunction()) + .mapToPair(new PairFunction<T, AvroKey<T>, NullWritable>() { + @Override + public Tuple2<AvroKey<T>, NullWritable> call(T t) throws Exception { + return new Tuple2<>(new AvroKey<>(t), NullWritable.get()); + } + }); + ShardTemplateInformation shardTemplateInfo = + new ShardTemplateInformation(transform.getNumShards(), + transform.getShardTemplate(), transform.getFilenamePrefix(), + transform.getFilenameSuffix()); + writeHadoopFile(last, job.getConfiguration(), shardTemplateInfo, + AvroKey.class, NullWritable.class, TemplatedAvroKeyOutputFormat.class); + } + }; + } + + private static <K, V> TransformEvaluator<HadoopIO.Read.Bound<K, V>> readHadoop() { + return new TransformEvaluator<HadoopIO.Read.Bound<K, V>>() { + @Override + public void evaluate(HadoopIO.Read.Bound<K, V> transform, EvaluationContext context) { + String pattern = transform.getFilepattern(); + JavaSparkContext jsc = context.getSparkContext(); + @SuppressWarnings ("unchecked") + JavaPairRDD<K, V> file = jsc.newAPIHadoopFile(pattern, + transform.getFormatClass(), + transform.getKeyClass(), transform.getValueClass(), + new Configuration()); + JavaRDD<WindowedValue<KV<K, V>>> rdd = + file.map(new Function<Tuple2<K, V>, KV<K, V>>() { + @Override + public KV<K, V> call(Tuple2<K, V> t2) throws Exception { + return KV.of(t2._1(), t2._2()); + } + }).map(WindowingHelpers.<KV<K, V>>windowFunction()); + context.setOutputRDD(transform, rdd); + } + }; + } + + private static <K, V> TransformEvaluator<HadoopIO.Write.Bound<K, V>> writeHadoop() { + return new TransformEvaluator<HadoopIO.Write.Bound<K, V>>() { + @Override + public void evaluate(HadoopIO.Write.Bound<K, V> transform, EvaluationContext context) { + @SuppressWarnings("unchecked") + JavaPairRDD<K, V> last = ((JavaRDDLike<WindowedValue<KV<K, V>>, ?>) context + .getInputRDD(transform)) + .map(WindowingHelpers.<KV<K, V>>unwindowFunction()) + .mapToPair(new PairFunction<KV<K, V>, K, V>() { + @Override + public Tuple2<K, V> call(KV<K, V> t) throws Exception { + return new Tuple2<>(t.getKey(), t.getValue()); + } + }); + ShardTemplateInformation shardTemplateInfo = + new ShardTemplateInformation(transform.getNumShards(), + transform.getShardTemplate(), transform.getFilenamePrefix(), + transform.getFilenameSuffix()); + Configuration conf = new Configuration(); + for (Map.Entry<String, String> e : transform.getConfigurationProperties().entrySet()) { + conf.set(e.getKey(), e.getValue()); + } + writeHadoopFile(last, conf, shardTemplateInfo, + transform.getKeyClass(), transform.getValueClass(), transform.getFormatClass()); + } + }; + } + + private static final class ShardTemplateInformation { + private final int numShards; + private final String shardTemplate; + private final String filenamePrefix; + private final String filenameSuffix; + + private ShardTemplateInformation(int numShards, String shardTemplate, String + filenamePrefix, String filenameSuffix) { + this.numShards = numShards; + this.shardTemplate = shardTemplate; + this.filenamePrefix = filenamePrefix; + this.filenameSuffix = filenameSuffix; + } + + int getNumShards() { + return numShards; + } + + String getShardTemplate() { + return shardTemplate; + } + + String getFilenamePrefix() { + return filenamePrefix; + } + + String getFilenameSuffix() { + return filenameSuffix; + } + } + + private static <K, V> void writeHadoopFile(JavaPairRDD<K, V> rdd, Configuration conf, + ShardTemplateInformation shardTemplateInfo, Class<?> keyClass, Class<?> valueClass, + Class<? extends FileOutputFormat> formatClass) { + int numShards = shardTemplateInfo.getNumShards(); + String shardTemplate = shardTemplateInfo.getShardTemplate(); + String filenamePrefix = shardTemplateInfo.getFilenamePrefix(); + String filenameSuffix = shardTemplateInfo.getFilenameSuffix(); + if (numShards != 0) { + // number of shards was set explicitly, so repartition + rdd = rdd.repartition(numShards); + } + int actualNumShards = rdd.partitions().size(); + String template = replaceShardCount(shardTemplate, actualNumShards); + String outputDir = getOutputDirectory(filenamePrefix, template); + String filePrefix = getOutputFilePrefix(filenamePrefix, template); + String fileTemplate = getOutputFileTemplate(filenamePrefix, template); + + conf.set(ShardNameTemplateHelper.OUTPUT_FILE_PREFIX, filePrefix); + conf.set(ShardNameTemplateHelper.OUTPUT_FILE_TEMPLATE, fileTemplate); + conf.set(ShardNameTemplateHelper.OUTPUT_FILE_SUFFIX, filenameSuffix); + rdd.saveAsNewAPIHadoopFile(outputDir, keyClass, valueClass, formatClass, conf); + } + + private static final FieldGetter WINDOW_FG = new FieldGetter(Window.Bound.class); + + private static <T, W extends BoundedWindow> TransformEvaluator<Window.Bound<T>> window() { + return new TransformEvaluator<Window.Bound<T>>() { + @Override + public void evaluate(Window.Bound<T> transform, EvaluationContext context) { + @SuppressWarnings("unchecked") + JavaRDDLike<WindowedValue<T>, ?> inRDD = + (JavaRDDLike<WindowedValue<T>, ?>) context.getInputRDD(transform); + WindowFn<? super T, W> windowFn = WINDOW_FG.get("windowFn", transform); + if (windowFn instanceof GlobalWindows) { + context.setOutputRDD(transform, inRDD); + } else { + @SuppressWarnings("unchecked") + DoFn<T, T> addWindowsDoFn = new AssignWindowsDoFn<>(windowFn); + DoFnFunction<T, T> dofn = + new DoFnFunction<>(addWindowsDoFn, context.getRuntimeContext(), null); + context.setOutputRDD(transform, inRDD.mapPartitions(dofn)); + } + } + }; + } + + private static <T> TransformEvaluator<Create.Values<T>> create() { + return new TransformEvaluator<Create.Values<T>>() { + @Override + public void evaluate(Create.Values<T> transform, EvaluationContext context) { + Iterable<T> elems = transform.getElements(); + // Use a coder to convert the objects in the PCollection to byte arrays, so they + // can be transferred over the network. + Coder<T> coder = context.getOutput(transform).getCoder(); + context.setOutputRDDFromValues(transform, elems, coder); + } + }; + } + + private static <T> TransformEvaluator<View.AsSingleton<T>> viewAsSingleton() { + return new TransformEvaluator<View.AsSingleton<T>>() { + @Override + public void evaluate(View.AsSingleton<T> transform, EvaluationContext context) { + Iterable<? extends WindowedValue<?>> iter = + context.getWindowedValues(context.getInput(transform)); + context.setPView(context.getOutput(transform), iter); + } + }; + } + + private static <T> TransformEvaluator<View.AsIterable<T>> viewAsIter() { + return new TransformEvaluator<View.AsIterable<T>>() { + @Override + public void evaluate(View.AsIterable<T> transform, EvaluationContext context) { + Iterable<? extends WindowedValue<?>> iter = + context.getWindowedValues(context.getInput(transform)); + context.setPView(context.getOutput(transform), iter); + } + }; + } + + private static <R, W> TransformEvaluator<View.CreatePCollectionView<R, W>> createPCollView() { + return new TransformEvaluator<View.CreatePCollectionView<R, W>>() { + @Override + public void evaluate(View.CreatePCollectionView<R, W> transform, EvaluationContext context) { + Iterable<? extends WindowedValue<?>> iter = + context.getWindowedValues(context.getInput(transform)); + context.setPView(context.getOutput(transform), iter); + } + }; + } + + private static final class TupleTagFilter<V> + implements Function<Tuple2<TupleTag<V>, WindowedValue<?>>, Boolean> { + + private final TupleTag<V> tag; + + private TupleTagFilter(TupleTag<V> tag) { + this.tag = tag; + } + + @Override + public Boolean call(Tuple2<TupleTag<V>, WindowedValue<?>> input) { + return tag.equals(input._1()); + } + } + + private static Map<TupleTag<?>, BroadcastHelper<?>> getSideInputs( + List<PCollectionView<?>> views, + EvaluationContext context) { + if (views == null) { + return ImmutableMap.of(); + } else { + Map<TupleTag<?>, BroadcastHelper<?>> sideInputs = Maps.newHashMap(); + for (PCollectionView<?> view : views) { + Iterable<? extends WindowedValue<?>> collectionView = context.getPCollectionView(view); + Coder<Iterable<WindowedValue<?>>> coderInternal = view.getCoderInternal(); + @SuppressWarnings("unchecked") + BroadcastHelper<?> helper = + BroadcastHelper.create((Iterable<WindowedValue<?>>) collectionView, coderInternal); + //broadcast side inputs + helper.broadcast(context.getSparkContext()); + sideInputs.put(view.getTagInternal(), helper); + } + return sideInputs; + } + } + + private static final Map<Class<? extends PTransform>, TransformEvaluator<?>> EVALUATORS = Maps + .newHashMap(); + + static { + EVALUATORS.put(TextIO.Read.Bound.class, readText()); + EVALUATORS.put(TextIO.Write.Bound.class, writeText()); + EVALUATORS.put(AvroIO.Read.Bound.class, readAvro()); + EVALUATORS.put(AvroIO.Write.Bound.class, writeAvro()); + EVALUATORS.put(HadoopIO.Read.Bound.class, readHadoop()); + EVALUATORS.put(HadoopIO.Write.Bound.class, writeHadoop()); + EVALUATORS.put(ParDo.Bound.class, parDo()); + EVALUATORS.put(ParDo.BoundMulti.class, multiDo()); + EVALUATORS.put(GroupByKey.GroupByKeyOnly.class, gbk()); + EVALUATORS.put(Combine.GroupedValues.class, grouped()); + EVALUATORS.put(Combine.Globally.class, combineGlobally()); + EVALUATORS.put(Combine.PerKey.class, combinePerKey()); + EVALUATORS.put(Flatten.FlattenPCollectionList.class, flattenPColl()); + EVALUATORS.put(Create.Values.class, create()); + EVALUATORS.put(View.AsSingleton.class, viewAsSingleton()); + EVALUATORS.put(View.AsIterable.class, viewAsIter()); + EVALUATORS.put(View.CreatePCollectionView.class, createPCollView()); + EVALUATORS.put(Window.Bound.class, window()); + } + + public static <PT extends PTransform<?, ?>> TransformEvaluator<PT> + getTransformEvaluator(Class<PT> clazz) { + @SuppressWarnings("unchecked") + TransformEvaluator<PT> transform = (TransformEvaluator<PT>) EVALUATORS.get(clazz); + if (transform == null) { + throw new IllegalStateException("No TransformEvaluator registered for " + clazz); + } + return transform; + } + + /** + * Translator matches Dataflow transformation with the appropriate evaluator. + */ + public static class Translator implements SparkPipelineTranslator { + + @Override + public boolean hasTranslation(Class<? extends PTransform<?, ?>> clazz) { + return EVALUATORS.containsKey(clazz); + } + + @Override + public <PT extends PTransform<?, ?>> TransformEvaluator<PT> translate(Class<PT> clazz) { + return getTransformEvaluator(clazz); + } + } +}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/WindowingHelpers.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/WindowingHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/WindowingHelpers.java new file mode 100644 index 0000000..6b904f7 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/WindowingHelpers.java @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark; + +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import org.apache.spark.api.java.function.Function; + +/** + * Helper functions for working with windows. + */ +public final class WindowingHelpers { + private WindowingHelpers() { + } + + /** + * A function for converting a value to a {@link WindowedValue}. The resulting + * {@link WindowedValue} will be in no windows, and will have the default timestamp + * and pane. + * + * @param <T> The type of the object. + * @return A function that accepts an object and returns its {@link WindowedValue}. + */ + public static <T> Function<T, WindowedValue<T>> windowFunction() { + return new Function<T, WindowedValue<T>>() { + @Override + public WindowedValue<T> call(T t) { + return WindowedValue.valueInEmptyWindows(t); + } + }; + } + + /** + * A function for extracting the value from a {@link WindowedValue}. + * + * @param <T> The type of the object. + * @return A function that accepts a {@link WindowedValue} and returns its value. + */ + public static <T> Function<WindowedValue<T>, T> unwindowFunction() { + return new Function<WindowedValue<T>, T>() { + @Override + public T call(WindowedValue<T> t) { + return t.getValue(); + } + }; + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AggAccumParam.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AggAccumParam.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AggAccumParam.java new file mode 100644 index 0000000..a82dbbe --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AggAccumParam.java @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark.aggregators; + +import org.apache.spark.AccumulatorParam; + +public class AggAccumParam implements AccumulatorParam<NamedAggregators> { + @Override + public NamedAggregators addAccumulator(NamedAggregators current, NamedAggregators added) { + return current.merge(added); + } + + @Override + public NamedAggregators addInPlace(NamedAggregators current, NamedAggregators added) { + return addAccumulator(current, added); + } + + @Override + public NamedAggregators zero(NamedAggregators initialValue) { + return new NamedAggregators(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/NamedAggregators.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/NamedAggregators.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/NamedAggregators.java new file mode 100644 index 0000000..2747703 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/NamedAggregators.java @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark.aggregators; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.Map; +import java.util.TreeMap; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.common.collect.ImmutableList; + +import org.apache.beam.runners.spark.SparkRuntimeContext; + +/** + * This class wraps a map of named aggregators. Spark expects that all accumulators be declared + * before a job is launched. Dataflow allows aggregators to be used and incremented on the fly. + * We create a map of named aggregators and instantiate in the the spark context before the job + * is launched. We can then add aggregators on the fly in Spark. + */ +public class NamedAggregators implements Serializable { + /** + * Map from aggregator name to current state. + */ + private final Map<String, State<?, ?, ?>> mNamedAggregators = new TreeMap<>(); + + /** + * Constructs a new NamedAggregators instance. + */ + public NamedAggregators() { + } + + /** + * Constructs a new named aggregators instance that contains a mapping from the specified + * `named` to the associated initial state. + * + * @param name Name of aggregator. + * @param state Associated State. + */ + public NamedAggregators(String name, State<?, ?, ?> state) { + this.mNamedAggregators.put(name, state); + } + + /** + * @param name Name of aggregator to retrieve. + * @param typeClass Type class to cast the value to. + * @param <T> Type to be returned. + * @return the value of the aggregator associated with the specified name + */ + public <T> T getValue(String name, Class<T> typeClass) { + return typeClass.cast(mNamedAggregators.get(name).render()); + } + + /** + * Merges another NamedAggregators instance with this instance. + * + * @param other The other instance of named aggregators ot merge. + * @return This instance of Named aggregators with associated states updated to reflect the + * other instance's aggregators. + */ + public NamedAggregators merge(NamedAggregators other) { + for (Map.Entry<String, State<?, ?, ?>> e : other.mNamedAggregators.entrySet()) { + String key = e.getKey(); + State<?, ?, ?> otherValue = e.getValue(); + State<?, ?, ?> value = mNamedAggregators.get(key); + if (value == null) { + mNamedAggregators.put(key, otherValue); + } else { + mNamedAggregators.put(key, merge(value, otherValue)); + } + } + return this; + } + + /** + * Helper method to merge States whose generic types aren't provably the same, + * so require some casting. + */ + @SuppressWarnings("unchecked") + private static <A, B, C> State<A, B, C> merge(State<?, ?, ?> s1, State<?, ?, ?> s2) { + return ((State<A, B, C>) s1).merge((State<A, B, C>) s2); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + for (Map.Entry<String, State<?, ?, ?>> e : mNamedAggregators.entrySet()) { + sb.append(e.getKey()).append(": ").append(e.getValue().render()); + } + return sb.toString(); + } + + /** + * @param <IN> Input data type + * @param <INTER> Intermediate data type (useful for averages) + * @param <OUT> Output data type + */ + public interface State<IN, INTER, OUT> extends Serializable { + /** + * @param element new element to update state + */ + void update(IN element); + + State<IN, INTER, OUT> merge(State<IN, INTER, OUT> other); + + INTER current(); + + OUT render(); + + Combine.CombineFn<IN, INTER, OUT> getCombineFn(); + } + + /** + * => combineFunction in data flow. + */ + public static class CombineFunctionState<IN, INTER, OUT> implements State<IN, INTER, OUT> { + + private Combine.CombineFn<IN, INTER, OUT> combineFn; + private Coder<IN> inCoder; + private SparkRuntimeContext ctxt; + private transient INTER state; + + public CombineFunctionState( + Combine.CombineFn<IN, INTER, OUT> combineFn, + Coder<IN> inCoder, + SparkRuntimeContext ctxt) { + this.combineFn = combineFn; + this.inCoder = inCoder; + this.ctxt = ctxt; + this.state = combineFn.createAccumulator(); + } + + @Override + public void update(IN element) { + combineFn.addInput(state, element); + } + + @Override + public State<IN, INTER, OUT> merge(State<IN, INTER, OUT> other) { + this.state = combineFn.mergeAccumulators(ImmutableList.of(current(), other.current())); + return this; + } + + @Override + public INTER current() { + return state; + } + + @Override + public OUT render() { + return combineFn.extractOutput(state); + } + + @Override + public Combine.CombineFn<IN, INTER, OUT> getCombineFn() { + return combineFn; + } + + private void writeObject(ObjectOutputStream oos) throws IOException { + oos.writeObject(ctxt); + oos.writeObject(combineFn); + oos.writeObject(inCoder); + try { + combineFn.getAccumulatorCoder(ctxt.getCoderRegistry(), inCoder) + .encode(state, oos, Coder.Context.NESTED); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Could not determine coder for accumulator", e); + } + } + + @SuppressWarnings("unchecked") + private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException { + ctxt = (SparkRuntimeContext) ois.readObject(); + combineFn = (Combine.CombineFn<IN, INTER, OUT>) ois.readObject(); + inCoder = (Coder<IN>) ois.readObject(); + try { + state = combineFn.getAccumulatorCoder(ctxt.getCoderRegistry(), inCoder) + .decode(ois, Coder.Context.NESTED); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Could not determine coder for accumulator", e); + } + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java new file mode 100644 index 0000000..7d75e7d --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark.coders; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.common.collect.Iterables; +import org.apache.beam.runners.spark.util.ByteArray; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import scala.Tuple2; + +/** + * Serialization utility class. + */ +public final class CoderHelpers { + private CoderHelpers() { + } + + /** + * Utility method for serializing an object using the specified coder. + * + * @param value Value to serialize. + * @param coder Coder to serialize with. + * @param <T> type of value that is serialized + * @return Byte array representing serialized object. + */ + public static <T> byte[] toByteArray(T value, Coder<T> coder) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + coder.encode(value, baos, new Coder.Context(true)); + } catch (IOException e) { + throw new IllegalStateException("Error encoding value: " + value, e); + } + return baos.toByteArray(); + } + + /** + * Utility method for serializing a Iterable of values using the specified coder. + * + * @param values Values to serialize. + * @param coder Coder to serialize with. + * @param <T> type of value that is serialized + * @return List of bytes representing serialized objects. + */ + public static <T> List<byte[]> toByteArrays(Iterable<T> values, Coder<T> coder) { + List<byte[]> res = new LinkedList<>(); + for (T value : values) { + res.add(toByteArray(value, coder)); + } + return res; + } + + /** + * Utility method for deserializing a byte array using the specified coder. + * + * @param serialized bytearray to be deserialized. + * @param coder Coder to deserialize with. + * @param <T> Type of object to be returned. + * @return Deserialized object. + */ + public static <T> T fromByteArray(byte[] serialized, Coder<T> coder) { + ByteArrayInputStream bais = new ByteArrayInputStream(serialized); + try { + return coder.decode(bais, new Coder.Context(true)); + } catch (IOException e) { + throw new IllegalStateException("Error decoding bytes for coder: " + coder, e); + } + } + + /** + * A function wrapper for converting an object to a bytearray. + * + * @param coder Coder to serialize with. + * @param <T> The type of the object being serialized. + * @return A function that accepts an object and returns its coder-serialized form. + */ + public static <T> Function<T, byte[]> toByteFunction(final Coder<T> coder) { + return new Function<T, byte[]>() { + @Override + public byte[] call(T t) throws Exception { + return toByteArray(t, coder); + } + }; + } + + /** + * A function wrapper for converting a byte array to an object. + * + * @param coder Coder to deserialize with. + * @param <T> The type of the object being deserialized. + * @return A function that accepts a byte array and returns its corresponding object. + */ + public static <T> Function<byte[], T> fromByteFunction(final Coder<T> coder) { + return new Function<byte[], T>() { + @Override + public T call(byte[] bytes) throws Exception { + return fromByteArray(bytes, coder); + } + }; + } + + /** + * A function wrapper for converting a key-value pair to a byte array pair. + * + * @param keyCoder Coder to serialize keys. + * @param valueCoder Coder to serialize values. + * @param <K> The type of the key being serialized. + * @param <V> The type of the value being serialized. + * @return A function that accepts a key-value pair and returns a pair of byte arrays. + */ + public static <K, V> PairFunction<Tuple2<K, V>, ByteArray, byte[]> toByteFunction( + final Coder<K> keyCoder, final Coder<V> valueCoder) { + return new PairFunction<Tuple2<K, V>, ByteArray, byte[]>() { + @Override + public Tuple2<ByteArray, byte[]> call(Tuple2<K, V> kv) { + return new Tuple2<>(new ByteArray(toByteArray(kv._1(), keyCoder)), toByteArray(kv._2(), + valueCoder)); + } + }; + } + + /** + * A function wrapper for converting a byte array pair to a key-value pair. + * + * @param keyCoder Coder to deserialize keys. + * @param valueCoder Coder to deserialize values. + * @param <K> The type of the key being deserialized. + * @param <V> The type of the value being deserialized. + * @return A function that accepts a pair of byte arrays and returns a key-value pair. + */ + public static <K, V> PairFunction<Tuple2<ByteArray, byte[]>, K, V> fromByteFunction( + final Coder<K> keyCoder, final Coder<V> valueCoder) { + return new PairFunction<Tuple2<ByteArray, byte[]>, K, V>() { + @Override + public Tuple2<K, V> call(Tuple2<ByteArray, byte[]> tuple) { + return new Tuple2<>(fromByteArray(tuple._1().getValue(), keyCoder), + fromByteArray(tuple._2(), valueCoder)); + } + }; + } + + /** + * A function wrapper for converting a byte array pair to a key-value pair, where + * values are {@link Iterable}. + * + * @param keyCoder Coder to deserialize keys. + * @param valueCoder Coder to deserialize values. + * @param <K> The type of the key being deserialized. + * @param <V> The type of the value being deserialized. + * @return A function that accepts a pair of byte arrays and returns a key-value pair. + */ + public static <K, V> PairFunction<Tuple2<ByteArray, Iterable<byte[]>>, K, Iterable<V>> + fromByteFunctionIterable(final Coder<K> keyCoder, final Coder<V> valueCoder) { + return new PairFunction<Tuple2<ByteArray, Iterable<byte[]>>, K, Iterable<V>>() { + @Override + public Tuple2<K, Iterable<V>> call(Tuple2<ByteArray, Iterable<byte[]>> tuple) { + return new Tuple2<>(fromByteArray(tuple._1().getValue(), keyCoder), + Iterables.transform(tuple._2(), new com.google.common.base.Function<byte[], V>() { + @Override + public V apply(byte[] bytes) { + return fromByteArray(bytes, valueCoder); + } + })); + } + }; + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/NullWritableCoder.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/NullWritableCoder.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/NullWritableCoder.java new file mode 100644 index 0000000..5b77e97 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/NullWritableCoder.java @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark.coders; + +import java.io.InputStream; +import java.io.OutputStream; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.google.cloud.dataflow.sdk.coders.Coder; +import org.apache.hadoop.io.NullWritable; + +public final class NullWritableCoder extends WritableCoder<NullWritable> { + private static final long serialVersionUID = 1L; + + @JsonCreator + public static NullWritableCoder of() { + return INSTANCE; + } + + private static final NullWritableCoder INSTANCE = new NullWritableCoder(); + + private NullWritableCoder() { + super(NullWritable.class); + } + + @Override + public void encode(NullWritable value, OutputStream outStream, Context context) { + // nothing to write + } + + @Override + public NullWritable decode(InputStream inStream, Context context) { + return NullWritable.get(); + } + + @Override + public boolean consistentWithEquals() { + return true; + } + + /** + * Returns true since registerByteSizeObserver() runs in constant time. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(NullWritable value, Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(NullWritable value, Context context) { + return 0; + } + + @Override + public void verifyDeterministic() throws Coder.NonDeterministicException { + // NullWritableCoder is deterministic + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/WritableCoder.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/WritableCoder.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/WritableCoder.java new file mode 100644 index 0000000..fa73753 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/WritableCoder.java @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark.coders; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.lang.reflect.InvocationTargetException; +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.io.Writable; + +/** + * A {@code WritableCoder} is a {@link Coder} for a Java class that implements {@link Writable}. + * + * <p> To use, specify the coder type on a PCollection: + * <pre> + * {@code + * PCollection<MyRecord> records = + * foo.apply(...).setCoder(WritableCoder.of(MyRecord.class)); + * } + * </pre> + * + * @param <T> the type of elements handled by this coder + */ +public class WritableCoder<T extends Writable> extends StandardCoder<T> { + private static final long serialVersionUID = 0L; + + /** + * Returns a {@code WritableCoder} instance for the provided element class. + * @param <T> the element type + * @param clazz the element class + * @return a {@code WritableCoder} instance for the provided element class + */ + public static <T extends Writable> WritableCoder<T> of(Class<T> clazz) { + if (clazz.equals(NullWritable.class)) { + @SuppressWarnings("unchecked") + WritableCoder<T> result = (WritableCoder<T>) NullWritableCoder.of(); + return result; + } + return new WritableCoder<>(clazz); + } + + @JsonCreator + @SuppressWarnings("unchecked") + public static WritableCoder<?> of(@JsonProperty("type") String classType) + throws ClassNotFoundException { + Class<?> clazz = Class.forName(classType); + if (!Writable.class.isAssignableFrom(clazz)) { + throw new ClassNotFoundException( + "Class " + classType + " does not implement Writable"); + } + return of((Class<? extends Writable>) clazz); + } + + private final Class<T> type; + + public WritableCoder(Class<T> type) { + this.type = type; + } + + @Override + public void encode(T value, OutputStream outStream, Context context) throws IOException { + value.write(new DataOutputStream(outStream)); + } + + @Override + public T decode(InputStream inStream, Context context) throws IOException { + try { + T t = type.getConstructor().newInstance(); + t.readFields(new DataInputStream(inStream)); + return t; + } catch (NoSuchMethodException | InstantiationException | IllegalAccessException e) { + throw new CoderException("unable to deserialize record", e); + } catch (InvocationTargetException ite) { + throw new CoderException("unable to deserialize record", ite.getCause()); + } + } + + @Override + public List<Coder<?>> getCoderArguments() { + return null; + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + result.put("type", type.getName()); + return result; + } + + @Override + public void verifyDeterministic() throws Coder.NonDeterministicException { + throw new NonDeterministicException(this, + "Hadoop Writable may be non-deterministic."); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/ConsoleIO.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/ConsoleIO.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/ConsoleIO.java new file mode 100644 index 0000000..2ee072a --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/ConsoleIO.java @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ +package org.apache.beam.runners.spark.io; + +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; + +/** + * Print to console. + */ +public final class ConsoleIO { + + private ConsoleIO() { + } + + public static final class Write { + + private Write() { + } + + public static <T> Unbound<T> from() { + return new Unbound<>(10); + } + + public static <T> Unbound<T> from(int num) { + return new Unbound<>(num); + } + + public static class Unbound<T> extends PTransform<PCollection<T>, PDone> { + + private final int num; + + Unbound(int num) { + this.num = num; + } + + public int getNum() { + return num; + } + + @Override + public PDone apply(PCollection<T> input) { + return PDone.in(input.getPipeline()); + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/CreateStream.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/CreateStream.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/CreateStream.java new file mode 100644 index 0000000..c92f8bf --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/CreateStream.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ +package org.apache.beam.runners.spark.io; + +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.common.base.Preconditions; + +/** + * Create an input stream from Queue. + * + * @param <T> stream type + */ +public final class CreateStream<T> { + + private CreateStream() { + } + + /** + * Define the input stream to create from queue. + * + * @param queuedValues defines the input stream + * @param <T> stream type + * @return the queue that defines the input stream + */ + public static <T> QueuedValues<T> fromQueue(Iterable<Iterable<T>> queuedValues) { + return new QueuedValues<>(queuedValues); + } + + public static final class QueuedValues<T> extends PTransform<PInput, PCollection<T>> { + + private final Iterable<Iterable<T>> queuedValues; + + QueuedValues(Iterable<Iterable<T>> queuedValues) { + Preconditions.checkNotNull(queuedValues, + "need to set the queuedValues of an Create.QueuedValues transform"); + this.queuedValues = queuedValues; + } + + public Iterable<Iterable<T>> getQueuedValues() { + return queuedValues; + } + + @Override + public PCollection<T> apply(PInput input) { + // Spark streaming micro batches are bounded by default + return PCollection.createPrimitiveOutputInternal(input.getPipeline(), + WindowingStrategy.globalDefault(), PCollection.IsBounded.UNBOUNDED); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/KafkaIO.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/KafkaIO.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/KafkaIO.java new file mode 100644 index 0000000..9798157 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/KafkaIO.java @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ +package org.apache.beam.runners.spark.io; + +import java.util.Map; +import java.util.Set; + +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.common.base.Preconditions; + +import kafka.serializer.Decoder; + +/** + * Read stream from Kafka. + */ +public final class KafkaIO { + + private KafkaIO() { + } + + public static final class Read { + + private Read() { + } + + /** + * Define the Kafka consumption. + * + * @param keyDecoder {@link Decoder} to decode the Kafka message key + * @param valueDecoder {@link Decoder} to decode the Kafka message value + * @param key Kafka message key Class + * @param value Kafka message value Class + * @param topics Kafka topics to subscribe + * @param kafkaParams map of Kafka parameters + * @param <K> Kafka message key Class type + * @param <V> Kafka message value Class type + * @return KafkaIO Unbound input + */ + public static <K, V> Unbound<K, V> from(Class<? extends Decoder<K>> keyDecoder, + Class<? extends Decoder<V>> valueDecoder, + Class<K> key, + Class<V> value, Set<String> topics, + Map<String, String> kafkaParams) { + return new Unbound<>(keyDecoder, valueDecoder, key, value, topics, kafkaParams); + } + + public static class Unbound<K, V> extends PTransform<PInput, PCollection<KV<K, V>>> { + + private final Class<? extends Decoder<K>> keyDecoderClass; + private final Class<? extends Decoder<V>> valueDecoderClass; + private final Class<K> keyClass; + private final Class<V> valueClass; + private final Set<String> topics; + private final Map<String, String> kafkaParams; + + Unbound(Class<? extends Decoder<K>> keyDecoder, + Class<? extends Decoder<V>> valueDecoder, Class<K> key, + Class<V> value, Set<String> topics, Map<String, String> kafkaParams) { + Preconditions.checkNotNull(keyDecoder, + "need to set the key decoder class of a KafkaIO.Read transform"); + Preconditions.checkNotNull(valueDecoder, + "need to set the value decoder class of a KafkaIO.Read transform"); + Preconditions.checkNotNull(key, + "need to set the key class of aKafkaIO.Read transform"); + Preconditions.checkNotNull(value, + "need to set the value class of a KafkaIO.Read transform"); + Preconditions.checkNotNull(topics, + "need to set the topics of a KafkaIO.Read transform"); + Preconditions.checkNotNull(kafkaParams, + "need to set the kafkaParams of a KafkaIO.Read transform"); + this.keyDecoderClass = keyDecoder; + this.valueDecoderClass = valueDecoder; + this.keyClass = key; + this.valueClass = value; + this.topics = topics; + this.kafkaParams = kafkaParams; + } + + public Class<? extends Decoder<K>> getKeyDecoderClass() { + return keyDecoderClass; + } + + public Class<? extends Decoder<V>> getValueDecoderClass() { + return valueDecoderClass; + } + + public Class<V> getValueClass() { + return valueClass; + } + + public Class<K> getKeyClass() { + return keyClass; + } + + public Set<String> getTopics() { + return topics; + } + + public Map<String, String> getKafkaParams() { + return kafkaParams; + } + + @Override + public PCollection<KV<K, V>> apply(PInput input) { + // Spark streaming micro batches are bounded by default + return PCollection.createPrimitiveOutputInternal(input.getPipeline(), + WindowingStrategy.globalDefault(), PCollection.IsBounded.UNBOUNDED); + } + } + + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/hadoop/HadoopIO.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/hadoop/HadoopIO.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/hadoop/HadoopIO.java new file mode 100644 index 0000000..e8d2aa1 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/hadoop/HadoopIO.java @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ +package org.apache.beam.runners.spark.io.hadoop; + +import java.util.HashMap; +import java.util.Map; + +import com.google.cloud.dataflow.sdk.io.ShardNameTemplate; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.common.base.Preconditions; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; + +public final class HadoopIO { + + private HadoopIO() { + } + + public static final class Read { + + private Read() { + } + + public static <K, V> Bound<K, V> from(String filepattern, + Class<? extends FileInputFormat<K, V>> format, Class<K> key, Class<V> value) { + return new Bound<>(filepattern, format, key, value); + } + + public static class Bound<K, V> extends PTransform<PInput, PCollection<KV<K, V>>> { + + private final String filepattern; + private final Class<? extends FileInputFormat<K, V>> formatClass; + private final Class<K> keyClass; + private final Class<V> valueClass; + + Bound(String filepattern, Class<? extends FileInputFormat<K, V>> format, Class<K> key, + Class<V> value) { + Preconditions.checkNotNull(filepattern, + "need to set the filepattern of an HadoopIO.Read transform"); + Preconditions.checkNotNull(format, + "need to set the format class of an HadoopIO.Read transform"); + Preconditions.checkNotNull(key, + "need to set the key class of an HadoopIO.Read transform"); + Preconditions.checkNotNull(value, + "need to set the value class of an HadoopIO.Read transform"); + this.filepattern = filepattern; + this.formatClass = format; + this.keyClass = key; + this.valueClass = value; + } + + public String getFilepattern() { + return filepattern; + } + + public Class<? extends FileInputFormat<K, V>> getFormatClass() { + return formatClass; + } + + public Class<V> getValueClass() { + return valueClass; + } + + public Class<K> getKeyClass() { + return keyClass; + } + + @Override + public PCollection<KV<K, V>> apply(PInput input) { + return PCollection.createPrimitiveOutputInternal(input.getPipeline(), + WindowingStrategy.globalDefault(), PCollection.IsBounded.BOUNDED); + } + + } + + } + + public static final class Write { + + private Write() { + } + + public static <K, V> Bound<K, V> to(String filenamePrefix, + Class<? extends FileOutputFormat<K, V>> format, Class<K> key, Class<V> value) { + return new Bound<>(filenamePrefix, format, key, value); + } + + public static class Bound<K, V> extends PTransform<PCollection<KV<K, V>>, PDone> { + + /** The filename to write to. */ + private final String filenamePrefix; + /** Suffix to use for each filename. */ + private final String filenameSuffix; + /** Requested number of shards. 0 for automatic. */ + private final int numShards; + /** Shard template string. */ + private final String shardTemplate; + private final Class<? extends FileOutputFormat<K, V>> formatClass; + private final Class<K> keyClass; + private final Class<V> valueClass; + private final Map<String, String> configurationProperties; + + Bound(String filenamePrefix, Class<? extends FileOutputFormat<K, V>> format, + Class<K> key, + Class<V> value) { + this(filenamePrefix, "", 0, ShardNameTemplate.INDEX_OF_MAX, format, key, value, + new HashMap<String, String>()); + } + + Bound(String filenamePrefix, String filenameSuffix, int numShards, + String shardTemplate, Class<? extends FileOutputFormat<K, V>> format, + Class<K> key, Class<V> value, Map<String, String> configurationProperties) { + this.filenamePrefix = filenamePrefix; + this.filenameSuffix = filenameSuffix; + this.numShards = numShards; + this.shardTemplate = shardTemplate; + this.formatClass = format; + this.keyClass = key; + this.valueClass = value; + this.configurationProperties = configurationProperties; + } + + public Bound<K, V> withoutSharding() { + return new Bound<>(filenamePrefix, filenameSuffix, 1, "", formatClass, + keyClass, valueClass, configurationProperties); + } + + public Bound<K, V> withConfigurationProperty(String key, String value) { + configurationProperties.put(key, value); + return this; + } + + public String getFilenamePrefix() { + return filenamePrefix; + } + + public String getShardTemplate() { + return shardTemplate; + } + + public int getNumShards() { + return numShards; + } + + public String getFilenameSuffix() { + return filenameSuffix; + } + + public Class<? extends FileOutputFormat<K, V>> getFormatClass() { + return formatClass; + } + + public Class<V> getValueClass() { + return valueClass; + } + + public Class<K> getKeyClass() { + return keyClass; + } + + public Map<String, String> getConfigurationProperties() { + return configurationProperties; + } + + @Override + public PDone apply(PCollection<KV<K, V>> input) { + Preconditions.checkNotNull(filenamePrefix, + "need to set the filename prefix of an HadoopIO.Write transform"); + Preconditions.checkNotNull(formatClass, + "need to set the format class of an HadoopIO.Write transform"); + Preconditions.checkNotNull(keyClass, + "need to set the key class of an HadoopIO.Write transform"); + Preconditions.checkNotNull(valueClass, + "need to set the value class of an HadoopIO.Write transform"); + + Preconditions.checkArgument(ShardNameTemplateAware.class.isAssignableFrom(formatClass), + "Format class must implement " + ShardNameTemplateAware.class.getName()); + + return PDone.in(input.getPipeline()); + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/hadoop/ShardNameBuilder.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/hadoop/ShardNameBuilder.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/hadoop/ShardNameBuilder.java new file mode 100644 index 0000000..21c7985 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/hadoop/ShardNameBuilder.java @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark.io.hadoop; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.apache.hadoop.fs.Path; + +public final class ShardNameBuilder { + + private ShardNameBuilder() { + } + + /** + * Replace occurrences of uppercase letters 'N' with the given {code}shardCount{code}, + * left-padded with zeros if necessary. + * @see com.google.cloud.dataflow.sdk.io.ShardNameTemplate + * @param template the string template containing uppercase letters 'N' + * @param shardCount the total number of shards + * @return a string template with 'N' replaced by the shard count + */ + public static String replaceShardCount(String template, int shardCount) { + return replaceShardPattern(template, "N+", shardCount); + } + + /** + * Replace occurrences of uppercase letters 'S' with the given {code}shardNumber{code}, + * left-padded with zeros if necessary. + * @see com.google.cloud.dataflow.sdk.io.ShardNameTemplate + * @param template the string template containing uppercase letters 'S' + * @param shardNumber the number of a particular shard + * @return a string template with 'S' replaced by the shard number + */ + public static String replaceShardNumber(String template, int shardNumber) { + return replaceShardPattern(template, "S+", shardNumber); + } + + private static String replaceShardPattern(String template, String pattern, int n) { + Pattern p = Pattern.compile(pattern); + Matcher m = p.matcher(template); + StringBuffer sb = new StringBuffer(); + while (m.find()) { + // replace pattern with a String format string: + // index 1, zero-padding flag (0), width length of matched pattern, decimal conversion + m.appendReplacement(sb, "%1\\$0" + m.group().length() + "d"); + } + m.appendTail(sb); + return String.format(sb.toString(), n); + } + + /** + * @param pathPrefix a relative or absolute path + * @param template a template string + * @return the output directory for the given prefix, template and suffix + */ + public static String getOutputDirectory(String pathPrefix, String template) { + String out = new Path(pathPrefix + template).getParent().toString(); + if (out.isEmpty()) { + return "./"; + } + return out; + } + + /** + * @param pathPrefix a relative or absolute path + * @param template a template string + * @return the prefix of the output filename for the given path prefix and template + */ + public static String getOutputFilePrefix(String pathPrefix, String template) { + String name = new Path(pathPrefix + template).getName(); + if (name.endsWith(template)) { + return name.substring(0, name.length() - template.length()); + } else { + return ""; + } + } + + /** + * @param pathPrefix a relative or absolute path + * @param template a template string + * @return the template for the output filename for the given path prefix and + * template + */ + public static String getOutputFileTemplate(String pathPrefix, String template) { + String name = new Path(pathPrefix + template).getName(); + if (name.endsWith(template)) { + return template; + } else { + return name; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/41c4ca6a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/hadoop/ShardNameTemplateAware.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/hadoop/ShardNameTemplateAware.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/hadoop/ShardNameTemplateAware.java new file mode 100644 index 0000000..fdee42b --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/hadoop/ShardNameTemplateAware.java @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package org.apache.beam.runners.spark.io.hadoop; + +/** + * A marker interface that implementations of + * {@link org.apache.hadoop.mapreduce.lib.output.FileOutputFormat} implement to indicate + * that they produce shard names that adhere to the template in + * {@link HadoopIO.Write}. + * + * Some common shard names are defined in + * {@link com.google.cloud.dataflow.sdk.io.ShardNameTemplate}. + */ +public interface ShardNameTemplateAware { +}
