You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by an...@apache.org on 2015/02/14 05:13:15 UTC

spark git commit: SPARK-3290 [GRAPHX] No unpersist callls in SVDPlusPlus

Repository: spark
Updated Branches:
  refs/heads/master d06d5ee9b -> 0ce4e430a


SPARK-3290 [GRAPHX] No unpersist callls in SVDPlusPlus

This just unpersist()s each RDD in this code that was cache()ed.

Author: Sean Owen <so...@cloudera.com>

Closes #4234 from srowen/SPARK-3290 and squashes the following commits:

66c1e11 [Sean Owen] unpersist() each RDD that was cache()ed


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0ce4e430
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0ce4e430
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0ce4e430

Branch: refs/heads/master
Commit: 0ce4e430a81532dc317136f968f28742e087d840
Parents: d06d5ee
Author: Sean Owen <so...@cloudera.com>
Authored: Fri Feb 13 20:12:52 2015 -0800
Committer: Ankur Dave <an...@gmail.com>
Committed: Fri Feb 13 20:12:52 2015 -0800

----------------------------------------------------------------------
 .../apache/spark/graphx/lib/SVDPlusPlus.scala   | 40 ++++++++++++++++----
 1 file changed, 32 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0ce4e430/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
----------------------------------------------------------------------
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
index f58587e..112ed09 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
@@ -72,17 +72,22 @@ object SVDPlusPlus {
 
     // construct graph
     var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache()
+    materialize(g)
+    edges.unpersist()
 
     // Calculate initial bias and norm
     val t0 = g.aggregateMessages[(Long, Double)](
       ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) },
       (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2))
 
-    g = g.outerJoinVertices(t0) {
+    val gJoinT0 = g.outerJoinVertices(t0) {
       (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
        msg: Option[(Long, Double)]) =>
         (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1))
-    }
+    }.cache()
+    materialize(gJoinT0)
+    g.unpersist()
+    g = gJoinT0
 
     def sendMsgTrainF(conf: Conf, u: Double)
         (ctx: EdgeContext[
@@ -114,12 +119,15 @@ object SVDPlusPlus {
       val t1 = g.aggregateMessages[DoubleMatrix](
         ctx => ctx.sendToSrc(ctx.dstAttr._2),
         (g1, g2) => g1.addColumnVector(g2))
-      g = g.outerJoinVertices(t1) {
+      val gJoinT1 = g.outerJoinVertices(t1) {
         (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
          msg: Option[DoubleMatrix]) =>
           if (msg.isDefined) (vd._1, vd._1
             .addColumnVector(msg.get.mul(vd._4)), vd._3, vd._4) else vd
-      }
+      }.cache()
+      materialize(gJoinT1)
+      g.unpersist()
+      g = gJoinT1
 
       // Phase 2, update p for user nodes and q, y for item nodes
       g.cache()
@@ -127,13 +135,16 @@ object SVDPlusPlus {
         sendMsgTrainF(conf, u),
         (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) =>
           (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3))
-      g = g.outerJoinVertices(t2) {
+      val gJoinT2 = g.outerJoinVertices(t2) {
         (vid: VertexId,
          vd: (DoubleMatrix, DoubleMatrix, Double, Double),
          msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) =>
           (vd._1.addColumnVector(msg.get._1), vd._2.addColumnVector(msg.get._2),
             vd._3 + msg.get._3, vd._4)
-      }
+      }.cache()
+      materialize(gJoinT2)
+      g.unpersist()
+      g = gJoinT2
     }
 
     // calculate error on training set
@@ -147,13 +158,26 @@ object SVDPlusPlus {
       val err = (ctx.attr - pred) * (ctx.attr - pred)
       ctx.sendToDst(err)
     }
+
     g.cache()
     val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _)
-    g = g.outerJoinVertices(t3) {
+    val gJoinT3 = g.outerJoinVertices(t3) {
       (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) =>
         if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd
-    }
+    }.cache()
+    materialize(gJoinT3)
+    g.unpersist()
+    g = gJoinT3
 
     (g, u)
   }
+
+  /**
+   * Forces materialization of a Graph by count()ing its RDDs.
+   */
+  private def materialize(g: Graph[_,_]): Unit = {
+    g.vertices.count()
+    g.edges.count()
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org