Repository: spark
Updated Branches:
  refs/heads/master d9670f847 -> 91984978e


[SPARK-13816][GRAPHX] Add parameter checks for algorithms in Graphx

JIRA: https://issues.apache.org/jira/browse/SPARK-13816

## What changes were proposed in this pull request?

Add parameter checks for algorithms in Graphx: 
Pregel,LabelPropagation,PageRank,SVDPlusPlus

## How was this patch tested?

manual tests

Author: Zheng RuiFeng <[email protected]>

Closes #11655 from zhengruifeng/graphx_param_check.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/91984978
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/91984978
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/91984978

Branch: refs/heads/master
Commit: 91984978e7c86650d4cf523724b4a1aeaaecf260
Parents: d9670f8
Author: Zheng RuiFeng <[email protected]>
Authored: Wed Mar 16 11:52:25 2016 -0700
Committer: Reynold Xin <[email protected]>
Committed: Wed Mar 16 11:52:25 2016 -0700

----------------------------------------------------------------------
 graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala  | 3 +++
 .../org/apache/spark/graphx/lib/ConnectedComponents.scala   | 4 +++-
 .../org/apache/spark/graphx/lib/LabelPropagation.scala      | 2 ++
 .../main/scala/org/apache/spark/graphx/lib/PageRank.scala   | 9 +++++++++
 .../scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala     | 5 +++++
 .../spark/graphx/lib/StronglyConnectedComponents.scala      | 4 +++-
 6 files changed, 25 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/91984978/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
----------------------------------------------------------------------
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala 
b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
index 3ba73b4..efdc248 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -119,6 +119,9 @@ object Pregel extends Logging {
       mergeMsg: (A, A) => A)
     : Graph[VD, ED] =
   {
+    require(maxIterations > 0, s"Maximum of iterations must be greater than 
0," +
+      s" but got ${maxIterations}")
+
     var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, 
initialMsg)).cache()
     // compute the messages
     var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg)

http://git-wip-us.apache.org/repos/asf/spark/blob/91984978/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
----------------------------------------------------------------------
diff --git 
a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala 
b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
index 40cf073..137c512 100644
--- 
a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
+++ 
b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
@@ -36,7 +36,9 @@ object ConnectedComponents {
    */
   def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED],
                                       maxIterations: Int): Graph[VertexId, ED] 
= {
-    require(maxIterations > 0)
+    require(maxIterations > 0, s"Maximum of iterations must be greater than 
0," +
+      s" but got ${maxIterations}")
+
     val ccGraph = graph.mapVertices { case (vid, _) => vid }
     def sendMessage(edge: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, 
VertexId)] = {
       if (edge.srcAttr < edge.dstAttr) {

http://git-wip-us.apache.org/repos/asf/spark/blob/91984978/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala
----------------------------------------------------------------------
diff --git 
a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala 
b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala
index 7a53eca..fc7547a 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala
@@ -43,6 +43,8 @@ object LabelPropagation {
    * @return a graph with vertex attributes containing the label of community 
affiliation
    */
   def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): 
Graph[VertexId, ED] = {
+    require(maxSteps > 0, s"Maximum of steps must be greater than 0, but got 
${maxSteps}")
+
     val lpaGraph = graph.mapVertices { case (vid, _) => vid }
     def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, 
Map[VertexId, Long])] = {
       Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 
1L)))

http://git-wip-us.apache.org/repos/asf/spark/blob/91984978/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
----------------------------------------------------------------------
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala 
b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
index 00ba358..9d9a26e 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
@@ -104,6 +104,11 @@ object PageRank extends Logging {
       graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15,
       srcId: Option[VertexId] = None): Graph[Double, Double] =
   {
+    require(numIter > 0, s"Number of iterations must be greater than 0," +
+      s" but got ${numIter}")
+    require(resetProb >= 0 && resetProb <= 1, s"Random reset probability must 
belong" +
+      s" to [0, 1], but got ${resetProb}")
+
     val personalized = srcId isDefined
     val src: VertexId = srcId.getOrElse(-1L)
 
@@ -197,6 +202,10 @@ object PageRank extends Logging {
       graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15,
       srcId: Option[VertexId] = None): Graph[Double, Double] =
   {
+    require(tol >= 0, s"Tolerance must be no less than 0, but got ${tol}")
+    require(resetProb >= 0 && resetProb <= 1, s"Random reset probability must 
belong" +
+      s" to [0, 1], but got ${resetProb}")
+
     val personalized = srcId.isDefined
     val src: VertexId = srcId.getOrElse(-1L)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/91984978/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
----------------------------------------------------------------------
diff --git 
a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala 
b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
index 78a5cb0..bb2ffab 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
@@ -56,6 +56,11 @@ object SVDPlusPlus {
   def run(edges: RDD[Edge[Double]], conf: Conf)
     : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) =
   {
+    require(conf.maxIters > 0, s"Maximum of iterations must be greater than 
0," +
+      s" but got ${conf.maxIters}")
+    require(conf.maxVal > conf.minVal, s"MaxVal must be greater than MinVal," +
+      s" but got {maxVal: ${conf.maxVal}, minVal: ${conf.minVal}}")
+
     // Generate default vertex attribute
     def defaultF(rank: Int): (Array[Double], Array[Double], Double, Double) = {
       // TODO: use a fixed random seed

http://git-wip-us.apache.org/repos/asf/spark/blob/91984978/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala
----------------------------------------------------------------------
diff --git 
a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala
 
b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala
index 7063137..1fa92b0 100644
--- 
a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala
+++ 
b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala
@@ -36,7 +36,9 @@ object StronglyConnectedComponents {
    * @return a graph with vertex attributes containing the smallest vertex id 
in each SCC
    */
   def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int): 
Graph[VertexId, ED] = {
-    require(numIter > 0, s"Number of iterations ${numIter} must be greater 
than 0.")
+    require(numIter > 0, s"Number of iterations must be greater than 0," +
+      s" but got ${numIter}")
+
     // the graph we update with final SCC ids, and the graph we return at the 
end
     var sccGraph = graph.mapVertices { case (vid, _) => vid }
     // graph we are going to work with in our iterations


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to