Repository: mahout Updated Branches: refs/heads/master 6ac833bdc -> 6ab5a8d64
MAHOUT-1848: drmSampleKRows in FlinkEngine should generate a dense or sparse matrix, this closes apache/mahout#233 Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/6ab5a8d6 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/6ab5a8d6 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/6ab5a8d6 Branch: refs/heads/master Commit: 6ab5a8d6456dfc52d7a951ae71618a6417516a07 Parents: 6ac833b Author: smarthi <[email protected]> Authored: Mon May 2 23:06:00 2016 -0400 Committer: smarthi <[email protected]> Committed: Mon May 2 23:06:00 2016 -0400 ---------------------------------------------------------------------- .../mahout/flinkbindings/FlinkEngine.scala | 27 ++++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/6ab5a8d6/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala index b3b72b0..f1d23b2 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala @@ -357,7 +357,16 @@ object FlinkEngine extends DistributedEngine { implicit val typeInformation = generateTypeInformation[K] val sample = DataSetUtils(drmX.dataset).sample(replacement, fraction) - new CheckpointedFlinkDrm[K](sample) + + val res = if (kTag != ClassTag.Int) { + new CheckpointedFlinkDrm[K](sample) + } + else { + blas.rekeySeqInts(new RowsFlinkDrm[K](sample, ncol = drmX.ncol), computeMap = false)._1 + .asInstanceOf[DrmLike[K]] + } + + res } def drmSampleKRows[K](drmX: DrmLike[K], numSamples:Int, replacement: Boolean = false): Matrix = { @@ -365,15 +374,17 @@ object FlinkEngine extends DistributedEngine { implicit val typeInformation = generateTypeInformation[K] val sample = DataSetUtils(drmX.dataset).sampleWithSize(replacement, numSamples) + val sampleArray = sample.collect().toArray + val isSparse = sampleArray.exists { case (_, vec) â !vec.isDense } - val res = if (kTag != ClassTag.Int) { - new CheckpointedFlinkDrm[K](sample) - } - else { - blas.rekeySeqInts(new RowsFlinkDrm[K](sample, ncol = drmX.ncol), computeMap = false)._1 - } + val vectors = sampleArray.map(_._2) + val labels = sampleArray.view.zipWithIndex + .map { case ((key, _), idx) â key.toString â (idx: Integer) }.toMap + + val mx: Matrix = if (isSparse) sparse(vectors: _*) else dense(vectors) + mx.setRowLabelBindings(labels) - res.collect + mx } /** Engine-specific all reduce tensor operation. */
