Repository: mahout
Updated Branches:
  refs/heads/master 6ac833bdc -> 6ab5a8d64


MAHOUT-1848: drmSampleKRows in FlinkEngine should generate a dense or sparse 
matrix, this closes apache/mahout#233


Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/6ab5a8d6
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/6ab5a8d6
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/6ab5a8d6

Branch: refs/heads/master
Commit: 6ab5a8d6456dfc52d7a951ae71618a6417516a07
Parents: 6ac833b
Author: smarthi <[email protected]>
Authored: Mon May 2 23:06:00 2016 -0400
Committer: smarthi <[email protected]>
Committed: Mon May 2 23:06:00 2016 -0400

----------------------------------------------------------------------
 .../mahout/flinkbindings/FlinkEngine.scala      | 27 ++++++++++++++------
 1 file changed, 19 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/6ab5a8d6/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
----------------------------------------------------------------------
diff --git 
a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala 
b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
index b3b72b0..f1d23b2 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
@@ -357,7 +357,16 @@ object FlinkEngine extends DistributedEngine {
     implicit val typeInformation = generateTypeInformation[K]
 
     val sample = DataSetUtils(drmX.dataset).sample(replacement, fraction)
-    new CheckpointedFlinkDrm[K](sample)
+
+    val res = if (kTag != ClassTag.Int) {
+      new CheckpointedFlinkDrm[K](sample)
+    }
+    else {
+      blas.rekeySeqInts(new RowsFlinkDrm[K](sample, ncol = drmX.ncol), 
computeMap = false)._1
+        .asInstanceOf[DrmLike[K]]
+    }
+
+    res
   }
 
   def drmSampleKRows[K](drmX: DrmLike[K], numSamples:Int, replacement: Boolean 
= false): Matrix = {
@@ -365,15 +374,17 @@ object FlinkEngine extends DistributedEngine {
     implicit val typeInformation = generateTypeInformation[K]
 
     val sample = DataSetUtils(drmX.dataset).sampleWithSize(replacement, 
numSamples)
+    val sampleArray = sample.collect().toArray
+    val isSparse = sampleArray.exists { case (_, vec) ⇒ !vec.isDense }
 
-    val res = if (kTag != ClassTag.Int) {
-      new CheckpointedFlinkDrm[K](sample)
-    }
-    else {
-      blas.rekeySeqInts(new RowsFlinkDrm[K](sample, ncol = drmX.ncol), 
computeMap = false)._1
-    }
+    val vectors = sampleArray.map(_._2)
+    val labels = sampleArray.view.zipWithIndex
+      .map { case ((key, _), idx) ⇒ key.toString → (idx: Integer) }.toMap
+
+    val mx: Matrix = if (isSparse) sparse(vectors: _*) else dense(vectors)
+    mx.setRowLabelBindings(labels)
 
-    res.collect
+    mx
   }
 
   /** Engine-specific all reduce tensor operation. */

Reply via email to