Repository: spark Updated Branches: refs/heads/master d06d5ee9b -> 0ce4e430a
SPARK-3290 [GRAPHX] No unpersist callls in SVDPlusPlus This just unpersist()s each RDD in this code that was cache()ed. Author: Sean Owen <so...@cloudera.com> Closes #4234 from srowen/SPARK-3290 and squashes the following commits: 66c1e11 [Sean Owen] unpersist() each RDD that was cache()ed Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0ce4e430 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0ce4e430 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0ce4e430 Branch: refs/heads/master Commit: 0ce4e430a81532dc317136f968f28742e087d840 Parents: d06d5ee Author: Sean Owen <so...@cloudera.com> Authored: Fri Feb 13 20:12:52 2015 -0800 Committer: Ankur Dave <ankurd...@gmail.com> Committed: Fri Feb 13 20:12:52 2015 -0800 ---------------------------------------------------------------------- .../apache/spark/graphx/lib/SVDPlusPlus.scala | 40 ++++++++++++++++---- 1 file changed, 32 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0ce4e430/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala ---------------------------------------------------------------------- diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index f58587e..112ed09 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -72,17 +72,22 @@ object SVDPlusPlus { // construct graph var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache() + materialize(g) + edges.unpersist() // Calculate initial bias and norm val t0 = g.aggregateMessages[(Long, Double)]( ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) }, (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2)) - g = g.outerJoinVertices(t0) { + val gJoinT0 = g.outerJoinVertices(t0) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[(Long, Double)]) => (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) - } + }.cache() + materialize(gJoinT0) + g.unpersist() + g = gJoinT0 def sendMsgTrainF(conf: Conf, u: Double) (ctx: EdgeContext[ @@ -114,12 +119,15 @@ object SVDPlusPlus { val t1 = g.aggregateMessages[DoubleMatrix]( ctx => ctx.sendToSrc(ctx.dstAttr._2), (g1, g2) => g1.addColumnVector(g2)) - g = g.outerJoinVertices(t1) { + val gJoinT1 = g.outerJoinVertices(t1) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[DoubleMatrix]) => if (msg.isDefined) (vd._1, vd._1 .addColumnVector(msg.get.mul(vd._4)), vd._3, vd._4) else vd - } + }.cache() + materialize(gJoinT1) + g.unpersist() + g = gJoinT1 // Phase 2, update p for user nodes and q, y for item nodes g.cache() @@ -127,13 +135,16 @@ object SVDPlusPlus { sendMsgTrainF(conf, u), (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) => (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3)) - g = g.outerJoinVertices(t2) { + val gJoinT2 = g.outerJoinVertices(t2) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) => (vd._1.addColumnVector(msg.get._1), vd._2.addColumnVector(msg.get._2), vd._3 + msg.get._3, vd._4) - } + }.cache() + materialize(gJoinT2) + g.unpersist() + g = gJoinT2 } // calculate error on training set @@ -147,13 +158,26 @@ object SVDPlusPlus { val err = (ctx.attr - pred) * (ctx.attr - pred) ctx.sendToDst(err) } + g.cache() val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _) - g = g.outerJoinVertices(t3) { + val gJoinT3 = g.outerJoinVertices(t3) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd - } + }.cache() + materialize(gJoinT3) + g.unpersist() + g = gJoinT3 (g, u) } + + /** + * Forces materialization of a Graph by count()ing its RDDs. + */ + private def materialize(g: Graph[_,_]): Unit = { + g.vertices.count() + g.edges.count() + } + } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org