MAHOUT-1710: Flink: A times incoreB operator
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/522f3d51 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/522f3d51 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/522f3d51 Branch: refs/heads/flink-binding Commit: 522f3d516cefd44272cdbf5b4b4f52574200dd6c Parents: de7a75f Author: Alexey Grigorev <[email protected]> Authored: Tue May 26 15:51:36 2015 +0200 Committer: Alexey Grigorev <[email protected]> Committed: Fri Sep 25 17:41:44 2015 +0200 ---------------------------------------------------------------------- .../mahout/flinkbindings/FlinkEngine.scala | 2 + .../blas/FlinkOpTimesRightMatrix.scala | 40 ++++++++++++++++++++ .../mahout/flinkbindings/RLikeOpsSuite.scala | 12 ++++++ .../mahout/flinkbindings/blas/LATestSuit.scala | 17 ++++++++- 4 files changed, 70 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/522f3d51/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala index a7082d1..03d1a9c 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala @@ -97,6 +97,8 @@ object FlinkEngine extends DistributedEngine { val aTranslated = flinkTranslate(aInt) FlinkOpAtB.notZippable(opAtB, aTranslated, aTranslated) } + case op @ OpTimesRightMatrix(a, b) => + FlinkOpTimesRightMatrix.drmTimesInCore(op, flinkTranslate(a)(op.classTagA), b) case op @ OpAewScalar(a, scalar, _) => FlinkOpAewScalar.opScalarNoSideEffect(op, flinkTranslate(a)(op.classTagA), scalar) case op @ OpAewB(a, b, _) => http://git-wip-us.apache.org/repos/asf/mahout/blob/522f3d51/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpTimesRightMatrix.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpTimesRightMatrix.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpTimesRightMatrix.scala new file mode 100644 index 0000000..e26ee7d --- /dev/null +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpTimesRightMatrix.scala @@ -0,0 +1,40 @@ +package org.apache.mahout.flinkbindings.blas + +import scala.reflect.ClassTag +import org.apache.mahout.math.drm.logical.OpTimesRightMatrix +import org.apache.mahout.flinkbindings.drm.FlinkDrm +import org.apache.mahout.math.Matrix +import org.apache.mahout.math.DiagonalMatrix +import org.apache.flink.api.common.functions.RichMapFunction +import org.apache.flink.configuration.Configuration +import org.apache.mahout.flinkbindings.drm.BlockifiedFlinkDrm +import org.apache.mahout.math._ +import scalabindings._ +import RLikeOps._ + +object FlinkOpTimesRightMatrix { + + def drmTimesInCore[K: ClassTag](op: OpTimesRightMatrix[K], A: FlinkDrm[K], inCoreB: Matrix): FlinkDrm[K] = { + implicit val ctx = A.context + + val singletonDataSetB = ctx.env.fromElements(inCoreB) + + val res = A.blockify.ds.map(new RichMapFunction[(Array[K], Matrix), (Array[K], Matrix)] { + var inCoreB: Matrix = null + + override def open(params: Configuration): Unit = { + val runtime = this.getRuntimeContext() + val dsB: java.util.List[Matrix] = runtime.getBroadcastVariable("matrix") + inCoreB = dsB.get(0) + } + + override def map(tuple: (Array[K], Matrix)): (Array[K], Matrix) = tuple match { + case (keys, block_A) => (keys, block_A %*% inCoreB) + } + + }).withBroadcastSet(singletonDataSetB, "matrix") + + new BlockifiedFlinkDrm(res, op.ncol) + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/522f3d51/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala ---------------------------------------------------------------------- diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala index 835cf68..a0da308 100644 --- a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala +++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala @@ -221,4 +221,16 @@ class RLikeOpsSuite extends FunSuite with DistributedFlinkSuit { assert((res.collect - expected).norm < 1e-6) } + test("A %*% inCoreB") { + val inCoreA = dense((1, 2), (2, 3), (3, 4)).t + val inCoreB = dense((1, 2), (3, 4), (11, 4)) + + val A = drmParallelize(m = inCoreA, numPartitions = 2) + + val res = A %*% inCoreB + + val expected = inCoreA %*% inCoreB + assert((res.collect - expected).norm < 1e-6) + } + } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/522f3d51/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala ---------------------------------------------------------------------- diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala index dde6402..6706599 100644 --- a/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala +++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala @@ -12,7 +12,6 @@ import org.apache.mahout.math.drm.logical.OpAx import org.apache.mahout.flinkbindings.drm.CheckpointedFlinkDrm import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm import org.apache.mahout.math.drm.logical._ -import scala.collection.immutable.Range @RunWith(classOf[JUnitRunner]) class LATestSuit extends FunSuite with DistributedFlinkSuit { @@ -122,4 +121,20 @@ class LATestSuit extends FunSuite with DistributedFlinkSuit { assert((output - expected).norm < 1e-6) } + test("A times inCoreB") { + val inCoreA = dense((1, 2, 3), (2, 3, 1), (3, 4, 4), (4, 4, 5), (5, 5, 7), (6, 7, 11)) + val inCoreB = dense((2, 1), (3, 4), (5, 11)) + val A = drmParallelize(m = inCoreA, numPartitions = 2) + + val op = new OpTimesRightMatrix(A, inCoreB) + val res = FlinkOpTimesRightMatrix.drmTimesInCore(op, A, inCoreB) + + val drm = new CheckpointedFlinkDrm(res.deblockify.ds, _nrow=op.nrow, + _ncol=op.ncol) + val output = drm.collect + + val expected = inCoreA %*% inCoreB + assert((output - expected).norm < 1e-6) + } + } \ No newline at end of file
