Skip to content
Snippets Groups Projects
Commit db574792 authored by Sean Owen's avatar Sean Owen Committed by Ankur Dave
Browse files

SPARK-3290 [GRAPHX] No unpersist callls in SVDPlusPlus


This just unpersist()s each RDD in this code that was cache()ed.

Author: Sean Owen <sowen@cloudera.com>

Closes #4234 from srowen/SPARK-3290 and squashes the following commits:

66c1e11 [Sean Owen] unpersist() each RDD that was cache()ed

(cherry picked from commit 0ce4e430)
Signed-off-by: default avatarAnkur Dave <ankurdave@gmail.com>
parent 152147f5
No related branches found
No related tags found
No related merge requests found
...@@ -72,17 +72,22 @@ object SVDPlusPlus { ...@@ -72,17 +72,22 @@ object SVDPlusPlus {
// construct graph // construct graph
var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache() var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache()
materialize(g)
edges.unpersist()
// Calculate initial bias and norm // Calculate initial bias and norm
val t0 = g.aggregateMessages[(Long, Double)]( val t0 = g.aggregateMessages[(Long, Double)](
ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) }, ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) },
(g1, g2) => (g1._1 + g2._1, g1._2 + g2._2)) (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2))
g = g.outerJoinVertices(t0) { val gJoinT0 = g.outerJoinVertices(t0) {
(vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
msg: Option[(Long, Double)]) => msg: Option[(Long, Double)]) =>
(vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1))
} }.cache()
materialize(gJoinT0)
g.unpersist()
g = gJoinT0
def sendMsgTrainF(conf: Conf, u: Double) def sendMsgTrainF(conf: Conf, u: Double)
(ctx: EdgeContext[ (ctx: EdgeContext[
...@@ -114,12 +119,15 @@ object SVDPlusPlus { ...@@ -114,12 +119,15 @@ object SVDPlusPlus {
val t1 = g.aggregateMessages[DoubleMatrix]( val t1 = g.aggregateMessages[DoubleMatrix](
ctx => ctx.sendToSrc(ctx.dstAttr._2), ctx => ctx.sendToSrc(ctx.dstAttr._2),
(g1, g2) => g1.addColumnVector(g2)) (g1, g2) => g1.addColumnVector(g2))
g = g.outerJoinVertices(t1) { val gJoinT1 = g.outerJoinVertices(t1) {
(vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
msg: Option[DoubleMatrix]) => msg: Option[DoubleMatrix]) =>
if (msg.isDefined) (vd._1, vd._1 if (msg.isDefined) (vd._1, vd._1
.addColumnVector(msg.get.mul(vd._4)), vd._3, vd._4) else vd .addColumnVector(msg.get.mul(vd._4)), vd._3, vd._4) else vd
} }.cache()
materialize(gJoinT1)
g.unpersist()
g = gJoinT1
// Phase 2, update p for user nodes and q, y for item nodes // Phase 2, update p for user nodes and q, y for item nodes
g.cache() g.cache()
...@@ -127,13 +135,16 @@ object SVDPlusPlus { ...@@ -127,13 +135,16 @@ object SVDPlusPlus {
sendMsgTrainF(conf, u), sendMsgTrainF(conf, u),
(g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) => (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) =>
(g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3)) (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3))
g = g.outerJoinVertices(t2) { val gJoinT2 = g.outerJoinVertices(t2) {
(vid: VertexId, (vid: VertexId,
vd: (DoubleMatrix, DoubleMatrix, Double, Double), vd: (DoubleMatrix, DoubleMatrix, Double, Double),
msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) => msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) =>
(vd._1.addColumnVector(msg.get._1), vd._2.addColumnVector(msg.get._2), (vd._1.addColumnVector(msg.get._1), vd._2.addColumnVector(msg.get._2),
vd._3 + msg.get._3, vd._4) vd._3 + msg.get._3, vd._4)
} }.cache()
materialize(gJoinT2)
g.unpersist()
g = gJoinT2
} }
// calculate error on training set // calculate error on training set
...@@ -147,13 +158,26 @@ object SVDPlusPlus { ...@@ -147,13 +158,26 @@ object SVDPlusPlus {
val err = (ctx.attr - pred) * (ctx.attr - pred) val err = (ctx.attr - pred) * (ctx.attr - pred)
ctx.sendToDst(err) ctx.sendToDst(err)
} }
g.cache() g.cache()
val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _) val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _)
g = g.outerJoinVertices(t3) { val gJoinT3 = g.outerJoinVertices(t3) {
(vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) => (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) =>
if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd
} }.cache()
materialize(gJoinT3)
g.unpersist()
g = gJoinT3
(g, u) (g, u)
} }
/**
* Forces materialization of a Graph by count()ing its RDDs.
*/
private def materialize(g: Graph[_,_]): Unit = {
g.vertices.count()
g.edges.count()
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment