MAHOUT-1701: Flink: bug with AtB

Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/137b3b84
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/137b3b84
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/137b3b84

Branch: refs/heads/flink-binding
Commit: 137b3b840eff74f0833d5678a55af2dde4a316a7
Parents: 19708f4
Author: Alexey Grigorev <[email protected]>
Authored: Thu Aug 27 15:55:01 2015 +0200
Committer: Alexey Grigorev <[email protected]>
Committed: Fri Sep 25 17:47:46 2015 +0200

----------------------------------------------------------------------
 .../mahout/flinkbindings/blas/FlinkOpAtB.scala  | 28 ++++++++------------
 .../mahout/flinkbindings/RLikeOpsSuite.scala    | 10 +++++++
 2 files changed, 21 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/137b3b84/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
----------------------------------------------------------------------
diff --git 
a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala 
b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
index ebb1064..362c62f 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
@@ -57,12 +57,12 @@ object FlinkOpAtB {
     val joined = rowsAt.join(rowsB).where(joiner).equalTo(joiner)
 
     val ncol = op.ncol
-    val nrow = op.nrow
+    val nrow = op.nrow.toInt
     val blockHeight = 10
-    val blockCount = safeToNonNegInt((ncol - 1) / blockHeight + 1)
+    val blockCount = safeToNonNegInt((nrow - 1) / blockHeight + 1)
 
-    val preProduct: DataSet[(Int, Matrix)] = joined.flatMap(new 
FlatMapFunction[Tuple2[(_, Vector), (_, Vector)], 
-                                                        (Int, Matrix)] {
+    val preProduct: DataSet[(Int, Matrix)] = 
+             joined.flatMap(new FlatMapFunction[Tuple2[(_, Vector), (_, 
Vector)], (Int, Matrix)] {
       def flatMap(in: Tuple2[(_, Vector), (_, Vector)],
                   out: Collector[(Int, Matrix)]): Unit = {
         val avec = in.f0._2
@@ -70,35 +70,29 @@ object FlinkOpAtB {
 
         0.until(blockCount) map { blockKey =>
           val blockStart = blockKey * blockHeight
-          val blockEnd = Math.min(nrow.toInt, blockStart + blockHeight)
+          val blockEnd = Math.min(nrow, blockStart + blockHeight)
 
-          // Create block by cross product of proper slice of aRow and qRow
           val outer = avec(blockStart until blockEnd) cross bvec
-          out.collect((blockKey, outer))
+          out.collect(blockKey -> outer)
         }
       }
     })
 
-    val res: BlockifiedDrmDataSet[Int] = preProduct.groupBy(selector[Matrix, 
Int]).reduceGroup(
-            new GroupReduceFunction[(Int, Matrix), BlockifiedDrmTuple[Int]] {
+    val res: BlockifiedDrmDataSet[Int] = 
+      preProduct.groupBy(selector[Matrix, Int])
+                .reduceGroup(new GroupReduceFunction[(Int, Matrix), 
BlockifiedDrmTuple[Int]] {
       def reduce(values: Iterable[(Int, Matrix)], out: 
Collector[BlockifiedDrmTuple[Int]]): Unit = {
         val it = Lists.newArrayList(values).asScala
         val (idx, _) = it.head
 
-        val block = it.map(t => t._2).reduce((m1, m2) => m1 + m2)
+        val block = it.map { t => t._2 }.reduce { (m1, m2) => m1 + m2 }
 
         val keys = idx.until(block.nrow).toArray[Int]
-        out.collect((keys, block))
+        out.collect(keys -> block)
       }
     })
 
     new BlockifiedFlinkDrm(res, ncol)
   }
 
-}
-
-class DrmTupleToFlinkTupleMapper[K: ClassTag] extends MapFunction[(K, Vector), 
Tuple2[Int, Vector]] {
-  def map(tuple: (K, Vector)): Tuple2[Int, Vector] = tuple match {
-    case (key, vec) => new Tuple2[Int, Vector](key.asInstanceOf[Int], vec)
-  }
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/137b3b84/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
----------------------------------------------------------------------
diff --git 
a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala 
b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
index 225a956..98318e3 100644
--- a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
@@ -115,6 +115,16 @@ class RLikeOpsSuite extends FunSuite with 
DistributedFlinkSuite {
     assert((res.collect - expected).norm < 1e-6)
   }
 
+  test("A %*% B.t test 2") {
+    val mxA = Matrices.symmetricUniformView(10, 7, 80085)
+    val mxB = Matrices.symmetricUniformView(30, 7, 31337)
+    val A = drmParallelize(mxA, 3)
+    val B = drmParallelize(mxB, 4)
+
+    val ABt = (A %*% B.t).collect
+    (ABt - mxA %*% mxB.t).norm should be < 1e-7
+  }
+
   test("ABt test") {
     val mxX = dense((1, 2), (2, 3), (3, 4), (5, 6), (7, 8))
     val mxY = dense((1, 2), (2, 3), (3, 4), (5, 6), (7, 8),

Reply via email to