MAHOUT-1570: Flink: calculating ncol, nrow; colSum, colMean, norm methods
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/df1db7cc Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/df1db7cc Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/df1db7cc Branch: refs/heads/flink-binding Commit: df1db7cc775ed5e10c6416e033e25e430ffdd171 Parents: 522f3d5 Author: Alexey Grigorev <[email protected]> Authored: Tue May 26 16:17:14 2015 +0200 Committer: Alexey Grigorev <[email protected]> Committed: Fri Sep 25 17:41:45 2015 +0200 ---------------------------------------------------------------------- .../mahout/flinkbindings/FlinkEngine.scala | 30 +++++++++-- .../drm/CheckpointedFlinkDrm.scala | 52 +++++++++++++------- 2 files changed, 62 insertions(+), 20 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/df1db7cc/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala index 03d1a9c..6696152 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala @@ -44,6 +44,8 @@ import org.apache.mahout.math.drm.logical.OpRbind import org.apache.mahout.math.drm.logical.OpMapBlock import org.apache.mahout.math.drm.logical.OpRowRange import org.apache.mahout.math.drm.logical.OpTimesRightMatrix +import org.apache.flink.api.common.functions.MapFunction +import org.apache.flink.api.common.functions.ReduceFunction object FlinkEngine extends DistributedEngine { @@ -119,15 +121,37 @@ object FlinkEngine extends DistributedEngine { def translate[K: ClassTag](oper: DrmLike[K]): DataSet[K] = ??? /** Engine-specific colSums implementation based on a checkpoint. */ - override def colSums[K: ClassTag](drm: CheckpointedDrm[K]): Vector = ??? + override def colSums[K: ClassTag](drm: CheckpointedDrm[K]): Vector = { + val sum = drm.ds.map(new MapFunction[(K, Vector), Vector] { + def map(tuple: (K, Vector)): Vector = tuple._2 + }).reduce(new ReduceFunction[Vector] { + def reduce(v1: Vector, v2: Vector) = v1 + v2 + }) + + val list = CheckpointedFlinkDrm.flinkCollect(sum, "FlinkEngine colSums()") + list.head + } /** Engine-specific numNonZeroElementsPerColumn implementation based on a checkpoint. */ override def numNonZeroElementsPerColumn[K: ClassTag](drm: CheckpointedDrm[K]): Vector = ??? /** Engine-specific colMeans implementation based on a checkpoint. */ - override def colMeans[K: ClassTag](drm: CheckpointedDrm[K]): Vector = ??? + override def colMeans[K: ClassTag](drm: CheckpointedDrm[K]): Vector = { + drm.colSums() / drm.nrow + } - override def norm[K: ClassTag](drm: CheckpointedDrm[K]): Double = ??? + override def norm[K: ClassTag](drm: CheckpointedDrm[K]): Double = { + val sumOfSquares = drm.ds.map(new MapFunction[(K, Vector), Double] { + def map(tuple: (K, Vector)): Double = tuple match { + case (idx, vec) => vec dot vec + } + }).reduce(new ReduceFunction[Double] { + def reduce(v1: Double, v2: Double) = v1 + v2 + }) + + val list = CheckpointedFlinkDrm.flinkCollect(sumOfSquares, "FlinkEngine norm()") + list.head + } /** Broadcast support */ override def drmBroadcast(v: Vector)(implicit dc: DistributedContext): BCast[Vector] = ??? http://git-wip-us.apache.org/repos/asf/mahout/blob/df1db7cc/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala index c19920f..e7d9dcd 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala @@ -5,7 +5,6 @@ import org.apache.mahout.math.drm._ import org.apache.mahout.math.scalabindings._ import RLikeOps._ import org.apache.mahout.flinkbindings._ - import org.apache.mahout.math.drm.CheckpointedDrm import org.apache.mahout.math.Matrix import org.apache.mahout.flinkbindings.FlinkDistributedContext @@ -17,8 +16,10 @@ import org.apache.mahout.math.DenseMatrix import org.apache.mahout.math.SparseMatrix import org.apache.flink.api.java.io.LocalCollectionOutputFormat import java.util.ArrayList - import scala.collection.JavaConverters._ +import org.apache.flink.api.common.functions.MapFunction +import org.apache.flink.api.common.functions.ReduceFunction +import org.apache.flink.api.java.DataSet class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K], private var _nrow: Long = CheckpointedFlinkDrm.UNKNOWN, @@ -27,20 +28,31 @@ class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K], override protected[mahout] val partitioningTag: Long = Random.nextLong(), private var _canHaveMissingRows: Boolean = false) extends CheckpointedDrm[K] { - lazy val nrow = if (_nrow >= 0) _nrow else computeNRow - lazy val ncol = if (_ncol >= 0) _ncol else computeNCol + lazy val nrow: Long = if (_nrow >= 0) _nrow else computeNRow + lazy val ncol: Int = if (_ncol >= 0) _ncol else computeNCol + + protected def computeNRow: Long = { + val count = ds.map(new MapFunction[DrmTuple[K], Long] { + def map(value: DrmTuple[K]): Long = 1L + }).reduce(new ReduceFunction[Long] { + def reduce(a1: Long, a2: Long) = a1 + a2 + }) + + val list = CheckpointedFlinkDrm.flinkCollect(count, "CheckpointedFlinkDrm computeNRow()") + list.head + } - protected def computeNRow = ??? - protected def computeNCol = ??? /*{ - TODO: find out how to get one value + protected def computeNCol: Int = { val max = ds.map(new MapFunction[DrmTuple[K], Int] { def map(value: DrmTuple[K]): Int = value._2.length }).reduce(new ReduceFunction[Int] { def reduce(a1: Int, a2: Int) = Math.max(a1, a2) }) - - max - }*/ + + val list = CheckpointedFlinkDrm.flinkCollect(max, "CheckpointedFlinkDrm computeNCol()") + list.head + } + def keyClassTag: ClassTag[K] = implicitly[ClassTag[K]] def cache() = { @@ -57,12 +69,7 @@ class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K], def checkpoint(cacheHint: CacheHint.CacheHint): CheckpointedDrm[K] = this def collect: Matrix = { - val dataJavaList = new ArrayList[DrmTuple[K]] - val outputFormat = new LocalCollectionOutputFormat[DrmTuple[K]](dataJavaList) - ds.output(outputFormat) - val data = dataJavaList.asScala - ds.getExecutionEnvironment.execute("Checkpointed Flink Drm collect()") - + val data = CheckpointedFlinkDrm.flinkCollect(ds, "Checkpointed Flink Drm collect()") val isDense = data.forall(_._2.isDense) val m = if (isDense) { @@ -99,5 +106,16 @@ class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K], } object CheckpointedFlinkDrm { - val UNKNOWN = -1; + val UNKNOWN = -1 + + // needed for backwards compatibility with flink 0.8.1 + def flinkCollect[K](dataset: DataSet[K], jobName: String = "flinkCollect()"): List[K] = { + val dataJavaList = new ArrayList[K] + val outputFormat = new LocalCollectionOutputFormat[K](dataJavaList) + dataset.output(outputFormat) + val data = dataJavaList.asScala + dataset.getExecutionEnvironment.execute(jobName) + data.toList + } + } \ No newline at end of file
