MAHOUT-1703: Flink: cbind with scalar

Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/f26245b8
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/f26245b8
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/f26245b8

Branch: refs/heads/flink-binding
Commit: f26245b8f06d2c05a818f47c9162df6f4d7f7b67
Parents: 9d48487
Author: Alexey Grigorev <[email protected]>
Authored: Fri Aug 21 16:39:25 2015 +0200
Committer: Alexey Grigorev <[email protected]>
Committed: Fri Sep 25 17:43:41 2015 +0200

----------------------------------------------------------------------
 .../mahout/flinkbindings/FlinkEngine.scala      |  3 ++
 .../flinkbindings/blas/FlinkOpCBind.scala       | 43 +++++++++++++++++---
 .../mahout/flinkbindings/RLikeOpsSuite.scala    | 18 ++++++++
 .../mahout/flinkbindings/blas/LATestSuite.scala | 30 ++++++++++++++
 4 files changed, 89 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/f26245b8/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
----------------------------------------------------------------------
diff --git 
a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala 
b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
index 8e47629..c7bea7b 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
@@ -78,6 +78,7 @@ import org.apache.mahout.math.indexeddataset.Schema
 import org.apache.mahout.math.scalabindings._
 import org.apache.mahout.math.scalabindings.RLikeOps._
 import org.apache.mahout.flinkbindings.blas.FlinkOpAtA
+import org.apache.mahout.math.drm.logical.OpCbindScalar
 
 object FlinkEngine extends DistributedEngine {
 
@@ -177,6 +178,8 @@ object FlinkEngine extends DistributedEngine {
       FlinkOpCBind.cbind(op, flinkTranslate(a)(op.classTagA), 
flinkTranslate(b)(op.classTagA))
     case op @ OpRbind(a, b) => 
       FlinkOpRBind.rbind(op, flinkTranslate(a)(op.classTagA), 
flinkTranslate(b)(op.classTagA))
+    case op @ OpCbindScalar(a, x, _) => 
+      FlinkOpCBind.cbindScalar(op, flinkTranslate(a)(op.classTagA), x)
     case op @ OpRowRange(a, _) => 
       FlinkOpRowRange.slice(op, flinkTranslate(a)(op.classTagA))
     case op: OpMapBlock[K, _] => 

http://git-wip-us.apache.org/repos/asf/mahout/blob/f26245b8/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala
----------------------------------------------------------------------
diff --git 
a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala 
b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala
index 88155d6..27237d6 100644
--- 
a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala
+++ 
b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala
@@ -20,18 +20,19 @@ package org.apache.mahout.flinkbindings.blas
 
 import java.lang.Iterable
 
+
 import scala.collection.JavaConverters._
+import scala.collection.JavaConversions._
 import scala.reflect.ClassTag
 
 import org.apache.flink.api.common.functions.CoGroupFunction
+import org.apache.flink.api.common.functions.MapFunction
 import org.apache.flink.api.java.DataSet
 import org.apache.flink.util.Collector
-import org.apache.mahout.flinkbindings.drm.FlinkDrm
-import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
-import org.apache.mahout.math.DenseVector
-import org.apache.mahout.math.SequentialAccessSparseVector
-import org.apache.mahout.math.Vector
+import org.apache.mahout.flinkbindings.drm._
+import org.apache.mahout.math._
 import org.apache.mahout.math.drm.logical.OpCbind
+import org.apache.mahout.math.drm.logical.OpCbindScalar
 import org.apache.mahout.math.scalabindings.RLikeOps._
 
 import com.google.common.collect.Lists
@@ -99,4 +100,36 @@ object FlinkOpCBind {
     new RowsFlinkDrm(res.asInstanceOf[DataSet[(K, Vector)]], ncol=op.ncol)
   }
 
+  def cbindScalar[K: ClassTag](op: OpCbindScalar[K], A: FlinkDrm[K], x: 
Double): FlinkDrm[K] = {
+    val left = op.leftBind
+    val ds = A.blockify.ds
+
+    val out = A.blockify.ds.map(new MapFunction[(Array[K], Matrix), (Array[K], 
Matrix)] {
+      def map(tuple: (Array[K], Matrix)): (Array[K], Matrix) = tuple match {
+        case (keys, mat) => (keys, cbind(mat, x, left))
+      }
+
+      def cbind(mat: Matrix, x: Double, left: Boolean): Matrix = {
+        val ncol = mat.ncol
+        val newMat = mat.like(mat.nrow, ncol + 1)
+
+        if (left) {
+          newMat.zip(mat).foreach { case (newVec, origVec) =>
+            newVec(0) = x
+            newVec(1 to ncol) := origVec
+          }
+        } else {
+          newMat.zip(mat).foreach { case (newVec, origVec) =>
+            newVec(ncol) = x
+            newVec(0 to (ncol - 1)) := origVec
+          }
+        }
+
+        newMat
+      }
+    })
+
+    new BlockifiedFlinkDrm(out, op.ncol)
+  }
+
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/f26245b8/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
----------------------------------------------------------------------
diff --git 
a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala 
b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
index 99f6718..77ecaf4 100644
--- a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
@@ -206,6 +206,24 @@ class RLikeOpsSuite extends FunSuite with 
DistributedFlinkSuite {
     assert((res.collect - expected).norm < 1e-6)
   }
 
+  test("1 cbind A") {
+    val inCoreA = dense((1, 2), (2, 3), (3, 4))
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+
+    val res = 1 cbind A
+    val expected = dense((1, 1, 2), (1, 2, 3), (1, 3, 4))
+    assert((res.collect - expected).norm < 1e-6)
+  }
+
+  test("A cbind 1") {
+    val inCoreA = dense((1, 2), (2, 3), (3, 4))
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+
+    val res = A cbind 1
+    val expected = dense((1, 2, 1), (2, 3, 1), (3, 4, 1))
+    assert((res.collect - expected).norm < 1e-6)
+  }
+
   test("A rbind B") {
     val inCoreA = dense((1, 2), (2, 3), (3, 4))
     val inCoreB = dense((1, 2), (3, 4), (11, 4))

http://git-wip-us.apache.org/repos/asf/mahout/blob/f26245b8/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuite.scala
----------------------------------------------------------------------
diff --git 
a/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuite.scala 
b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuite.scala
index 2db7f91..42c1f63 100644
--- 
a/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuite.scala
+++ 
b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuite.scala
@@ -123,6 +123,36 @@ class LATestSuite extends FunSuite with 
DistributedFlinkSuite {
     assert((output - expected).norm < 1e-6)
   }
 
+  test("CbindScalar left") {
+    val inCoreA = dense((1, 2), (2, 3), (3, 4))
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+
+    val op = new OpCbindScalar(A, 1, true)
+    val res = FlinkOpCBind.cbindScalar(op, A, 1)
+
+    val drm = new CheckpointedFlinkDrm(res.deblockify.ds, _nrow=inCoreA.nrow, 
+        _ncol=(inCoreA.ncol + 1))
+    val output = drm.collect
+
+    val expected = dense((1, 1, 2), (1, 2, 3), (1, 3, 4))
+    assert((output - expected).norm < 1e-6)
+  }
+
+  test("CbindScalar right") {
+    val inCoreA = dense((1, 2), (2, 3), (3, 4))
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+
+    val op = new OpCbindScalar(A, 1, false)
+    val res = FlinkOpCBind.cbindScalar(op, A, 1)
+
+    val drm = new CheckpointedFlinkDrm(res.deblockify.ds, _nrow=inCoreA.nrow, 
+        _ncol=(inCoreA.ncol + 1))
+    val output = drm.collect
+
+    val expected = dense((1, 2, 1), (2, 3, 1), (3, 4, 1))
+    assert((output - expected).norm < 1e-6)
+  }
+
   test("slice") {
     val inCoreA = dense((1, 2), (2, 3), (3, 4), (4, 4), (5, 5), (6, 7))
     val A = drmParallelize(m = inCoreA, numPartitions = 2)

Reply via email to