MAHOUT-1701: Flink: AtB implemented, ABt and AtA expressed via AtB
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/f836481b Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/f836481b Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/f836481b Branch: refs/heads/flink-binding Commit: f836481b823a1aaaa70d9ab87c030f60c459de0d Parents: 98d4ff0 Author: Alexey Grigorev <[email protected]> Authored: Tue May 5 20:05:21 2015 +0200 Committer: Alexey Grigorev <[email protected]> Committed: Fri Sep 25 17:41:39 2015 +0200 ---------------------------------------------------------------------- .../mahout/flinkbindings/FlinkEngine.scala | 32 ++++++- .../mahout/flinkbindings/blas/FlinkOpAt.scala | 11 ++- .../mahout/flinkbindings/blas/FlinkOpAtB.scala | 85 +++++++++++++++++++ .../mahout/flinkbindings/blas/package.scala | 15 ++++ .../mahout/flinkbindings/RLikeOpsSuite.scala | 88 +++++++++++--------- .../mahout/flinkbindings/UseCasesSuite.scala | 79 ++++++++++++++++++ .../mahout/flinkbindings/blas/LATestSuit.scala | 24 +++++- 7 files changed, 285 insertions(+), 49 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/f836481b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala index 17bf0b6..a124a7c 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala @@ -33,6 +33,10 @@ import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm import org.apache.mahout.math.drm.logical.OpAt import org.apache.mahout.math.drm.logical.OpAtx import org.apache.mahout.math.drm.logical.OpAtx +import org.apache.mahout.math.drm.logical.OpAtB +import org.apache.mahout.math.drm.logical.OpABt +import org.apache.mahout.math.drm.logical.OpAtB +import org.apache.mahout.math.drm.logical.OpAtA object FlinkEngine extends DistributedEngine { @@ -56,14 +60,40 @@ object FlinkEngine extends DistributedEngine { case op @ OpAx(a, x) => FlinkOpAx.blockifiedBroadcastAx(op, flinkTranslate(a)(op.classTagA)) case op @ OpAt(a) => FlinkOpAt.sparseTrick(op, flinkTranslate(a)(op.classTagA)) case op @ OpAtx(a, x) => { + // express Atx as (A.t) %*% x + // TODO: create specific implementation of Atx val opAt = OpAt(a) val at = FlinkOpAt.sparseTrick(opAt, flinkTranslate(a)(op.classTagA)) val atCast = new CheckpointedFlinkDrm(at.deblockify.ds, _nrow=opAt.nrow, _ncol=opAt.ncol) val opAx = OpAx(atCast, x) FlinkOpAx.blockifiedBroadcastAx(opAx, flinkTranslate(atCast)(op.classTagA)) } + case op @ OpAtB(a, b) => FlinkOpAtB.notZippable(op, flinkTranslate(a)(op.classTagA), + flinkTranslate(b)(op.classTagA)) + case op @ OpABt(a, b) => { + // express ABt via AtB: let C=At and D=Bt, and calculate CtD + // TODO: create specific implementation of ABt + val opAt = OpAt(a.asInstanceOf[DrmLike[Int]]) // TODO: casts! + val at = FlinkOpAt.sparseTrick(opAt, flinkTranslate(a.asInstanceOf[DrmLike[Int]])) + val c = new CheckpointedFlinkDrm(at.deblockify.ds, _nrow=opAt.nrow, _ncol=opAt.ncol) + + val opBt = OpAt(b.asInstanceOf[DrmLike[Int]]) // TODO: casts! + val bt = FlinkOpAt.sparseTrick(opBt, flinkTranslate(b.asInstanceOf[DrmLike[Int]])) + val d = new CheckpointedFlinkDrm(bt.deblockify.ds, _nrow=opBt.nrow, _ncol=opBt.ncol) + + FlinkOpAtB.notZippable(OpAtB(c, d), flinkTranslate(c), flinkTranslate(d)) + .asInstanceOf[FlinkDrm[K]] + } + case op @ OpAtA(a) => { + // express AtA via AtB + // TODO: create specific implementation of AtA + val aInt = a.asInstanceOf[DrmLike[Int]] // TODO: casts! + val opAtB = OpAtB(aInt, aInt) + val aTranslated = flinkTranslate(aInt) + FlinkOpAtB.notZippable(opAtB, aTranslated, aTranslated) + } case cp: CheckpointedFlinkDrm[K] => new RowsFlinkDrm(cp.ds, cp.ncol) - case _ => ??? + case _ => throw new NotImplementedError(s"operator $oper is not implemented yet") } http://git-wip-us.apache.org/repos/asf/mahout/blob/f836481b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala index be7fc8f..08aea73 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala @@ -20,6 +20,7 @@ import org.apache.flink.api.java.functions.KeySelector import java.util.ArrayList import org.apache.flink.shaded.com.google.common.collect.Lists + /** * Taken from */ @@ -40,7 +41,7 @@ object FlinkOpAt { val columnVector: Vector = new SequentialAccessSparseVector(ncol) keys.zipWithIndex.foreach { case (key, idx) => - columnVector(key) = block(idx, columnIdx) + columnVector(key) = block(idx, columnIdx) } out.collect(new Tuple2(columnIdx, columnVector)) @@ -49,12 +50,10 @@ object FlinkOpAt { } }) - val regrouped = sparseParts.groupBy(new KeySelector[Tuple2[Int, Vector], Integer] { - def getKey(tuple: Tuple2[Int, Vector]): Integer = tuple._1 - }) + val regrouped = sparseParts.groupBy(tuple_1[Vector]) - val sparseTotal = regrouped.reduceGroup(new GroupReduceFunction[Tuple2[Int, Vector], DrmTuple[Int]] { - def reduce(values: Iterable[DrmTuple[Int]], out: Collector[DrmTuple[Int]]): Unit = { + val sparseTotal = regrouped.reduceGroup(new GroupReduceFunction[(Int, Vector), DrmTuple[Int]] { + def reduce(values: Iterable[(Int, Vector)], out: Collector[DrmTuple[Int]]): Unit = { val it = Lists.newArrayList(values).asScala val (idx, _) = it.head val vector = it map { case (idx, vec) => vec } reduce (_ + _) http://git-wip-us.apache.org/repos/asf/mahout/blob/f836481b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala new file mode 100644 index 0000000..3b353fc --- /dev/null +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala @@ -0,0 +1,85 @@ +package org.apache.mahout.flinkbindings.blas + +import scala.reflect.ClassTag +import org.apache.mahout.flinkbindings.drm.FlinkDrm +import org.apache.mahout.math.drm.logical.OpAtB +import org.apache.flink.api.common.functions.MapFunction +import org.apache.flink.api.java.tuple.Tuple2 +import org.apache.mahout.math.Vector +import org.apache.mahout.math.Matrix +import org.apache.flink.api.common.functions.FlatMapFunction +import org.apache.flink.util.Collector +import org.apache.mahout.math.drm._ +import org.apache.mahout.math.scalabindings._ +import RLikeOps._ +import org.apache.flink.api.common.functions.GroupReduceFunction +import java.lang.Iterable +import scala.collection.JavaConverters._ +import com.google.common.collect.Lists +import org.apache.mahout.flinkbindings.drm.BlockifiedFlinkDrm +import org.apache.mahout.flinkbindings.BlockifiedDrmDataSet +import org.apache.flink.api.scala._ +import org.apache.flink.api.common.typeinfo.TypeInformation + + +object FlinkOpAtB { + + def notZippable[K: ClassTag](op: OpAtB[K], At: FlinkDrm[K], B: FlinkDrm[K]): FlinkDrm[Int] = { + // TODO: to help Flink's type inference + // only Int is supported now + val rowsAt = At.deblockify.ds.map(new DrmTupleToDrmTupleInt()) + val rowsB = B.deblockify.ds.map(new DrmTupleToDrmTupleInt()) + val joined = rowsAt.join(rowsB).where(tuple_1[Vector]).equalTo(tuple_1[Vector]) + + val ncol = op.ncol + val nrow = op.nrow + val blockHeight = 10 + val blockCount = safeToNonNegInt((ncol - 1) / blockHeight + 1) + + val preProduct = joined.flatMap(new FlatMapFunction[Tuple2[(Int, Vector), (Int, Vector)], + (Int, Matrix)] { + def flatMap(in: Tuple2[(Int, Vector), (Int, Vector)], + out: Collector[(Int, Matrix)]): Unit = { + val avec = in.f0._2 + val bvec = in.f1._2 + + 0.until(blockCount) map { blockKey => + val blockStart = blockKey * blockHeight + val blockEnd = Math.min(ncol, blockStart + blockHeight) + + // Create block by cross product of proper slice of aRow and qRow + val outer = avec(blockStart until blockEnd) cross bvec + out.collect((blockKey, outer)) + } + } + }) + + val res: BlockifiedDrmDataSet[Int] = preProduct.groupBy(tuple_1[Matrix]).reduceGroup( + new GroupReduceFunction[(Int, Matrix), BlockifiedDrmTuple[Int]] { + def reduce(values: Iterable[(Int, Matrix)], out: Collector[BlockifiedDrmTuple[Int]]): Unit = { + val it = Lists.newArrayList(values).asScala + val (idx, _) = it.head + + val block = it.map(t => t._2).reduce((m1, m2) => m1 + m2) + + val keys = idx.until(block.nrow).toArray[Int] + out.collect((keys, block)) + } + }) + + new BlockifiedFlinkDrm(res, ncol) + } + +} + +class DrmTupleToDrmTupleInt[K: ClassTag] extends MapFunction[(K, Vector), (Int, Vector)] { + def map(tuple: (K, Vector)): (Int, Vector) = tuple match { + case (key, vec) => (key.asInstanceOf[Int], vec) + } +} + +class DrmTupleToFlinkTupleMapper[K: ClassTag] extends MapFunction[(K, Vector), Tuple2[Int, Vector]] { + def map(tuple: (K, Vector)): Tuple2[Int, Vector] = tuple match { + case (key, vec) => new Tuple2[Int, Vector](key.asInstanceOf[Int], vec) + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/f836481b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala new file mode 100644 index 0000000..af5ccc8 --- /dev/null +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala @@ -0,0 +1,15 @@ +package org.apache.mahout.flinkbindings + +import org.apache.flink.api.java.functions.KeySelector +import org.apache.mahout.math.Vector +import scala.reflect.ClassTag + + +package object blas { + + // TODO: remove it once figure out how to make Flink accept interfaces (Vector here) + def tuple_1[K: ClassTag] = new KeySelector[(Int, K), Integer] { + def getKey(tuple: Tuple2[Int, K]): Integer = tuple._1 + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/f836481b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala ---------------------------------------------------------------------- diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala index 07d6a84..2624077 100644 --- a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala +++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala @@ -33,64 +33,72 @@ class RLikeOpsSuite extends FunSuite with DistributedFlinkSuit { assert(b == dvec(8, 11, 14)) } - test("Power interation 1000 x 1000 matrix") { - val dim = 1000 + test("A.t") { + val inCoreA = dense((1, 2, 3), (2, 3, 4)) + val A = drmParallelize(m = inCoreA, numPartitions = 2) + val res = A.t.collect - // we want a symmetric matrix so we can have real eigenvalues - val inCoreA = symmtericMatrix(dim, max = 2000) + val expected = inCoreA.t + assert((res - expected).norm < 1e-6) + } + test("A.t %*% x") { + val inCoreA = dense((1, 2, 3), (2, 3, 4)) val A = drmParallelize(m = inCoreA, numPartitions = 2) + val x = dvec(3, 11) + val res = (A.t %*% x).collect(::, 0) - var x: Vector = 1 to dim map (_ => 1.0 / Math.sqrt(dim)) - var converged = false + val expected = inCoreA.t %*% x + assert((res - expected).norm(2) < 1e-6) + } - var iteration = 1 + test("A.t %*% B") { + val inCoreA = dense((1, 2), (2, 3), (3, 4)) + val inCoreB = dense((1, 2), (3, 4), (11, 4)) - while (!converged) { - LOGGER.info(s"iteration #$iteration...") + val A = drmParallelize(m = inCoreA, numPartitions = 2) + val B = drmParallelize(m = inCoreB, numPartitions = 2) - val Ax = A %*% x - var x_new = Ax.collect(::, 0) - x_new = x_new / x_new.norm(2) + val res = A.t %*% B - val diff = (x_new - x).norm(2) - LOGGER.info(s"difference norm is $diff") + val expected = inCoreA.t %*% inCoreB + assert((res.collect - expected).norm < 1e-6) + } - converged = diff < 1e-6 - iteration = iteration + 1 - x = x_new - } + test("A %*% B.t") { + val inCoreA = dense((1, 2), (2, 3), (3, 4)) + val inCoreB = dense((1, 2), (3, 4), (11, 4)) - LOGGER.info("converged") - // TODO: add test that it's the 1st PC - } + val A = drmParallelize(m = inCoreA, numPartitions = 2) + val B = drmParallelize(m = inCoreB, numPartitions = 2) + + val res = A %*% B.t - def symmtericMatrix(dim: Int, max: Int, seed: Int = 0x31337) = { - Matrices.functionalMatrixView(dim, dim, new IntIntFunction { - def apply(i: Int, j: Int): Double = { - val arr = Array(i + j, i * j, i + j + 31, i / (j + 1) + j / (i + 1)) - Math.abs(MurmurHash3.arrayHash(arr, seed) % max) - } - }) + val expected = inCoreA %*% inCoreB.t + assert((res.collect - expected).norm < 1e-6) } - test("A.t") { - val inCoreA = dense((1, 2, 3), (2, 3, 4)) + test("A.t %*% A") { + val inCoreA = dense((1, 2), (2, 3), (3, 4)) val A = drmParallelize(m = inCoreA, numPartitions = 2) - val res = A.t.collect - val expected = inCoreA.t - assert((res - expected).norm < 1e-6) + val res = A.t %*% A + + val expected = inCoreA.t %*% inCoreA + assert((res.collect - expected).norm < 1e-6) } - test("A.t %*% x") { - val inCoreA = dense((1, 2, 3), (2, 3, 4)) + test("A %*% B") { + val inCoreA = dense((1, 2), (2, 3), (3, 4)).t + val inCoreB = dense((1, 2), (3, 4), (11, 4)) + val A = drmParallelize(m = inCoreA, numPartitions = 2) - val x = dvec(3, 11) - val res = (A.t %*% x).collect(::, 0) + val B = drmParallelize(m = inCoreB, numPartitions = 2) - val expected = inCoreA.t %*% x - assert((res - expected).norm(2) < 1e-6) + val res = A %*% B + + val expected = inCoreA %*% inCoreB + assert((res.collect - expected).norm < 1e-6) } - + } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/f836481b/flink/src/test/scala/org/apache/mahout/flinkbindings/UseCasesSuite.scala ---------------------------------------------------------------------- diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/UseCasesSuite.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/UseCasesSuite.scala new file mode 100644 index 0000000..8cdaca3 --- /dev/null +++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/UseCasesSuite.scala @@ -0,0 +1,79 @@ +package org.apache.mahout.flinkbindings + +import org.junit.runner.RunWith +import org.scalatest.junit.JUnitRunner +import org.scalatest.FunSuite +import org.apache.mahout.math._ +import scalabindings._ +import RLikeOps._ +import org.apache.mahout.math.drm._ +import RLikeDrmOps._ +import org.apache.mahout.flinkbindings._ +import org.apache.mahout.math.function.IntIntFunction +import scala.util.Random +import scala.util.MurmurHash +import scala.util.hashing.MurmurHash3 +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import org.scalatest.Ignore + +@RunWith(classOf[JUnitRunner]) +class UseCasesSuite extends FunSuite with DistributedFlinkSuit { + + val LOGGER = LoggerFactory.getLogger(getClass()) + + test("use case: Power interation 1000 x 1000 matrix") { + val dim = 1000 + + // we want a symmetric matrix so we can have real eigenvalues + val inCoreA = symmtericMatrix(dim, max = 2000) + + val A = drmParallelize(m = inCoreA, numPartitions = 2) + + var x: Vector = 1 to dim map (_ => 1.0 / Math.sqrt(dim)) + var converged = false + + var iteration = 1 + + while (!converged) { + LOGGER.info(s"iteration #$iteration...") + + val Ax = A %*% x + var x_new = Ax.collect(::, 0) + x_new = x_new / x_new.norm(2) + + val diff = (x_new - x).norm(2) + LOGGER.info(s"difference norm is $diff") + + converged = diff < 1e-6 + iteration = iteration + 1 + x = x_new + } + + LOGGER.info("converged") + // TODO: add test that it's the 1st PC + } + + def symmtericMatrix(dim: Int, max: Int, seed: Int = 0x31337) = { + Matrices.functionalMatrixView(dim, dim, new IntIntFunction { + def apply(i: Int, j: Int): Double = { + val arr = Array(i + j, i * j, i + j + 31, i / (j + 1) + j / (i + 1)) + Math.abs(MurmurHash3.arrayHash(arr, seed) % max) + } + }) + } + + test("use case: OLS Regression") { + val inCoreA = dense((1, 2), (2, 3), (3, 4), (5, 6), (7, 8), (9, 10)) + val x = dvec(1, 2, 2, 3, 3, 3) + val A = drmParallelize(m = inCoreA, numPartitions = 2) + val AtA = A.t %*% A + val Atx = A.t %*% x + + val w = solve(AtA, Atx) + + val expected = solve(inCoreA.t %*% inCoreA, inCoreA.t %*% x) + assert((w(::, 0) - expected).norm(2) < 1e-6) + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/f836481b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala ---------------------------------------------------------------------- diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala index 3ce8895..baf23d6 100644 --- a/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala +++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala @@ -12,11 +12,12 @@ import org.apache.mahout.math.drm.logical.OpAx import org.apache.mahout.flinkbindings.drm.CheckpointedFlinkDrm import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm import org.apache.mahout.math.drm.logical.OpAt +import org.apache.mahout.math.drm.logical.OpAtB @RunWith(classOf[JUnitRunner]) class LATestSuit extends FunSuite with DistributedFlinkSuit { - test("Ax") { + test("Ax blockified") { val inCoreA = dense((1, 2, 3), (2, 3, 4), (3, 4, 5)) val A = drmParallelize(m = inCoreA, numPartitions = 2) val x: Vector = (0, 1, 2) @@ -30,7 +31,7 @@ class LATestSuit extends FunSuite with DistributedFlinkSuit { assert(b == dvec(8, 11, 14)) } - test("At") { + test("At sparseTrick") { val inCoreA = dense((1, 2, 3), (2, 3, 4)) val A = drmParallelize(m = inCoreA, numPartitions = 2) @@ -42,4 +43,23 @@ class LATestSuit extends FunSuite with DistributedFlinkSuit { assert((output - inCoreA.t).norm < 1e-6) } + test("AtB notZippable") { + val inCoreAt = dense((1, 2), (2, 3), (3, 4)) + + val At = drmParallelize(m = inCoreAt, numPartitions = 2) + + val inCoreB = dense((1, 2), (3, 4), (11, 4)) + val B = drmParallelize(m = inCoreB, numPartitions = 2) + + val opAtB = new OpAtB(At, B) + val res = FlinkOpAtB.notZippable(opAtB, At, B) + + val drm = new CheckpointedFlinkDrm(res.deblockify.ds, _nrow=inCoreAt.ncol, _ncol=inCoreB.ncol) + val output = drm.collect + + val expected = inCoreAt.t %*% inCoreB + assert((output - expected).norm < 1e-6) + } + + } \ No newline at end of file
