MAHOUT-1701: Flink: bug with AtB
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/137b3b84 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/137b3b84 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/137b3b84 Branch: refs/heads/flink-binding Commit: 137b3b840eff74f0833d5678a55af2dde4a316a7 Parents: 19708f4 Author: Alexey Grigorev <[email protected]> Authored: Thu Aug 27 15:55:01 2015 +0200 Committer: Alexey Grigorev <[email protected]> Committed: Fri Sep 25 17:47:46 2015 +0200 ---------------------------------------------------------------------- .../mahout/flinkbindings/blas/FlinkOpAtB.scala | 28 ++++++++------------ .../mahout/flinkbindings/RLikeOpsSuite.scala | 10 +++++++ 2 files changed, 21 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/137b3b84/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala index ebb1064..362c62f 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala @@ -57,12 +57,12 @@ object FlinkOpAtB { val joined = rowsAt.join(rowsB).where(joiner).equalTo(joiner) val ncol = op.ncol - val nrow = op.nrow + val nrow = op.nrow.toInt val blockHeight = 10 - val blockCount = safeToNonNegInt((ncol - 1) / blockHeight + 1) + val blockCount = safeToNonNegInt((nrow - 1) / blockHeight + 1) - val preProduct: DataSet[(Int, Matrix)] = joined.flatMap(new FlatMapFunction[Tuple2[(_, Vector), (_, Vector)], - (Int, Matrix)] { + val preProduct: DataSet[(Int, Matrix)] = + joined.flatMap(new FlatMapFunction[Tuple2[(_, Vector), (_, Vector)], (Int, Matrix)] { def flatMap(in: Tuple2[(_, Vector), (_, Vector)], out: Collector[(Int, Matrix)]): Unit = { val avec = in.f0._2 @@ -70,35 +70,29 @@ object FlinkOpAtB { 0.until(blockCount) map { blockKey => val blockStart = blockKey * blockHeight - val blockEnd = Math.min(nrow.toInt, blockStart + blockHeight) + val blockEnd = Math.min(nrow, blockStart + blockHeight) - // Create block by cross product of proper slice of aRow and qRow val outer = avec(blockStart until blockEnd) cross bvec - out.collect((blockKey, outer)) + out.collect(blockKey -> outer) } } }) - val res: BlockifiedDrmDataSet[Int] = preProduct.groupBy(selector[Matrix, Int]).reduceGroup( - new GroupReduceFunction[(Int, Matrix), BlockifiedDrmTuple[Int]] { + val res: BlockifiedDrmDataSet[Int] = + preProduct.groupBy(selector[Matrix, Int]) + .reduceGroup(new GroupReduceFunction[(Int, Matrix), BlockifiedDrmTuple[Int]] { def reduce(values: Iterable[(Int, Matrix)], out: Collector[BlockifiedDrmTuple[Int]]): Unit = { val it = Lists.newArrayList(values).asScala val (idx, _) = it.head - val block = it.map(t => t._2).reduce((m1, m2) => m1 + m2) + val block = it.map { t => t._2 }.reduce { (m1, m2) => m1 + m2 } val keys = idx.until(block.nrow).toArray[Int] - out.collect((keys, block)) + out.collect(keys -> block) } }) new BlockifiedFlinkDrm(res, ncol) } -} - -class DrmTupleToFlinkTupleMapper[K: ClassTag] extends MapFunction[(K, Vector), Tuple2[Int, Vector]] { - def map(tuple: (K, Vector)): Tuple2[Int, Vector] = tuple match { - case (key, vec) => new Tuple2[Int, Vector](key.asInstanceOf[Int], vec) - } } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/137b3b84/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala ---------------------------------------------------------------------- diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala index 225a956..98318e3 100644 --- a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala +++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala @@ -115,6 +115,16 @@ class RLikeOpsSuite extends FunSuite with DistributedFlinkSuite { assert((res.collect - expected).norm < 1e-6) } + test("A %*% B.t test 2") { + val mxA = Matrices.symmetricUniformView(10, 7, 80085) + val mxB = Matrices.symmetricUniformView(30, 7, 31337) + val A = drmParallelize(mxA, 3) + val B = drmParallelize(mxB, 4) + + val ABt = (A %*% B.t).collect + (ABt - mxA %*% mxB.t).norm should be < 1e-7 + } + test("ABt test") { val mxX = dense((1, 2), (2, 3), (3, 4), (5, 6), (7, 8)) val mxY = dense((1, 2), (2, 3), (3, 4), (5, 6), (7, 8),
