MAHOUT-1814:Implement drm2intKeyed in flink bindings, this closes apache/mahout#214
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/430310db Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/430310db Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/430310db Branch: refs/heads/master Commit: 430310dbf7870e263ea0340df9b226655ce82a72 Parents: 7c275f0 Author: smarthi <[email protected]> Authored: Sat Apr 9 23:04:53 2016 -0400 Committer: smarthi <[email protected]> Committed: Sat Apr 9 23:04:53 2016 -0400 ---------------------------------------------------------------------- .../mahout/flinkbindings/FlinkEngine.scala | 30 +++++- .../mahout/flinkbindings/blas/FlinkOpAtA.scala | 18 ++-- .../mahout/flinkbindings/blas/package.scala | 101 ++++++++++++++++++- 3 files changed, 133 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/430310db/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala index adff30b..fddb432 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala @@ -320,8 +320,34 @@ object FlinkEngine extends DistributedEngine { * Convert non-int-keyed matrix to an int-keyed, computing optionally mapping from old keys * to row indices in the new one. The mapping, if requested, is returned as a 1-column matrix. */ - def drm2IntKeyed[K](drmX: DrmLike[K], computeMap: Boolean = false): - (DrmLike[Int], Option[DrmLike[K]]) = ??? + def drm2IntKeyed[K](drmX: DrmLike[K], computeMap: Boolean = false): (DrmLike[Int], Option[DrmLike[K]]) = { + implicit val ktag = drmX.keyClassTag + implicit val kTypeInformation = generateTypeInformation[K] + + if (ktag == ClassTag.Int) { + drmX.asInstanceOf[DrmLike[Int]] â None + } else { + val drmXcp = drmX.checkpoint(CacheHint.MEMORY_ONLY) + val ncol = drmXcp.asInstanceOf[CheckpointedFlinkDrm[K]].ncol + val nrow = drmXcp.asInstanceOf[CheckpointedFlinkDrm[K]].nrow + + // Compute sequential int key numbering. + val (intDataset, keyMap) = blas.rekeySeqInts(drmDataSet = drmXcp, computeMap = computeMap) + + // Convert computed key mapping to a matrix. + val mxKeyMap = keyMap.map { dataSet â + datasetWrap(dataSet.map { + tuple: (K, Int) => { + val ordinal = tuple._2 + val key = tuple._1 + key -> (dvec(ordinal): Vector) + } + }) + } + + intDataset -> mxKeyMap + } + } /** * (Optional) Sampling operation. http://git-wip-us.apache.org/repos/asf/mahout/blob/430310db/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala index ab99e4d..6d0221a 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala @@ -21,27 +21,21 @@ package org.apache.mahout.flinkbindings.blas import java.lang.Iterable import org.apache.flink.api.common.functions._ -import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.scala._ import org.apache.flink.configuration.Configuration import org.apache.flink.shaded.com.google.common.collect.Lists import org.apache.flink.util.Collector - -import org.apache.mahout.math.{Matrix, UpperTriangular} -import org.apache.mahout.math.drm.{BlockifiedDrmTuple, _} - -import org.apache.mahout.math._ import org.apache.mahout.flinkbindings._ import org.apache.mahout.flinkbindings.drm._ -import org.apache.mahout.math.scalabindings._ -import RLikeOps._ -import collection._ -import JavaConversions._ import org.apache.mahout.math.drm.logical.OpAtA +import org.apache.mahout.math.drm.{BlockifiedDrmTuple, _} +import org.apache.mahout.math.scalabindings.RLikeOps._ +import org.apache.mahout.math.scalabindings._ +import org.apache.mahout.math.{Matrix, UpperTriangular, _} - +import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ -import scala.reflect.ClassTag +import scala.collection._ /** * Inspired by Spark's implementation from http://git-wip-us.apache.org/repos/asf/mahout/blob/430310db/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala index 6a3ac0e..265951c 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala @@ -21,18 +21,27 @@ package org.apache.mahout.flinkbindings import java.lang.Iterable import org.apache.flink.api.common.functions.RichMapPartitionFunction +import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.scala._ +import org.apache.flink.configuration.Configuration import org.apache.flink.util.Collector +import org.apache.mahout.flinkbindings.drm.FlinkDrm +import org.apache.mahout.math.drm.DrmLike +import org.apache.mahout.math.{RandomAccessSparseVector, Vector} +import org.apache.mahout.math._ +import scalabindings._ +import RLikeOps._ import scala.collection._ +import scala.reflect.ClassTag package object blas { /** * To compute tuples (PartitionIndex, PartitionElementCount) * - * @param drmDataSet - * @tparam K + * @param drmDataSet - DRM Dataset + * @tparam K - Key type * @return (PartitionIndex, PartitionElementCount) */ //TODO: Remove this when FLINK-3657 is merged into Flink codebase and @@ -48,4 +57,92 @@ package object blas { } } } + + /** + * Rekey matrix dataset keys to consecutive int keys. + * @param drmDataSet incoming matrix row-wise dataset + * @param computeMap if true, also compute mapping between old and new keys + * @tparam K existing key parameter + * @return + */ + private[mahout] def rekeySeqInts[K: ClassTag: TypeInformation](drmDataSet: FlinkDrm[K], + computeMap: Boolean = true): (DrmLike[Int], + Option[DataSet[(K, Int)]]) = { + + implicit val dc = drmDataSet.context + + val datasetA = drmDataSet.asRowWise.ds + + val ncols = drmDataSet.asRowWise.ncol + + // Flink environment + val env = datasetA.getExecutionEnvironment + + // First, compute partition sizes. + val partSizes = countsPerPartition(datasetA).collect().toList + + // Starting indices + var startInd = new Array[Int](datasetA.getParallelism) + + // Save counts + for (pc <- partSizes) startInd(pc._1) = pc._2 + + // compute cumulative sum + val cumulativeSum = startInd.scanLeft(0)(_ + _).init + + val vector: Vector = new RandomAccessSparseVector(cumulativeSum.length) + + cumulativeSum.indices.foreach { i => vector(i) = cumulativeSum(i).toDouble } + + val bCast = FlinkEngine.drmBroadcast(vector) + + implicit val typeInformation = createTypeInformation[(K, Int)] + + // Compute key -> int index map: + val keyMap = if (computeMap) { + Some( + datasetA.mapPartition(new RichMapPartitionFunction[(K, Vector), (K, Int)] { + + // partition number + var part: Int = 0 + + // get the index of the partition + override def open(params: Configuration): Unit = { + part = getRuntimeContext.getIndexOfThisSubtask + } + + override def mapPartition(iterable: Iterable[(K, Vector)], collector: Collector[(K, Int)]): Unit = { + val k = iterable.iterator().next._1 + val si = bCast.value.get(part) + collector.collect(k -> (part + si).toInt) + } + })) + } else { + None + } + + // Finally, do the transform + val intDataSet = datasetA + + // Re-number each partition + .mapPartition(new RichMapPartitionFunction[(K, Vector), (Int, Vector)] { + + // partition number + var part: Int = 0 + + // get the index of the partition + override def open(params: Configuration): Unit = { + part = getRuntimeContext.getIndexOfThisSubtask + } + + override def mapPartition(iterable: Iterable[(K, Vector)], collector: Collector[(Int, Vector)]): Unit = { + val k = iterable.iterator().next._2 + val si = bCast.value.get(part) + collector.collect((part + si).toInt -> k) + } + }) + + // Finally, return drm -> keymap result + datasetWrap(intDataSet) -> keyMap + } } \ No newline at end of file
