MAHOUT-1710: Flink: A times incoreB operator

Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/522f3d51
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/522f3d51
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/522f3d51

Branch: refs/heads/flink-binding
Commit: 522f3d516cefd44272cdbf5b4b4f52574200dd6c
Parents: de7a75f
Author: Alexey Grigorev <[email protected]>
Authored: Tue May 26 15:51:36 2015 +0200
Committer: Alexey Grigorev <[email protected]>
Committed: Fri Sep 25 17:41:44 2015 +0200

----------------------------------------------------------------------
 .../mahout/flinkbindings/FlinkEngine.scala      |  2 +
 .../blas/FlinkOpTimesRightMatrix.scala          | 40 ++++++++++++++++++++
 .../mahout/flinkbindings/RLikeOpsSuite.scala    | 12 ++++++
 .../mahout/flinkbindings/blas/LATestSuit.scala  | 17 ++++++++-
 4 files changed, 70 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/522f3d51/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
----------------------------------------------------------------------
diff --git 
a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala 
b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
index a7082d1..03d1a9c 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
@@ -97,6 +97,8 @@ object FlinkEngine extends DistributedEngine {
       val aTranslated = flinkTranslate(aInt)
       FlinkOpAtB.notZippable(opAtB, aTranslated, aTranslated)
     }
+    case op @ OpTimesRightMatrix(a, b) => 
+      FlinkOpTimesRightMatrix.drmTimesInCore(op, 
flinkTranslate(a)(op.classTagA), b)
     case op @ OpAewScalar(a, scalar, _) => 
       FlinkOpAewScalar.opScalarNoSideEffect(op, 
flinkTranslate(a)(op.classTagA), scalar)
     case op @ OpAewB(a, b, _) =>

http://git-wip-us.apache.org/repos/asf/mahout/blob/522f3d51/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpTimesRightMatrix.scala
----------------------------------------------------------------------
diff --git 
a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpTimesRightMatrix.scala
 
b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpTimesRightMatrix.scala
new file mode 100644
index 0000000..e26ee7d
--- /dev/null
+++ 
b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpTimesRightMatrix.scala
@@ -0,0 +1,40 @@
+package org.apache.mahout.flinkbindings.blas
+
+import scala.reflect.ClassTag
+import org.apache.mahout.math.drm.logical.OpTimesRightMatrix
+import org.apache.mahout.flinkbindings.drm.FlinkDrm
+import org.apache.mahout.math.Matrix
+import org.apache.mahout.math.DiagonalMatrix
+import org.apache.flink.api.common.functions.RichMapFunction
+import org.apache.flink.configuration.Configuration
+import org.apache.mahout.flinkbindings.drm.BlockifiedFlinkDrm
+import org.apache.mahout.math._
+import scalabindings._
+import RLikeOps._
+
+object FlinkOpTimesRightMatrix {
+
+  def drmTimesInCore[K: ClassTag](op: OpTimesRightMatrix[K], A: FlinkDrm[K], 
inCoreB: Matrix): FlinkDrm[K] = {
+    implicit val ctx = A.context
+
+    val singletonDataSetB = ctx.env.fromElements(inCoreB)
+
+    val res = A.blockify.ds.map(new RichMapFunction[(Array[K], Matrix), 
(Array[K], Matrix)] {
+      var inCoreB: Matrix = null
+
+      override def open(params: Configuration): Unit = {
+        val runtime = this.getRuntimeContext()
+        val dsB: java.util.List[Matrix] = 
runtime.getBroadcastVariable("matrix")
+        inCoreB = dsB.get(0)
+      }
+
+      override def map(tuple: (Array[K], Matrix)): (Array[K], Matrix) = tuple 
match {
+        case (keys, block_A) => (keys, block_A %*% inCoreB)
+      }
+
+    }).withBroadcastSet(singletonDataSetB, "matrix")
+
+    new BlockifiedFlinkDrm(res, op.ncol)
+  }
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/522f3d51/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
----------------------------------------------------------------------
diff --git 
a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala 
b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
index 835cf68..a0da308 100644
--- a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
@@ -221,4 +221,16 @@ class RLikeOpsSuite extends FunSuite with 
DistributedFlinkSuit {
     assert((res.collect - expected).norm < 1e-6)
   }
 
+  test("A %*% inCoreB") {
+    val inCoreA = dense((1, 2), (2, 3), (3, 4)).t
+    val inCoreB = dense((1, 2), (3, 4), (11, 4))
+
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+
+    val res = A %*% inCoreB
+
+    val expected = inCoreA %*% inCoreB
+    assert((res.collect - expected).norm < 1e-6)
+  }
+
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/522f3d51/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala
----------------------------------------------------------------------
diff --git 
a/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala 
b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala
index dde6402..6706599 100644
--- a/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala
@@ -12,7 +12,6 @@ import org.apache.mahout.math.drm.logical.OpAx
 import org.apache.mahout.flinkbindings.drm.CheckpointedFlinkDrm
 import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
 import org.apache.mahout.math.drm.logical._
-import scala.collection.immutable.Range
 
 @RunWith(classOf[JUnitRunner])
 class LATestSuit extends FunSuite with DistributedFlinkSuit {
@@ -122,4 +121,20 @@ class LATestSuit extends FunSuite with 
DistributedFlinkSuit {
     assert((output - expected).norm < 1e-6)
   }
 
+  test("A times inCoreB") {
+    val inCoreA = dense((1, 2, 3), (2, 3, 1), (3, 4, 4), (4, 4, 5), (5, 5, 7), 
(6, 7, 11))
+    val inCoreB = dense((2, 1), (3, 4), (5, 11))
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+
+    val op = new OpTimesRightMatrix(A, inCoreB)
+    val res = FlinkOpTimesRightMatrix.drmTimesInCore(op, A, inCoreB)
+
+    val drm = new CheckpointedFlinkDrm(res.deblockify.ds, _nrow=op.nrow, 
+        _ncol=op.ncol)
+    val output = drm.collect
+
+    val expected = inCoreA %*% inCoreB
+    assert((output - expected).norm < 1e-6)
+  }
+
 }
\ No newline at end of file

Reply via email to