Delay converting PCollection values to bytes in case they are only used for views.
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/7cff3049 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/7cff3049 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/7cff3049 Branch: refs/heads/master Commit: 7cff30498d0bcc0bdafddc0f0a6190d7b28d56a6 Parents: c51bc32 Author: Tom White <t...@cloudera.com> Authored: Wed Jul 8 20:57:25 2015 +0100 Committer: Tom White <t...@cloudera.com> Committed: Thu Mar 10 11:15:14 2016 +0000 ---------------------------------------------------------------------- .../dataflow/spark/EvaluationContext.java | 90 +++++++++++++++----- .../dataflow/spark/TransformTranslator.java | 4 +- 2 files changed, 68 insertions(+), 26 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/7cff3049/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java index 56f8521..649cbe9 100644 --- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java +++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java @@ -51,8 +51,8 @@ public class EvaluationContext implements EvaluationResult { private final Pipeline pipeline; private final SparkRuntimeContext runtime; private final CoderRegistry registry; - private final Map<PValue, JavaRDDLike<?, ?>> rdds = new LinkedHashMap<>(); - private final Set<JavaRDDLike<?, ?>> leafRdds = new LinkedHashSet<>(); + private final Map<PValue, RDDHolder<?>> pcollections = new LinkedHashMap<>(); + private final Set<RDDHolder<?>> leafRdds = new LinkedHashSet<>(); private final Set<PValue> multireads = new LinkedHashSet<>(); private final Map<PValue, Object> pobjects = new LinkedHashMap<>(); private final Map<PValue, Iterable<WindowedValue<?>>> pview = new LinkedHashMap<>(); @@ -65,6 +65,52 @@ public class EvaluationContext implements EvaluationResult { this.runtime = new SparkRuntimeContext(jsc, pipeline); } + /** + * Holds an RDD or values for deferred conversion to an RDD if needed. PCollections are + * sometimes created from a collection of objects (using RDD parallelize) and then + * only used to create View objects; in which case they do not need to be + * converted to bytes since they are not transferred across the network until they are + * broadcast. + */ + private class RDDHolder<T> { + + private Iterable<T> values; + private Coder<T> coder; + private JavaRDDLike<T, ?> rdd; + + public RDDHolder(Iterable<T> values, Coder<T> coder) { + this.values = values; + this.coder = coder; + } + + public RDDHolder(JavaRDDLike<T, ?> rdd) { + this.rdd = rdd; + } + + public JavaRDDLike<T, ?> getRDD() { + if (rdd == null) { + rdd = jsc.parallelize(CoderHelpers.toByteArrays(values, coder)) + .map(CoderHelpers.fromByteFunction(coder)); + } + return rdd; + } + + public Iterable<T> getValues(PCollection<T> pcollection) { + if (values == null) { + coder = pcollection.getCoder(); + JavaRDDLike<byte[], ?> bytesRDD = rdd.map(CoderHelpers.toByteFunction(coder)); + List<byte[]> clientBytes = bytesRDD.collect(); + values = Iterables.transform(clientBytes, new Function<byte[], T>() { + @Override + public T apply(byte[] bytes) { + return CoderHelpers.fromByteArray(bytes, coder); + } + }); + } + return values; + } + } + JavaSparkContext getSparkContext() { return jsc; } @@ -97,17 +143,23 @@ public class EvaluationContext implements EvaluationResult { return output; } - void setOutputRDD(PTransform<?, ?> transform, JavaRDDLike<?, ?> rdd) { + <T> void setOutputRDD(PTransform<?, ?> transform, JavaRDDLike<T, ?> rdd) { setRDD((PValue) getOutput(transform), rdd); } + <T> void setOutputRDDFromValues(PTransform<?, ?> transform, Iterable<T> values, + Coder<T> coder) { + pcollections.put((PValue) getOutput(transform), new RDDHolder<>(values, coder)); + } + void setPView(PValue view, Iterable<WindowedValue<?>> value) { pview.put(view, value); } JavaRDDLike<?, ?> getRDD(PValue pvalue) { - JavaRDDLike<?, ?> rdd = rdds.get(pvalue); - leafRdds.remove(rdd); + RDDHolder<?> rddHolder = pcollections.get(pvalue); + JavaRDDLike<?, ?> rdd = rddHolder.getRDD(); + leafRdds.remove(rddHolder); if (multireads.contains(pvalue)) { // Ensure the RDD is marked as cached rdd.rdd().cache(); @@ -117,14 +169,15 @@ public class EvaluationContext implements EvaluationResult { return rdd; } - void setRDD(PValue pvalue, JavaRDDLike<?, ?> rdd) { + <T> void setRDD(PValue pvalue, JavaRDDLike<T, ?> rdd) { try { rdd.rdd().setName(pvalue.getName()); } catch (IllegalStateException e) { // name not set, ignore } - rdds.put(pvalue, rdd); - leafRdds.add(rdd); + RDDHolder<T> rddHolder = new RDDHolder<>(rdd); + pcollections.put(pvalue, rddHolder); + leafRdds.add(rddHolder); } JavaRDDLike<?, ?> getInputRDD(PTransform<? extends PInput, ?> transform) { @@ -142,7 +195,8 @@ public class EvaluationContext implements EvaluationResult { * effects). */ void computeOutputs() { - for (JavaRDDLike<?, ?> rdd : leafRdds) { + for (RDDHolder<?> rddHolder : leafRdds) { + JavaRDDLike<?, ?> rdd = rddHolder.getRDD(); rdd.rdd().cache(); // cache so that any subsequent get() is cheap rdd.count(); // force the RDD to be computed } @@ -155,8 +209,8 @@ public class EvaluationContext implements EvaluationResult { T result = (T) pobjects.get(value); return result; } - if (rdds.containsKey(value)) { - JavaRDDLike<?, ?> rdd = rdds.get(value); + if (pcollections.containsKey(value)) { + JavaRDDLike<?, ?> rdd = pcollections.get(value).getRDD(); @SuppressWarnings("unchecked") T res = (T) Iterables.getOnlyElement(rdd.collect()); pobjects.put(value, res); @@ -179,18 +233,8 @@ public class EvaluationContext implements EvaluationResult { @Override public <T> Iterable<T> get(PCollection<T> pcollection) { @SuppressWarnings("unchecked") - JavaRDDLike<T, ?> rdd = (JavaRDDLike<T, ?>) getRDD(pcollection); - // Use a coder to convert the objects in the PCollection to byte arrays, so they - // can be transferred over the network. - final Coder<T> coder = pcollection.getCoder(); - JavaRDDLike<byte[], ?> bytesRDD = rdd.map(CoderHelpers.toByteFunction(coder)); - List<byte[]> clientBytes = bytesRDD.collect(); - return Iterables.transform(clientBytes, new Function<byte[], T>() { - @Override - public T apply(byte[] bytes) { - return CoderHelpers.fromByteArray(bytes, coder); - } - }); + RDDHolder<T> rddHolder = (RDDHolder<T>) pcollections.get(pcollection); + return rddHolder.getValues(pcollection); } @Override http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/7cff3049/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java index b0fd4a3..195766e 100644 --- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java +++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java @@ -528,9 +528,7 @@ public final class TransformTranslator { // Use a coder to convert the objects in the PCollection to byte arrays, so they // can be transferred over the network. Coder<T> coder = context.getOutput(transform).getCoder(); - JavaRDD<byte[]> rdd = context.getSparkContext().parallelize( - CoderHelpers.toByteArrays(elems, coder)); - context.setOutputRDD(transform, rdd.map(CoderHelpers.fromByteFunction(coder))); + context.setOutputRDDFromValues(transform, elems, coder); } }; }