[SPARK-1281] Improve partitioning in ALS

ALS was using HashPartitioner and explicit uses of `%` together.  Further, the 
naked use of `%` meant that, if the number of partitions corresponded with the 
stride of arithmetic progressions appearing in user and product ids, users and 
products could be mapped into buckets in an unfair or unwise way.

This pull request:
1) Makes the Partitioner an instance variable of ALS.
2) Replaces the direct uses of `%` with calls to a Partitioner.
3) Defines an anonymous Partitioner that scrambles the bits of the object's 
hashCode before reducing to the number of present buckets.

This pull request does not make the partitioner user-configurable.

I'm not all that happy about the way I did (1).  It introduces an icky lifetime 
issue and dances around it by nulling something.  However, I don't know a 
better way to make the partitioner visible everywhere it needs to be visible.

Author: Tor Myklebust <tmykl...@gmail.com>

Closes #407 from tmyklebu/master and squashes the following commits:

dcf583a [Tor Myklebust] Remove the partitioner member variable; instead, thread 
that needle everywhere it needs to go.
23d6f91 [Tor Myklebust] Stop making the partitioner configurable.
495784f [Tor Myklebust] Merge branch 'master' of https://github.com/apache/spark
674933a [Tor Myklebust] Fix style.
40edc23 [Tor Myklebust] Fix missing space.
f841345 [Tor Myklebust] Fix daft bug creating 'pairs', also for -> foreach.
5ec9e6c [Tor Myklebust] Clean a couple of things up using 'map'.
36a0f43 [Tor Myklebust] Make the partitioner private.
d872b09 [Tor Myklebust] Add negative id ALS test.
df27697 [Tor Myklebust] Support custom partitioners.  Currently we use the same 
partitioner for users and products.
c90b6d8 [Tor Myklebust] Scramble user and product ids before bucketing.
c774d7d [Tor Myklebust] Make the partitioner a member variable and use it 
instead of modding directly.
(cherry picked from commit bf9d49b6d1f668b49795c2d380ab7d64ec0029da)

Signed-off-by: Patrick Wendell <pwend...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4834adff
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4834adff
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4834adff

Branch: refs/heads/branch-1.0
Commit: 4834adfff26af56094e6bf18d84ca8fe243d8c20
Parents: 4f2f093
Author: Tor Myklebust <tmykl...@gmail.com>
Authored: Tue Apr 22 11:07:30 2014 -0700
Committer: Patrick Wendell <pwend...@gmail.com>
Committed: Tue Apr 22 11:22:30 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/recommendation/ALS.scala | 47 ++++++++++++--------
 .../spark/mllib/recommendation/ALSSuite.scala   | 30 +++++++++++--
 2 files changed, 54 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4834adff/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 1f5c746..60fb73f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, BitSet}
 import scala.math.{abs, sqrt}
 import scala.util.Random
 import scala.util.Sorting
+import scala.util.hashing.byteswap32
 
 import com.esotericsoftware.kryo.Kryo
 import org.jblas.{DoubleMatrix, SimpleBlas, Solve}
@@ -32,6 +33,7 @@ import org.apache.spark.storage.StorageLevel
 import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.KryoRegistrator
 import org.apache.spark.SparkContext._
+import org.apache.spark.util.Utils
 
 /**
  * Out-link information for a user or product block. This includes the 
original user/product IDs
@@ -169,34 +171,39 @@ class ALS private (
       this.numBlocks
     }
 
-    val partitioner = new HashPartitioner(numBlocks)
+    val partitioner = new Partitioner {
+      val numPartitions = numBlocks
 
-    val ratingsByUserBlock = ratings.map{ rating => (rating.user % numBlocks, 
rating) }
+      def getPartition(x: Any): Int = {
+        Utils.nonNegativeMod(byteswap32(x.asInstanceOf[Int]), numPartitions)
+      }
+    }
+
+    val ratingsByUserBlock = ratings.map{ rating =>
+      (partitioner.getPartition(rating.user), rating)
+    }
     val ratingsByProductBlock = ratings.map{ rating =>
-      (rating.product % numBlocks, Rating(rating.product, rating.user, 
rating.rating))
+      (partitioner.getPartition(rating.product),
+        Rating(rating.product, rating.user, rating.rating))
     }
 
-    val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, 
ratingsByUserBlock)
-    val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, 
ratingsByProductBlock)
+    val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, 
ratingsByUserBlock, partitioner)
+    val (productInLinks, productOutLinks) =
+        makeLinkRDDs(numBlocks, ratingsByProductBlock, partitioner)
 
     // Initialize user and product factors randomly, but use a deterministic 
seed for each
     // partition so that fault recovery works
     val seedGen = new Random(seed)
     val seed1 = seedGen.nextInt()
     val seed2 = seedGen.nextInt()
-    // Hash an integer to propagate random bits at all positions, similar to 
java.util.HashTable
-    def hash(x: Int): Int = {
-      val r = x ^ (x >>> 20) ^ (x >>> 12)
-      r ^ (r >>> 7) ^ (r >>> 4)
-    }
     var users = userOutLinks.mapPartitionsWithIndex { (index, itr) =>
-      val rand = new Random(hash(seed1 ^ index))
+      val rand = new Random(byteswap32(seed1 ^ index))
       itr.map { case (x, y) =>
         (x, y.elementIds.map(_ => randomFactor(rank, rand)))
       }
     }
     var products = productOutLinks.mapPartitionsWithIndex { (index, itr) =>
-      val rand = new Random(hash(seed2 ^ index))
+      val rand = new Random(byteswap32(seed2 ^ index))
       itr.map { case (x, y) =>
         (x, y.elementIds.map(_ => randomFactor(rank, rand)))
       }
@@ -327,13 +334,14 @@ class ALS private (
    * Make the out-links table for a block of the users (or products) dataset 
given the list of
    * (user, product, rating) values for the users in that block (or the 
opposite for products).
    */
-  private def makeOutLinkBlock(numBlocks: Int, ratings: Array[Rating]): 
OutLinkBlock = {
+  private def makeOutLinkBlock(numBlocks: Int, ratings: Array[Rating],
+      partitioner: Partitioner): OutLinkBlock = {
     val userIds = ratings.map(_.user).distinct.sorted
     val numUsers = userIds.length
     val userIdToPos = userIds.zipWithIndex.toMap
     val shouldSend = Array.fill(numUsers)(new BitSet(numBlocks))
     for (r <- ratings) {
-      shouldSend(userIdToPos(r.user))(r.product % numBlocks) = true
+      shouldSend(userIdToPos(r.user))(partitioner.getPartition(r.product)) = 
true
     }
     OutLinkBlock(userIds, shouldSend)
   }
@@ -342,14 +350,15 @@ class ALS private (
    * Make the in-links table for a block of the users (or products) dataset 
given a list of
    * (user, product, rating) values for the users in that block (or the 
opposite for products).
    */
-  private def makeInLinkBlock(numBlocks: Int, ratings: Array[Rating]): 
InLinkBlock = {
+  private def makeInLinkBlock(numBlocks: Int, ratings: Array[Rating],
+      partitioner: Partitioner): InLinkBlock = {
     val userIds = ratings.map(_.user).distinct.sorted
     val numUsers = userIds.length
     val userIdToPos = userIds.zipWithIndex.toMap
     // Split out our ratings by product block
     val blockRatings = Array.fill(numBlocks)(new ArrayBuffer[Rating])
     for (r <- ratings) {
-      blockRatings(r.product % numBlocks) += r
+      blockRatings(partitioner.getPartition(r.product)) += r
     }
     val ratingsForBlock = new Array[Array[(Array[Int], 
Array[Double])]](numBlocks)
     for (productBlock <- 0 until numBlocks) {
@@ -374,14 +383,14 @@ class ALS private (
    * the users (or (blockId, (p, u, r)) for the products). We create these 
simultaneously to avoid
    * having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it.
    */
-  private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, Rating)])
+  private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, Rating)], 
partitioner: Partitioner)
     : (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) =
   {
     val grouped = ratings.partitionBy(new HashPartitioner(numBlocks))
     val links = grouped.mapPartitionsWithIndex((blockId, elements) => {
       val ratings = elements.map{_._2}.toArray
-      val inLinkBlock = makeInLinkBlock(numBlocks, ratings)
-      val outLinkBlock = makeOutLinkBlock(numBlocks, ratings)
+      val inLinkBlock = makeInLinkBlock(numBlocks, ratings, partitioner)
+      val outLinkBlock = makeOutLinkBlock(numBlocks, ratings, partitioner)
       Iterator.single((blockId, (inLinkBlock, outLinkBlock)))
     }, true)
     val inLinks = links.mapValues(_._1)

http://git-wip-us.apache.org/repos/asf/spark/blob/4834adff/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index 5aab9ab..4dfcd4b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -27,6 +27,7 @@ import org.jblas.DoubleMatrix
 
 import org.apache.spark.mllib.util.LocalSparkContext
 import org.apache.spark.SparkContext._
+import org.apache.spark.Partitioner
 
 object ALSSuite {
 
@@ -74,7 +75,6 @@ object ALSSuite {
 
     (sampledRatings, trueRatings, truePrefs)
   }
-
 }
 
 
@@ -128,6 +128,25 @@ class ALSSuite extends FunSuite with LocalSparkContext {
     assert(u11 != u2)
   }
 
+  test("negative ids") {
+    val data = ALSSuite.generateRatings(50, 50, 2, 0.7, false, false)
+    val ratings = sc.parallelize(data._1.map { case Rating(u, p, r) =>
+      Rating(u - 25, p - 25, r)
+    })
+    val correct = data._2
+    val model = ALS.train(ratings, 5, 15)
+
+    val pairs = Array.tabulate(50, 50)((u, p) => (u - 25, p - 25)).flatten
+    val ans = model.predict(sc.parallelize(pairs)).collect()
+    ans.foreach { r =>
+      val u = r.user + 25
+      val p = r.product + 25
+      val v = r.rating
+      val error = v - correct.get(u, p)
+      assert(math.abs(error) < 0.4)
+    }
+  }
+
   /**
    * Test if we can correctly factorize R = U * P where U and P are of known 
rank.
    *
@@ -140,16 +159,19 @@ class ALSSuite extends FunSuite with LocalSparkContext {
    * @param implicitPrefs  flag to test implicit feedback
    * @param bulkPredict    flag to test bulk prediciton
    * @param negativeWeights whether the generated data can contain negative 
values
+   * @param numBlocks      number of blocks to partition users and products 
into
    */
   def testALS(users: Int, products: Int, features: Int, iterations: Int,
     samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = 
false,
-    bulkPredict: Boolean = false, negativeWeights: Boolean = false)
+    bulkPredict: Boolean = false, negativeWeights: Boolean = false, numBlocks: 
Int = -1)
   {
     val (sampledRatings, trueRatings, truePrefs) = 
ALSSuite.generateRatings(users, products,
       features, samplingRate, implicitPrefs, negativeWeights)
     val model = implicitPrefs match {
-      case false => ALS.train(sc.parallelize(sampledRatings), features, 
iterations)
-      case true => ALS.trainImplicit(sc.parallelize(sampledRatings), features, 
iterations)
+      case false => ALS.train(sc.parallelize(sampledRatings), features, 
iterations, 0.01,
+          numBlocks, 0L)
+      case true => ALS.trainImplicit(sc.parallelize(sampledRatings), features, 
iterations, 0.01,
+          numBlocks, 1.0, 0L)
     }
 
     val predictedU = new DoubleMatrix(users, features)

Reply via email to