Author: xuefu
Date: Wed Dec 17 20:56:19 2014
New Revision: 1646336

URL: http://svn.apache.org/r1646336
Log:
HIVE-8843: Release RDD cache when Hive query is done [Spark Branch] (Jimmy via 
Xuefu)

Modified:
    
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/LocalHiveSparkClient.java
    
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java
    
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/RemoteHiveSparkClient.java
    
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java
    
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java
    
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java
    
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/impl/LocalSparkJobStatus.java
    
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/JobContext.java
    
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/JobContextImpl.java
    
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/MonitorCallback.java
    
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java
    
hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java

Modified: 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/LocalHiveSparkClient.java
URL: 
http://svn.apache.org/viewvc/hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/LocalHiveSparkClient.java?rev=1646336&r1=1646335&r2=1646336&view=diff
==============================================================================
--- 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/LocalHiveSparkClient.java
 (original)
+++ 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/LocalHiveSparkClient.java
 Wed Dec 17 20:56:19 2014
@@ -30,7 +30,6 @@ import org.apache.hadoop.hive.conf.HiveC
 import org.apache.hadoop.hive.ql.Context;
 import org.apache.hadoop.hive.ql.DriverContext;
 import org.apache.hadoop.hive.ql.exec.Utilities;
-import org.apache.hive.spark.counter.SparkCounters;
 import org.apache.hadoop.hive.ql.exec.spark.status.SparkJobRef;
 import org.apache.hadoop.hive.ql.exec.spark.status.impl.JobMetricsListener;
 import org.apache.hadoop.hive.ql.exec.spark.status.impl.LocalSparkJobStatus;
@@ -40,6 +39,7 @@ import org.apache.hadoop.hive.ql.plan.Sp
 import org.apache.hadoop.hive.ql.session.SessionState;
 import org.apache.hadoop.io.BytesWritable;
 import org.apache.hadoop.mapred.JobConf;
+import org.apache.hive.spark.counter.SparkCounters;
 import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaFutureAction;
 import org.apache.spark.api.java.JavaPairRDD;
@@ -129,8 +129,8 @@ public class LocalHiveSparkClient implem
     JavaFutureAction<Void> future = 
finalRDD.foreachAsync(HiveVoidFunction.getInstance());
     // As we always use foreach action to submit RDD graph, it would only 
trigger one job.
     int jobId = future.jobIds().get(0);
-    LocalSparkJobStatus sparkJobStatus =
-      new LocalSparkJobStatus(sc, jobId, jobMetricsListener, sparkCounters, 
future);
+    LocalSparkJobStatus sparkJobStatus = new LocalSparkJobStatus(
+      sc, jobId, jobMetricsListener, sparkCounters, plan.getCachedRDDIds(), 
future);
     return new SparkJobRef(Integer.toString(jobId), sparkJobStatus);
   }
 

Modified: 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java
URL: 
http://svn.apache.org/viewvc/hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java?rev=1646336&r1=1646335&r2=1646336&view=diff
==============================================================================
--- 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java
 (original)
+++ 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java
 Wed Dec 17 20:56:19 2014
@@ -23,27 +23,29 @@ import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.WritableComparable;
 import org.apache.hadoop.io.WritableUtils;
 import org.apache.spark.api.java.JavaPairRDD;
-
-import com.google.common.base.Preconditions;
-
 import org.apache.spark.api.java.function.PairFunction;
 import org.apache.spark.storage.StorageLevel;
 
 import scala.Tuple2;
 
+import com.google.common.base.Preconditions;
+
 
 public class MapInput implements SparkTran<WritableComparable, Writable,
     WritableComparable, Writable> {
   private JavaPairRDD<WritableComparable, Writable> hadoopRDD;
   private boolean toCache;
+  private final SparkPlan sparkPlan;
 
-  public MapInput(JavaPairRDD<WritableComparable, Writable> hadoopRDD) {
-    this(hadoopRDD, false);
+  public MapInput(SparkPlan sparkPlan, JavaPairRDD<WritableComparable, 
Writable> hadoopRDD) {
+    this(sparkPlan, hadoopRDD, false);
   }
 
-  public MapInput(JavaPairRDD<WritableComparable, Writable> hadoopRDD, boolean 
toCache) {
+  public MapInput(SparkPlan sparkPlan,
+      JavaPairRDD<WritableComparable, Writable> hadoopRDD, boolean toCache) {
     this.hadoopRDD = hadoopRDD;
     this.toCache = toCache;
+    this.sparkPlan = sparkPlan;
   }
 
   public void setToCache(boolean toCache) {
@@ -55,8 +57,15 @@ public class MapInput implements SparkTr
       JavaPairRDD<WritableComparable, Writable> input) {
     Preconditions.checkArgument(input == null,
         "AssertionError: MapInput doesn't take any input");
-    return toCache ? hadoopRDD.mapToPair(
-      new CopyFunction()).persist(StorageLevel.MEMORY_AND_DISK()) : hadoopRDD;
+    JavaPairRDD<WritableComparable, Writable> result;
+    if (toCache) {
+      result = hadoopRDD.mapToPair(new CopyFunction());
+      sparkPlan.addCachedRDDId(result.id());
+      result = result.persist(StorageLevel.MEMORY_AND_DISK());
+    } else {
+      result = hadoopRDD;
+    }
+    return result;
   }
 
   private static class CopyFunction implements 
PairFunction<Tuple2<WritableComparable, Writable>,

Modified: 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/RemoteHiveSparkClient.java
URL: 
http://svn.apache.org/viewvc/hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/RemoteHiveSparkClient.java?rev=1646336&r1=1646335&r2=1646336&view=diff
==============================================================================
--- 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/RemoteHiveSparkClient.java
 (original)
+++ 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/RemoteHiveSparkClient.java
 Wed Dec 17 20:56:19 2014
@@ -34,7 +34,6 @@ import org.apache.hadoop.hive.conf.HiveC
 import org.apache.hadoop.hive.ql.Context;
 import org.apache.hadoop.hive.ql.DriverContext;
 import org.apache.hadoop.hive.ql.exec.Utilities;
-import org.apache.hive.spark.counter.SparkCounters;
 import org.apache.hadoop.hive.ql.exec.spark.status.SparkJobRef;
 import org.apache.hadoop.hive.ql.exec.spark.status.impl.RemoteSparkJobStatus;
 import org.apache.hadoop.hive.ql.io.HiveKey;
@@ -48,6 +47,7 @@ import org.apache.hive.spark.client.JobC
 import org.apache.hive.spark.client.JobHandle;
 import org.apache.hive.spark.client.SparkClient;
 import org.apache.hive.spark.client.SparkClientFactory;
+import org.apache.hive.spark.counter.SparkCounters;
 import org.apache.spark.SparkConf;
 import org.apache.spark.SparkException;
 import org.apache.spark.api.java.JavaFutureAction;
@@ -93,7 +93,6 @@ public class RemoteHiveSparkClient imple
   }
 
   @Override
-  @SuppressWarnings("serial")
   public SparkJobRef execute(final DriverContext driverContext, final 
SparkWork sparkWork) throws Exception {
     final Context ctx = driverContext.getCtx();
     final HiveConf hiveConf = (HiveConf) ctx.getConf();
@@ -220,7 +219,7 @@ public class RemoteHiveSparkClient imple
       JavaPairRDD<HiveKey, BytesWritable> finalRDD = plan.generateGraph();
       // We use Spark RDD async action to submit job as it's the only way to 
get jobId now.
       JavaFutureAction<Void> future = 
finalRDD.foreachAsync(HiveVoidFunction.getInstance());
-      jc.monitor(future, sparkCounters);
+      jc.monitor(future, sparkCounters, plan.getCachedRDDIds());
       return null;
     }
 

Modified: 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java
URL: 
http://svn.apache.org/viewvc/hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java?rev=1646336&r1=1646335&r2=1646336&view=diff
==============================================================================
--- 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java
 (original)
+++ 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java
 Wed Dec 17 20:56:19 2014
@@ -27,20 +27,26 @@ public class ShuffleTran implements Spar
   private final SparkShuffler shuffler;
   private final int numOfPartitions;
   private final boolean toCache;
+  private final SparkPlan sparkPlan;
 
-  public ShuffleTran(SparkShuffler sf, int n) {
-    this(sf, n, false);
+  public ShuffleTran(SparkPlan sparkPlan, SparkShuffler sf, int n) {
+    this(sparkPlan, sf, n, false);
   }
 
-  public ShuffleTran(SparkShuffler sf, int n, boolean toCache) {
+  public ShuffleTran(SparkPlan sparkPlan, SparkShuffler sf, int n, boolean 
toCache) {
     shuffler = sf;
     numOfPartitions = n;
     this.toCache = toCache;
+    this.sparkPlan = sparkPlan;
   }
 
   @Override
   public JavaPairRDD<HiveKey, Iterable<BytesWritable>> 
transform(JavaPairRDD<HiveKey, BytesWritable> input) {
     JavaPairRDD<HiveKey, Iterable<BytesWritable>> result = 
shuffler.shuffle(input, numOfPartitions);
-    return toCache ? result.persist(StorageLevel.MEMORY_AND_DISK()) : result;
+    if (toCache) {
+      sparkPlan.addCachedRDDId(result.id());
+      result = result.persist(StorageLevel.MEMORY_AND_DISK());
+    }
+    return result;
   }
 }

Modified: 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java
URL: 
http://svn.apache.org/viewvc/hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java?rev=1646336&r1=1646335&r2=1646336&view=diff
==============================================================================
--- 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java
 (original)
+++ 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java
 Wed Dec 17 20:56:19 2014
@@ -18,12 +18,6 @@
 
 package org.apache.hadoop.hive.ql.exec.spark;
 
-import org.apache.hadoop.hive.ql.io.HiveKey;
-import org.apache.hadoop.io.BytesWritable;
-import org.apache.spark.api.java.JavaPairRDD;
-
-import com.google.common.base.Preconditions;
-
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -32,11 +26,18 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
+import org.apache.hadoop.hive.ql.io.HiveKey;
+import org.apache.hadoop.io.BytesWritable;
+import org.apache.spark.api.java.JavaPairRDD;
+
+import com.google.common.base.Preconditions;
+
 public class SparkPlan {
   private final Set<SparkTran> rootTrans = new HashSet<SparkTran>();
   private final Set<SparkTran> leafTrans = new HashSet<SparkTran>();
   private final Map<SparkTran, List<SparkTran>> transGraph = new 
HashMap<SparkTran, List<SparkTran>>();
   private final Map<SparkTran, List<SparkTran>> invertedTransGraph = new 
HashMap<SparkTran, List<SparkTran>>();
+  private final Set<Integer> cachedRDDIds = new HashSet<Integer>();
 
   public JavaPairRDD<HiveKey, BytesWritable> generateGraph() throws 
IllegalStateException {
     Map<SparkTran, JavaPairRDD<HiveKey, BytesWritable>> tranToOutputRDDMap
@@ -82,6 +83,14 @@ public class SparkPlan {
     leafTrans.add(tran);
   }
 
+  public void addCachedRDDId(int rddId) {
+    cachedRDDIds.add(rddId);
+  }
+
+  public Set<Integer> getCachedRDDIds() {
+    return cachedRDDIds;
+  }
+
   /**
    * This method returns a topologically sorted list of SparkTran
    */

Modified: 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java
URL: 
http://svn.apache.org/viewvc/hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java?rev=1646336&r1=1646335&r2=1646336&view=diff
==============================================================================
--- 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java
 (original)
+++ 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java
 Wed Dec 17 20:56:19 2014
@@ -22,16 +22,9 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
-import com.google.common.base.Preconditions;
-import org.apache.commons.lang.StringUtils;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.hive.ql.io.merge.MergeFileMapper;
-import org.apache.hadoop.hive.ql.io.merge.MergeFileOutputFormat;
-import org.apache.hadoop.hive.ql.io.merge.MergeFileWork;
-import org.apache.hadoop.mapred.FileOutputFormat;
-import org.apache.hadoop.mapred.Partitioner;
 import org.apache.hadoop.hive.conf.HiveConf;
 import org.apache.hadoop.hive.ql.Context;
 import org.apache.hadoop.hive.ql.ErrorMsg;
@@ -39,6 +32,9 @@ import org.apache.hadoop.hive.ql.exec.Ut
 import org.apache.hadoop.hive.ql.exec.mr.ExecMapper;
 import org.apache.hadoop.hive.ql.exec.mr.ExecReducer;
 import org.apache.hadoop.hive.ql.io.BucketizedHiveInputFormat;
+import org.apache.hadoop.hive.ql.io.merge.MergeFileMapper;
+import org.apache.hadoop.hive.ql.io.merge.MergeFileOutputFormat;
+import org.apache.hadoop.hive.ql.io.merge.MergeFileWork;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.ql.plan.BaseWork;
 import org.apache.hadoop.hive.ql.plan.MapWork;
@@ -47,13 +43,16 @@ import org.apache.hadoop.hive.ql.plan.Sp
 import org.apache.hadoop.hive.ql.plan.SparkWork;
 import org.apache.hadoop.hive.ql.stats.StatsFactory;
 import org.apache.hadoop.hive.ql.stats.StatsPublisher;
-import org.apache.hadoop.hive.shims.ShimLoader;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapred.FileOutputFormat;
 import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.Partitioner;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 
+import com.google.common.base.Preconditions;
+
 public class SparkPlanGenerator {
   private static final Log LOG = LogFactory.getLog(SparkPlanGenerator.class);
 
@@ -111,11 +110,12 @@ public class SparkPlanGenerator {
 
     SparkTran result;
     if (work instanceof MapWork) {
-      result = generateMapInput((MapWork)work);
+      result = generateMapInput(sparkPlan, (MapWork)work);
       sparkPlan.addTran(result);
     } else if (work instanceof ReduceWork) {
       List<BaseWork> parentWorks = sparkWork.getParents(work);
-      result = generate(sparkWork.getEdgeProperty(parentWorks.get(0), work), 
cloneToWork.containsKey(work));
+      result = generate(sparkPlan,
+        sparkWork.getEdgeProperty(parentWorks.get(0), work), 
cloneToWork.containsKey(work));
       sparkPlan.addTran(result);
       for (BaseWork parentWork : parentWorks) {
         sparkPlan.connect(workToTranMap.get(parentWork), result);
@@ -158,18 +158,18 @@ public class SparkPlanGenerator {
     return inputFormatClass;
   }
 
-  private MapInput generateMapInput(MapWork mapWork)
+  private MapInput generateMapInput(SparkPlan sparkPlan, MapWork mapWork)
       throws Exception {
     JobConf jobConf = cloneJobConf(mapWork);
     Class ifClass = getInputFormat(jobConf, mapWork);
 
     JavaPairRDD<WritableComparable, Writable> hadoopRDD = 
sc.hadoopRDD(jobConf, ifClass,
         WritableComparable.class, Writable.class);
-    MapInput result = new MapInput(hadoopRDD, 
cloneToWork.containsKey(mapWork));
+    MapInput result = new MapInput(sparkPlan, hadoopRDD, 
cloneToWork.containsKey(mapWork));
     return result;
   }
 
-  private ShuffleTran generate(SparkEdgeProperty edge, boolean toCache) {
+  private ShuffleTran generate(SparkPlan sparkPlan, SparkEdgeProperty edge, 
boolean toCache) {
     Preconditions.checkArgument(!edge.isShuffleNone(),
         "AssertionError: SHUFFLE_NONE should only be used for UnionWork.");
     SparkShuffler shuffler;
@@ -180,7 +180,7 @@ public class SparkPlanGenerator {
     } else {
       shuffler = new GroupByShuffler();
     }
-    return new ShuffleTran(shuffler, edge.getNumPartitions(), toCache);
+    return new ShuffleTran(sparkPlan, shuffler, edge.getNumPartitions(), 
toCache);
   }
 
   private SparkTran generate(BaseWork work) throws Exception {

Modified: 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/impl/LocalSparkJobStatus.java
URL: 
http://svn.apache.org/viewvc/hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/impl/LocalSparkJobStatus.java?rev=1646336&r1=1646335&r2=1646336&view=diff
==============================================================================
--- 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/impl/LocalSparkJobStatus.java
 (original)
+++ 
hive/branches/spark/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/impl/LocalSparkJobStatus.java
 Wed Dec 17 20:56:19 2014
@@ -20,13 +20,13 @@ package org.apache.hadoop.hive.ql.exec.s
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
-import com.google.common.collect.Maps;
 import org.apache.hadoop.hive.ql.exec.spark.Statistic.SparkStatistics;
 import org.apache.hadoop.hive.ql.exec.spark.Statistic.SparkStatisticsBuilder;
-import org.apache.hive.spark.counter.SparkCounters;
 import org.apache.hadoop.hive.ql.exec.spark.status.SparkJobStatus;
 import org.apache.hadoop.hive.ql.exec.spark.status.SparkStageProgress;
+import org.apache.hive.spark.counter.SparkCounters;
 import org.apache.spark.JobExecutionStatus;
 import org.apache.spark.SparkJobInfo;
 import org.apache.spark.SparkStageInfo;
@@ -38,6 +38,8 @@ import org.apache.spark.executor.TaskMet
 
 import scala.Option;
 
+import com.google.common.collect.Maps;
+
 public class LocalSparkJobStatus implements SparkJobStatus {
 
   private final JavaSparkContext sparkContext;
@@ -47,14 +49,16 @@ public class LocalSparkJobStatus impleme
   private JobMetricsListener jobMetricsListener;
   private SparkCounters sparkCounters;
   private JavaFutureAction<Void> future;
+  private Set<Integer> cachedRDDIds;
 
   public LocalSparkJobStatus(JavaSparkContext sparkContext, int jobId,
       JobMetricsListener jobMetricsListener, SparkCounters sparkCounters,
-      JavaFutureAction<Void> future) {
+      Set<Integer> cachedRDDIds, JavaFutureAction<Void> future) {
     this.sparkContext = sparkContext;
     this.jobId = jobId;
     this.jobMetricsListener = jobMetricsListener;
     this.sparkCounters = sparkCounters;
+    this.cachedRDDIds = cachedRDDIds;
     this.future = future;
   }
 
@@ -130,6 +134,11 @@ public class LocalSparkJobStatus impleme
   @Override
   public void cleanup() {
     jobMetricsListener.cleanup(jobId);
+    if (cachedRDDIds != null) {
+      for (Integer cachedRDDId: cachedRDDIds) {
+        sparkContext.sc().unpersistRDD(cachedRDDId, false);
+      }
+    }
   }
 
   private Map<String, Long> combineJobLevelMetrics(Map<String, 
List<TaskMetrics>> jobMetric) {

Modified: 
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/JobContext.java
URL: 
http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/JobContext.java?rev=1646336&r1=1646335&r2=1646336&view=diff
==============================================================================
--- 
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/JobContext.java
 (original)
+++ 
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/JobContext.java
 Wed Dec 17 20:56:19 2014
@@ -17,15 +17,15 @@
 
 package org.apache.hive.spark.client;
 
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.hadoop.hive.common.classification.InterfaceAudience;
 import org.apache.hive.spark.counter.SparkCounters;
 import org.apache.spark.api.java.JavaFutureAction;
 import org.apache.spark.api.java.JavaSparkContext;
 
-import org.apache.hadoop.hive.common.classification.InterfaceAudience;
-
-import java.util.List;
-import java.util.Map;
-
 /**
  * Holds runtime information about the job execution context.
  *
@@ -44,7 +44,8 @@ public interface JobContext {
    *
    * @return The job (unmodified).
    */
-  <T> JavaFutureAction<T> monitor(JavaFutureAction<T> job, SparkCounters 
sparkCounters);
+  <T> JavaFutureAction<T> monitor(
+    JavaFutureAction<T> job, SparkCounters sparkCounters, Set<Integer> 
cachedRDDIds);
 
   /**
    * Return a map from client job Id to corresponding JavaFutureActions

Modified: 
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/JobContextImpl.java
URL: 
http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/JobContextImpl.java?rev=1646336&r1=1646335&r2=1646336&view=diff
==============================================================================
--- 
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/JobContextImpl.java
 (original)
+++ 
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/JobContextImpl.java
 Wed Dec 17 20:56:19 2014
@@ -17,14 +17,15 @@
 
 package org.apache.hive.spark.client;
 
-import org.apache.hive.spark.counter.SparkCounters;
-import org.apache.spark.api.java.JavaFutureAction;
-import org.apache.spark.api.java.JavaSparkContext;
-
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 
+import org.apache.hive.spark.counter.SparkCounters;
+import org.apache.spark.api.java.JavaFutureAction;
+import org.apache.spark.api.java.JavaSparkContext;
+
 class JobContextImpl implements JobContext {
 
   private final JavaSparkContext sc;
@@ -44,8 +45,9 @@ class JobContextImpl implements JobConte
   }
 
   @Override
-  public <T> JavaFutureAction<T> monitor(JavaFutureAction<T> job, 
SparkCounters sparkCounters) {
-    monitorCb.get().call(job, sparkCounters);
+  public <T> JavaFutureAction<T> monitor(JavaFutureAction<T> job,
+      SparkCounters sparkCounters, Set<Integer> cachedRDDIds) {
+    monitorCb.get().call(job, sparkCounters, cachedRDDIds);
     return job;
   }
 

Modified: 
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/MonitorCallback.java
URL: 
http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/MonitorCallback.java?rev=1646336&r1=1646335&r2=1646336&view=diff
==============================================================================
--- 
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/MonitorCallback.java
 (original)
+++ 
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/MonitorCallback.java
 Wed Dec 17 20:56:19 2014
@@ -17,11 +17,13 @@
 
 package org.apache.hive.spark.client;
 
+import java.util.Set;
+
 import org.apache.hive.spark.counter.SparkCounters;
 import org.apache.spark.api.java.JavaFutureAction;
 
 interface MonitorCallback {
 
-  void call(JavaFutureAction<?> future, SparkCounters sparkCounters);
+  void call(JavaFutureAction<?> future, SparkCounters sparkCounters, 
Set<Integer> cachedRDDIds);
 
 }

Modified: 
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java
URL: 
http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java?rev=1646336&r1=1646335&r2=1646336&view=diff
==============================================================================
--- 
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java
 (original)
+++ 
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java
 Wed Dec 17 20:56:19 2014
@@ -24,6 +24,7 @@ import java.io.Serializable;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.Callable;
 import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.ExecutionException;
@@ -266,6 +267,7 @@ public class RemoteDriver {
     private final List<JavaFutureAction<?>> jobs;
     private final AtomicInteger completed;
     private SparkCounters sparkCounters;
+    private Set<Integer> cachedRDDIds;
 
     private Future<?> future;
 
@@ -274,6 +276,7 @@ public class RemoteDriver {
       this.jobs = Lists.newArrayList();
       this.completed = new AtomicInteger();
       this.sparkCounters = null;
+      this.cachedRDDIds = null;
     }
 
     @Override
@@ -281,8 +284,9 @@ public class RemoteDriver {
       try {
         jc.setMonitorCb(new MonitorCallback() {
           @Override
-          public void call(JavaFutureAction<?> future, SparkCounters 
sparkCounters) {
-            monitorJob(future, sparkCounters);
+          public void call(JavaFutureAction<?> future,
+              SparkCounters sparkCounters, Set<Integer> cachedRDDIds) {
+            monitorJob(future, sparkCounters, cachedRDDIds);
           }
         });
 
@@ -311,6 +315,7 @@ public class RemoteDriver {
       } finally {
         jc.setMonitorCb(null);
         activeJobs.remove(req.id);
+        releaseCache();
       }
       return null;
     }
@@ -326,13 +331,30 @@ public class RemoteDriver {
       }
     }
 
-    private void monitorJob(JavaFutureAction<?> job, SparkCounters 
sparkCounters) {
+    /**
+     * Release cached RDDs as soon as the job is done.
+     * This is different from local Spark client so as
+     * to save a RPC call/trip, avoid passing cached RDD
+     * id information around. Otherwise, we can follow
+     * the local Spark client way to be consistent.
+     */
+    void releaseCache() {
+      if (cachedRDDIds != null) {
+        for (Integer cachedRDDId: cachedRDDIds) {
+          jc.sc().sc().unpersistRDD(cachedRDDId, false);
+        }
+      }
+    }
+
+    private void monitorJob(JavaFutureAction<?> job,
+        SparkCounters sparkCounters, Set<Integer> cachedRDDIds) {
       jobs.add(job);
       if (!jc.getMonitoredJobs().containsKey(req.id)) {
         jc.getMonitoredJobs().put(req.id, new 
CopyOnWriteArrayList<JavaFutureAction<?>>());
       }
       jc.getMonitoredJobs().get(req.id).add(job);
       this.sparkCounters = sparkCounters;
+      this.cachedRDDIds = cachedRDDIds;
       protocol.jobSubmitted(req.id, job.jobIds().get(0));
     }
 

Modified: 
hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java
URL: 
http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java?rev=1646336&r1=1646335&r2=1646336&view=diff
==============================================================================
--- 
hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java
 (original)
+++ 
hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java
 Wed Dec 17 20:56:19 2014
@@ -17,6 +17,11 @@
 
 package org.apache.hive.spark.client;
 
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
 import java.io.File;
 import java.io.FileInputStream;
 import java.io.FileOutputStream;
@@ -27,24 +32,20 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
-import java.util.zip.ZipEntry;
 import java.util.jar.JarOutputStream;
+import java.util.zip.ZipEntry;
 
-import com.google.common.base.Objects;
-import com.google.common.base.Strings;
-import com.google.common.io.ByteStreams;
-import org.apache.spark.FutureAction;
+import org.apache.hive.spark.counter.SparkCounters;
 import org.apache.spark.SparkFiles;
 import org.apache.spark.api.java.JavaFutureAction;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.VoidFunction;
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
-import static org.junit.Assert.*;
 
-import org.apache.hive.spark.counter.SparkCounters;
+import com.google.common.base.Objects;
+import com.google.common.base.Strings;
+import com.google.common.io.ByteStreams;
 
 public class TestSparkClient {
 
@@ -258,7 +259,7 @@ public class TestSparkClient {
         public void call(Integer l) throws Exception {
 
         }
-      }), null);
+      }), null, null);
 
       future.get(TIMEOUT, TimeUnit.SECONDS);
 
@@ -332,7 +333,7 @@ public class TestSparkClient {
       counters.createCounter("group2", "counter2");
 
       jc.monitor(jc.sc().parallelize(Arrays.asList(1, 2, 3, 4, 5), 
5).foreachAsync(this),
-          counters);
+          counters, null);
 
       return null;
     }


Reply via email to