zhengruifeng commented on a change in pull request #30468:
URL: https://github.com/apache/spark/pull/30468#discussion_r529220280



##########
File path: mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
##########
@@ -456,34 +457,42 @@ class ALSModel private[ml] (
       num: Int,
       blockSize: Int): DataFrame = {
     import srcFactors.sparkSession.implicits._
+    import ALSModel.TopSelector
 
     val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])], 
blockSize)
     val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])], 
blockSize)
-    val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
-      .as[(Seq[(Int, Array[Float])], Seq[(Int, Array[Float])])]
-      .flatMap { case (srcIter, dstIter) =>
-        val m = srcIter.size
-        val n = math.min(dstIter.size, num)
-        val output = new Array[(Int, Int, Float)](m * n)
-        var i = 0
-        val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2))
-        srcIter.foreach { case (srcId, srcFactor) =>
-          dstIter.foreach { case (dstId, dstFactor) =>
-            // We use F2jBLAS which is faster than a call to native BLAS for 
vector dot product
-            val score = BLAS.f2jBLAS.sdot(rank, srcFactor, 1, dstFactor, 1)
-            pq += dstId -> score
+    val partialRecs = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
+      .as[(Array[Int], Array[Float], Array[Int], Array[Float])]
+      .mapPartitions { iter =>
+        var buffer: Array[Float] = null
+        var selector: TopSelector = null
+        iter.flatMap { case (srcIds, srcMat, dstIds, dstMat) =>
+          require(srcMat.length == srcIds.length * rank)
+          require(dstMat.length == dstIds.length * rank)
+          val m = srcIds.length
+          val n = dstIds.length
+          if (buffer == null || buffer.length < n) {
+            buffer = Array.ofDim[Float](n)
+            selector = new TopSelector(buffer)
           }
-          pq.foreach { case (dstId, score) =>
-            output(i) = (srcId, dstId, score)
-            i += 1
+
+          Iterator.tabulate(m) { i =>
+            // buffer = i-th vec in srcMat * dstMat
+            BLAS.f2jBLAS.sgemv("T", rank, n, 1.0F, dstMat, 0, rank,
+              srcMat, i * rank, 1, 0.0F, buffer, 0, 1)
+            val indices = selector.selectTopKIndices(Iterator.range(0, n), num)
+            (srcIds(i), indices.map(dstIds), indices.map(buffer))
           }
-          pq.clear()
+        } ++ {

Review comment:
       it is to make sure `buffer` is marked ready for GC, but it does matter, 
I will remove it.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to