Delay converting PCollection values to bytes in case they are only used for 
views.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/7cff3049
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/7cff3049
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/7cff3049

Branch: refs/heads/master
Commit: 7cff30498d0bcc0bdafddc0f0a6190d7b28d56a6
Parents: c51bc32
Author: Tom White <t...@cloudera.com>
Authored: Wed Jul 8 20:57:25 2015 +0100
Committer: Tom White <t...@cloudera.com>
Committed: Thu Mar 10 11:15:14 2016 +0000

----------------------------------------------------------------------
 .../dataflow/spark/EvaluationContext.java       | 90 +++++++++++++++-----
 .../dataflow/spark/TransformTranslator.java     |  4 +-
 2 files changed, 68 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/7cff3049/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java
 
b/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java
index 56f8521..649cbe9 100644
--- 
a/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java
+++ 
b/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java
@@ -51,8 +51,8 @@ public class EvaluationContext implements EvaluationResult {
   private final Pipeline pipeline;
   private final SparkRuntimeContext runtime;
   private final CoderRegistry registry;
-  private final Map<PValue, JavaRDDLike<?, ?>> rdds = new LinkedHashMap<>();
-  private final Set<JavaRDDLike<?, ?>> leafRdds = new LinkedHashSet<>();
+  private final Map<PValue, RDDHolder<?>> pcollections = new LinkedHashMap<>();
+  private final Set<RDDHolder<?>> leafRdds = new LinkedHashSet<>();
   private final Set<PValue> multireads = new LinkedHashSet<>();
   private final Map<PValue, Object> pobjects = new LinkedHashMap<>();
   private final Map<PValue, Iterable<WindowedValue<?>>> pview = new 
LinkedHashMap<>();
@@ -65,6 +65,52 @@ public class EvaluationContext implements EvaluationResult {
     this.runtime = new SparkRuntimeContext(jsc, pipeline);
   }
 
+  /**
+   * Holds an RDD or values for deferred conversion to an RDD if needed. 
PCollections are
+   * sometimes created from a collection of objects (using RDD parallelize) 
and then
+   * only used to create View objects; in which case they do not need to be
+   * converted to bytes since they are not transferred across the network 
until they are
+   * broadcast.
+   */
+  private class RDDHolder<T> {
+
+    private Iterable<T> values;
+    private Coder<T> coder;
+    private JavaRDDLike<T, ?> rdd;
+
+    public RDDHolder(Iterable<T> values, Coder<T> coder) {
+      this.values = values;
+      this.coder = coder;
+    }
+
+    public RDDHolder(JavaRDDLike<T, ?> rdd) {
+      this.rdd = rdd;
+    }
+
+    public JavaRDDLike<T, ?> getRDD() {
+      if (rdd == null) {
+        rdd = jsc.parallelize(CoderHelpers.toByteArrays(values, coder))
+            .map(CoderHelpers.fromByteFunction(coder));
+      }
+      return rdd;
+    }
+
+    public Iterable<T> getValues(PCollection<T> pcollection) {
+      if (values == null) {
+        coder = pcollection.getCoder();
+        JavaRDDLike<byte[], ?> bytesRDD = 
rdd.map(CoderHelpers.toByteFunction(coder));
+        List<byte[]> clientBytes = bytesRDD.collect();
+        values = Iterables.transform(clientBytes, new Function<byte[], T>() {
+          @Override
+          public T apply(byte[] bytes) {
+            return CoderHelpers.fromByteArray(bytes, coder);
+          }
+        });
+      }
+      return values;
+    }
+  }
+
   JavaSparkContext getSparkContext() {
     return jsc;
   }
@@ -97,17 +143,23 @@ public class EvaluationContext implements EvaluationResult 
{
     return output;
   }
 
-  void setOutputRDD(PTransform<?, ?> transform, JavaRDDLike<?, ?> rdd) {
+  <T> void setOutputRDD(PTransform<?, ?> transform, JavaRDDLike<T, ?> rdd) {
     setRDD((PValue) getOutput(transform), rdd);
   }
 
+  <T> void setOutputRDDFromValues(PTransform<?, ?> transform, Iterable<T> 
values,
+      Coder<T> coder) {
+    pcollections.put((PValue) getOutput(transform), new RDDHolder<>(values, 
coder));
+  }
+
   void setPView(PValue view, Iterable<WindowedValue<?>> value) {
     pview.put(view, value);
   }
 
   JavaRDDLike<?, ?> getRDD(PValue pvalue) {
-    JavaRDDLike<?, ?> rdd = rdds.get(pvalue);
-    leafRdds.remove(rdd);
+    RDDHolder<?> rddHolder = pcollections.get(pvalue);
+    JavaRDDLike<?, ?> rdd = rddHolder.getRDD();
+    leafRdds.remove(rddHolder);
     if (multireads.contains(pvalue)) {
       // Ensure the RDD is marked as cached
       rdd.rdd().cache();
@@ -117,14 +169,15 @@ public class EvaluationContext implements 
EvaluationResult {
     return rdd;
   }
 
-  void setRDD(PValue pvalue, JavaRDDLike<?, ?> rdd) {
+  <T> void setRDD(PValue pvalue, JavaRDDLike<T, ?> rdd) {
     try {
       rdd.rdd().setName(pvalue.getName());
     } catch (IllegalStateException e) {
       // name not set, ignore
     }
-    rdds.put(pvalue, rdd);
-    leafRdds.add(rdd);
+    RDDHolder<T> rddHolder = new RDDHolder<>(rdd);
+    pcollections.put(pvalue, rddHolder);
+    leafRdds.add(rddHolder);
   }
 
   JavaRDDLike<?, ?> getInputRDD(PTransform<? extends PInput, ?> transform) {
@@ -142,7 +195,8 @@ public class EvaluationContext implements EvaluationResult {
    * effects).
    */
   void computeOutputs() {
-    for (JavaRDDLike<?, ?> rdd : leafRdds) {
+    for (RDDHolder<?> rddHolder : leafRdds) {
+      JavaRDDLike<?, ?> rdd = rddHolder.getRDD();
       rdd.rdd().cache(); // cache so that any subsequent get() is cheap
       rdd.count(); // force the RDD to be computed
     }
@@ -155,8 +209,8 @@ public class EvaluationContext implements EvaluationResult {
       T result = (T) pobjects.get(value);
       return result;
     }
-    if (rdds.containsKey(value)) {
-      JavaRDDLike<?, ?> rdd = rdds.get(value);
+    if (pcollections.containsKey(value)) {
+      JavaRDDLike<?, ?> rdd = pcollections.get(value).getRDD();
       @SuppressWarnings("unchecked")
       T res = (T) Iterables.getOnlyElement(rdd.collect());
       pobjects.put(value, res);
@@ -179,18 +233,8 @@ public class EvaluationContext implements EvaluationResult 
{
   @Override
   public <T> Iterable<T> get(PCollection<T> pcollection) {
     @SuppressWarnings("unchecked")
-    JavaRDDLike<T, ?> rdd = (JavaRDDLike<T, ?>) getRDD(pcollection);
-    // Use a coder to convert the objects in the PCollection to byte arrays, 
so they
-    // can be transferred over the network.
-    final Coder<T> coder = pcollection.getCoder();
-    JavaRDDLike<byte[], ?> bytesRDD = 
rdd.map(CoderHelpers.toByteFunction(coder));
-    List<byte[]> clientBytes = bytesRDD.collect();
-    return Iterables.transform(clientBytes, new Function<byte[], T>() {
-      @Override
-      public T apply(byte[] bytes) {
-        return CoderHelpers.fromByteArray(bytes, coder);
-      }
-    });
+    RDDHolder<T> rddHolder = (RDDHolder<T>) pcollections.get(pcollection);
+    return rddHolder.getValues(pcollection);
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/7cff3049/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java
 
b/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java
index b0fd4a3..195766e 100644
--- 
a/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java
+++ 
b/runners/spark/src/main/java/com/cloudera/dataflow/spark/TransformTranslator.java
@@ -528,9 +528,7 @@ public final class TransformTranslator {
         // Use a coder to convert the objects in the PCollection to byte 
arrays, so they
         // can be transferred over the network.
         Coder<T> coder = context.getOutput(transform).getCoder();
-        JavaRDD<byte[]> rdd = context.getSparkContext().parallelize(
-            CoderHelpers.toByteArrays(elems, coder));
-        context.setOutputRDD(transform, 
rdd.map(CoderHelpers.fromByteFunction(coder)));
+        context.setOutputRDDFromValues(transform, elems, coder);
       }
     };
   }

Reply via email to