MAHOUT-1751: Flink: AtA slim
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/9d48487c Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/9d48487c Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/9d48487c Branch: refs/heads/flink-binding Commit: 9d48487cf2a7060193b305c152b8ce191f7d15be Parents: ceb1f05 Author: Alexey Grigorev <[email protected]> Authored: Fri Aug 21 15:33:32 2015 +0200 Committer: Alexey Grigorev <[email protected]> Committed: Fri Sep 25 17:42:50 2015 +0200 ---------------------------------------------------------------------- .../mahout/flinkbindings/FlinkEngine.scala | 12 +--- .../mahout/flinkbindings/blas/FlinkOpAtA.scala | 74 ++++++++++++++++++++ .../mahout/flinkbindings/blas/FlinkOpAtB.scala | 1 + .../drm/CheckpointedFlinkDrm.scala | 27 ++++--- 4 files changed, 95 insertions(+), 19 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/9d48487c/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala index 3076933..8e47629 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala @@ -19,10 +19,8 @@ package org.apache.mahout.flinkbindings import java.util.Collection - import scala.collection.JavaConverters._ import scala.reflect.ClassTag - import org.apache.flink.api.common.functions.MapFunction import org.apache.flink.api.common.functions.ReduceFunction import org.apache.flink.api.java.tuple.Tuple2 @@ -79,6 +77,7 @@ import org.apache.mahout.math.indexeddataset.IndexedDataset import org.apache.mahout.math.indexeddataset.Schema import org.apache.mahout.math.scalabindings._ import org.apache.mahout.math.scalabindings.RLikeOps._ +import org.apache.mahout.flinkbindings.blas.FlinkOpAtA object FlinkEngine extends DistributedEngine { @@ -165,14 +164,7 @@ object FlinkEngine extends DistributedEngine { FlinkOpAtB.notZippable(OpAtB(c, d), flinkTranslate(c), flinkTranslate(d)) .asInstanceOf[FlinkDrm[K]] } - case op @ OpAtA(a) => { - // express AtA via AtB - // TODO: create specific implementation of AtA, see MAHOUT-1751 - val aInt = a.asInstanceOf[DrmLike[Int]] // TODO: casts! - val opAtB = OpAtB(aInt, aInt) - val aTranslated = flinkTranslate(aInt) - FlinkOpAtB.notZippable(opAtB, aTranslated, aTranslated) - } + case op @ OpAtA(a) => FlinkOpAtA.at_a(op, flinkTranslate(a)(op.classTagA)) case op @ OpTimesRightMatrix(a, b) => FlinkOpTimesRightMatrix.drmTimesInCore(op, flinkTranslate(a)(op.classTagA), b) case op @ OpAewUnaryFunc(a, f, _) => http://git-wip-us.apache.org/repos/asf/mahout/blob/9d48487c/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala new file mode 100644 index 0000000..63d1845 --- /dev/null +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala @@ -0,0 +1,74 @@ +package org.apache.mahout.flinkbindings.blas + +import java.lang.Iterable +import scala.collection.JavaConverters.asScalaBufferConverter +import scala.reflect.ClassTag +import org.apache.mahout.math.drm._ +import org.apache.flink.api.common.functions.CoGroupFunction +import org.apache.flink.api.java.DataSet +import org.apache.flink.util.Collector +import org.apache.mahout.flinkbindings._ +import org.apache.mahout.flinkbindings.drm.FlinkDrm +import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm +import org.apache.mahout.math._ +import org.apache.mahout.math.Vector +import org.apache.mahout.math.drm._ +import org.apache.mahout.math.drm.logical._ +import org.apache.mahout.math.scalabindings.RLikeOps._ +import com.google.common.collect.Lists +import org.apache.flink.shaded.com.google.common.collect.Lists +import org.apache.flink.util.Collector +import org.apache.mahout.flinkbindings.drm.FlinkDrm +import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm +import org.apache.mahout.math.Matrix +import org.apache.mahout.math.SequentialAccessSparseVector +import org.apache.mahout.math.Vector +import org.apache.mahout.math.drm.DrmTuple +import org.apache.mahout.math.drm.logical.OpAt +import org.apache.mahout.math.scalabindings.RLikeOps._ +import org.apache.flink.api.common.functions.MapFunction +import org.apache.flink.api.common.functions.ReduceFunction + + +/** + */ +object FlinkOpAtA { + + final val PROPERTY_ATA_MAXINMEMNCOL = "mahout.math.AtA.maxInMemNCol" + final val PROPERTY_ATA_MAXINMEMNCOL_DEFAULT = "200" + + + def at_a(op: OpAtA[_], A: FlinkDrm[_]): FlinkDrm[Int] = { + val maxInMemStr = System.getProperty(PROPERTY_ATA_MAXINMEMNCOL, PROPERTY_ATA_MAXINMEMNCOL_DEFAULT) + val maxInMemNCol = maxInMemStr.toInt + maxInMemNCol.ensuring(_ > 0, "Invalid A'A in-memory setting for optimizer") + + if (op.ncol <= maxInMemNCol) { + implicit val ctx = A.context + val inCoreAtA = slim(op, A) + val result = drmParallelize(inCoreAtA, numPartitions = 1) + result + } else { + fat(op, A) + } + } + + def slim(op: OpAtA[_], A: FlinkDrm[_]): Matrix = { + val ds = A.blockify.ds.asInstanceOf[DataSet[(Array[Any], Matrix)]] + + val res = ds.map(new MapFunction[(Array[Any], Matrix), Matrix] { + // TODO: optimize it: use upper-triangle matrices like in Spark + def map(block: (Array[Any], Matrix)): Matrix = block match { + case (idx, m) => m.t %*% m + } + }).reduce(new ReduceFunction[Matrix] { + def reduce(m1: Matrix, m2: Matrix) = m1 + m2 + }).collect() + + res.asScala.head + } + + def fat(op: OpAtA[_], A: FlinkDrm[_]): FlinkDrm[Int] = { + throw new NotImplementedError("fat matrices are not yet supported") + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/9d48487c/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala index 297f676..f02cd84 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala @@ -42,6 +42,7 @@ import org.apache.mahout.math.scalabindings.RLikeOps._ import com.google.common.collect.Lists + /** * Implementation is taken from Spark's AtB * https://github.com/apache/mahout/blob/master/spark/src/main/scala/org/apache/mahout/sparkbindings/blas/AtB.scala http://git-wip-us.apache.org/repos/asf/mahout/blob/9d48487c/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala index e29b80c..f58e05b 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala @@ -94,26 +94,35 @@ class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K], val data = ds.collect().asScala.toList val isDense = data.forall(_._2.isDense) + val cols = ncol + val rows = safeToNonNegInt(nrow) + val m = if (isDense) { - val cols = data.head._2.size() - val rows = data.length new DenseMatrix(rows, cols) } else { - val cols = ncol - val rows = safeToNonNegInt(nrow) new SparseMatrix(rows, cols) } val intRowIndices = keyClassTag == implicitly[ClassTag[Int]] - if (intRowIndices) - data.foreach(t => m(t._1.asInstanceOf[Int], ::) := t._2) - else { + if (intRowIndices) { + data.foreach { case (t, vec) => + val idx = t.asInstanceOf[Int] + m(idx, ::) := vec + } + + println(m.ncol, m.nrow) + } else { // assign all rows sequentially val d = data.zipWithIndex - d.foreach(t => m(t._2, ::) := t._1._2) + d.foreach { + case ((_, vec), idx) => m(idx, ::) := vec + } + + val rowBindings = d.map { + case ((t, _), idx) => (t.toString, idx: java.lang.Integer) + }.toMap.asJava - val rowBindings = d.map(t => (t._1._1.toString, t._2: java.lang.Integer)).toMap.asJava m.setRowLabelBindings(rowBindings) }
