MAHOUT-1703: Flink: cbind with scalar
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/f26245b8 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/f26245b8 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/f26245b8 Branch: refs/heads/flink-binding Commit: f26245b8f06d2c05a818f47c9162df6f4d7f7b67 Parents: 9d48487 Author: Alexey Grigorev <[email protected]> Authored: Fri Aug 21 16:39:25 2015 +0200 Committer: Alexey Grigorev <[email protected]> Committed: Fri Sep 25 17:43:41 2015 +0200 ---------------------------------------------------------------------- .../mahout/flinkbindings/FlinkEngine.scala | 3 ++ .../flinkbindings/blas/FlinkOpCBind.scala | 43 +++++++++++++++++--- .../mahout/flinkbindings/RLikeOpsSuite.scala | 18 ++++++++ .../mahout/flinkbindings/blas/LATestSuite.scala | 30 ++++++++++++++ 4 files changed, 89 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/f26245b8/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala index 8e47629..c7bea7b 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala @@ -78,6 +78,7 @@ import org.apache.mahout.math.indexeddataset.Schema import org.apache.mahout.math.scalabindings._ import org.apache.mahout.math.scalabindings.RLikeOps._ import org.apache.mahout.flinkbindings.blas.FlinkOpAtA +import org.apache.mahout.math.drm.logical.OpCbindScalar object FlinkEngine extends DistributedEngine { @@ -177,6 +178,8 @@ object FlinkEngine extends DistributedEngine { FlinkOpCBind.cbind(op, flinkTranslate(a)(op.classTagA), flinkTranslate(b)(op.classTagA)) case op @ OpRbind(a, b) => FlinkOpRBind.rbind(op, flinkTranslate(a)(op.classTagA), flinkTranslate(b)(op.classTagA)) + case op @ OpCbindScalar(a, x, _) => + FlinkOpCBind.cbindScalar(op, flinkTranslate(a)(op.classTagA), x) case op @ OpRowRange(a, _) => FlinkOpRowRange.slice(op, flinkTranslate(a)(op.classTagA)) case op: OpMapBlock[K, _] => http://git-wip-us.apache.org/repos/asf/mahout/blob/f26245b8/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala index 88155d6..27237d6 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala @@ -20,18 +20,19 @@ package org.apache.mahout.flinkbindings.blas import java.lang.Iterable + import scala.collection.JavaConverters._ +import scala.collection.JavaConversions._ import scala.reflect.ClassTag import org.apache.flink.api.common.functions.CoGroupFunction +import org.apache.flink.api.common.functions.MapFunction import org.apache.flink.api.java.DataSet import org.apache.flink.util.Collector -import org.apache.mahout.flinkbindings.drm.FlinkDrm -import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm -import org.apache.mahout.math.DenseVector -import org.apache.mahout.math.SequentialAccessSparseVector -import org.apache.mahout.math.Vector +import org.apache.mahout.flinkbindings.drm._ +import org.apache.mahout.math._ import org.apache.mahout.math.drm.logical.OpCbind +import org.apache.mahout.math.drm.logical.OpCbindScalar import org.apache.mahout.math.scalabindings.RLikeOps._ import com.google.common.collect.Lists @@ -99,4 +100,36 @@ object FlinkOpCBind { new RowsFlinkDrm(res.asInstanceOf[DataSet[(K, Vector)]], ncol=op.ncol) } + def cbindScalar[K: ClassTag](op: OpCbindScalar[K], A: FlinkDrm[K], x: Double): FlinkDrm[K] = { + val left = op.leftBind + val ds = A.blockify.ds + + val out = A.blockify.ds.map(new MapFunction[(Array[K], Matrix), (Array[K], Matrix)] { + def map(tuple: (Array[K], Matrix)): (Array[K], Matrix) = tuple match { + case (keys, mat) => (keys, cbind(mat, x, left)) + } + + def cbind(mat: Matrix, x: Double, left: Boolean): Matrix = { + val ncol = mat.ncol + val newMat = mat.like(mat.nrow, ncol + 1) + + if (left) { + newMat.zip(mat).foreach { case (newVec, origVec) => + newVec(0) = x + newVec(1 to ncol) := origVec + } + } else { + newMat.zip(mat).foreach { case (newVec, origVec) => + newVec(ncol) = x + newVec(0 to (ncol - 1)) := origVec + } + } + + newMat + } + }) + + new BlockifiedFlinkDrm(out, op.ncol) + } + } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/f26245b8/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala ---------------------------------------------------------------------- diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala index 99f6718..77ecaf4 100644 --- a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala +++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala @@ -206,6 +206,24 @@ class RLikeOpsSuite extends FunSuite with DistributedFlinkSuite { assert((res.collect - expected).norm < 1e-6) } + test("1 cbind A") { + val inCoreA = dense((1, 2), (2, 3), (3, 4)) + val A = drmParallelize(m = inCoreA, numPartitions = 2) + + val res = 1 cbind A + val expected = dense((1, 1, 2), (1, 2, 3), (1, 3, 4)) + assert((res.collect - expected).norm < 1e-6) + } + + test("A cbind 1") { + val inCoreA = dense((1, 2), (2, 3), (3, 4)) + val A = drmParallelize(m = inCoreA, numPartitions = 2) + + val res = A cbind 1 + val expected = dense((1, 2, 1), (2, 3, 1), (3, 4, 1)) + assert((res.collect - expected).norm < 1e-6) + } + test("A rbind B") { val inCoreA = dense((1, 2), (2, 3), (3, 4)) val inCoreB = dense((1, 2), (3, 4), (11, 4)) http://git-wip-us.apache.org/repos/asf/mahout/blob/f26245b8/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuite.scala ---------------------------------------------------------------------- diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuite.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuite.scala index 2db7f91..42c1f63 100644 --- a/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuite.scala +++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuite.scala @@ -123,6 +123,36 @@ class LATestSuite extends FunSuite with DistributedFlinkSuite { assert((output - expected).norm < 1e-6) } + test("CbindScalar left") { + val inCoreA = dense((1, 2), (2, 3), (3, 4)) + val A = drmParallelize(m = inCoreA, numPartitions = 2) + + val op = new OpCbindScalar(A, 1, true) + val res = FlinkOpCBind.cbindScalar(op, A, 1) + + val drm = new CheckpointedFlinkDrm(res.deblockify.ds, _nrow=inCoreA.nrow, + _ncol=(inCoreA.ncol + 1)) + val output = drm.collect + + val expected = dense((1, 1, 2), (1, 2, 3), (1, 3, 4)) + assert((output - expected).norm < 1e-6) + } + + test("CbindScalar right") { + val inCoreA = dense((1, 2), (2, 3), (3, 4)) + val A = drmParallelize(m = inCoreA, numPartitions = 2) + + val op = new OpCbindScalar(A, 1, false) + val res = FlinkOpCBind.cbindScalar(op, A, 1) + + val drm = new CheckpointedFlinkDrm(res.deblockify.ds, _nrow=inCoreA.nrow, + _ncol=(inCoreA.ncol + 1)) + val output = drm.collect + + val expected = dense((1, 2, 1), (2, 3, 1), (3, 4, 1)) + assert((output - expected).norm < 1e-6) + } + test("slice") { val inCoreA = dense((1, 2), (2, 3), (3, 4), (4, 4), (5, 5), (6, 7)) val A = drmParallelize(m = inCoreA, numPartitions = 2)
