MAHOUT-1570: Flink: nrow and ncol optimized
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/be815fb2 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/be815fb2 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/be815fb2 Branch: refs/heads/flink-binding Commit: be815fb25f4c008bf3809bb444e3d8562dea96fa Parents: 851eebc Author: Alexey Grigorev <[email protected]> Authored: Thu Jun 25 15:52:56 2015 +0200 Committer: Alexey Grigorev <[email protected]> Committed: Fri Sep 25 17:41:56 2015 +0200 ---------------------------------------------------------------------- .../drm/CheckpointedFlinkDrm.scala | 32 +++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/be815fb2/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala index ecd8b39..e29b80c 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala @@ -54,28 +54,24 @@ class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K], private var _canHaveMissingRows: Boolean = false ) extends CheckpointedDrm[K] { - lazy val nrow: Long = if (_nrow >= 0) _nrow else computeNRow - lazy val ncol: Int = if (_ncol >= 0) _ncol else computeNCol - - protected def computeNRow: Long = { - val count = ds.map(new MapFunction[DrmTuple[K], Long] { - def map(value: DrmTuple[K]): Long = 1L - }).reduce(new ReduceFunction[Long] { - def reduce(a1: Long, a2: Long) = a1 + a2 - }) + lazy val nrow: Long = if (_nrow >= 0) _nrow else dim._1 + lazy val ncol: Int = if (_ncol >= 0) _ncol else dim._2 - val list = count.collect().asScala.toList - list.head - } + private lazy val dim: (Long, Int) = { + // combine computation of ncol and nrow in one pass - protected def computeNCol: Int = { - val max = ds.map(new MapFunction[DrmTuple[K], Int] { - def map(value: DrmTuple[K]): Int = value._2.length - }).reduce(new ReduceFunction[Int] { - def reduce(a1: Int, a2: Int) = Math.max(a1, a2) + val res = ds.map(new MapFunction[DrmTuple[K], (Long, Int)] { + def map(value: DrmTuple[K]): (Long, Int) = { + (1L, value._2.length) + } + }).reduce(new ReduceFunction[(Long, Int)] { + def reduce(t1: (Long, Int), t2: (Long, Int)) = { + val ((rowCnt1, colNum1), (rowCnt2, colNum2)) = (t1, t2) + (rowCnt1 + rowCnt2, Math.max(colNum1, colNum2)) + } }) - val list = max.collect().asScala.toList + val list = res.collect().asScala.toList list.head }
