MAHOUT-1749 Mahout DSL for Flink: Implement Atx closes apache/mahout#204

Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/a77f1c13
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/a77f1c13
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/a77f1c13

Branch: refs/heads/master
Commit: a77f1c13de58d462eed7ff224ec333b22ac22bf3
Parents: e3c8db5
Author: Andrew Palumbo <[email protected]>
Authored: Sat Mar 26 23:22:44 2016 -0400
Committer: Andrew Palumbo <[email protected]>
Committed: Sat Mar 26 23:22:44 2016 -0400

----------------------------------------------------------------------
 .../mahout/flinkbindings/FlinkEngine.scala      |  8 +---
 .../mahout/flinkbindings/blas/FlinkOpAx.scala   | 42 +++++++++++++++++++-
 2 files changed, 41 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/a77f1c13/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
----------------------------------------------------------------------
diff --git 
a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala 
b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
index c355cae..0fd2e05 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
@@ -137,13 +137,7 @@ object FlinkEngine extends DistributedEngine {
         FlinkOpAx.blockifiedBroadcastAx(op, flinkTranslate(a))
       case op@OpAt(a) if op.keyClassTag == ClassTag.Int ⇒ 
FlinkOpAt.sparseTrick(op, flinkTranslate(a)).asInstanceOf[FlinkDrm[K]]
       case op@OpAtx(a, x) if op.keyClassTag == ClassTag.Int ⇒
-        // express Atx as (A.t) %*% x
-        // TODO: create specific implementation of Atx, see MAHOUT-1749
-        val opAt = OpAt(a)
-        val at = FlinkOpAt.sparseTrick(opAt, flinkTranslate(a))
-        val atCast = new CheckpointedFlinkDrm(at.asRowWise.ds, _nrow = 
opAt.nrow, _ncol = opAt.ncol)
-        val opAx = OpAx(atCast, x)
-        FlinkOpAx.blockifiedBroadcastAx(opAx, 
flinkTranslate(atCast)).asInstanceOf[FlinkDrm[K]]
+        FlinkOpAx.atx_with_broadcast(op, 
flinkTranslate(a)).asInstanceOf[FlinkDrm[K]]
       case op@OpAtB(a, b) ⇒ FlinkOpAtB.notZippable(op, flinkTranslate(a),
         flinkTranslate(b)).asInstanceOf[FlinkDrm[K]]
       case op@OpABt(a, b) ⇒

http://git-wip-us.apache.org/repos/asf/mahout/blob/a77f1c13/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAx.scala
----------------------------------------------------------------------
diff --git 
a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAx.scala 
b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAx.scala
index ec20b6d..8a333c4 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAx.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAx.scala
@@ -24,9 +24,12 @@ import org.apache.flink.api.common.functions.RichMapFunction
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.scala._
 import org.apache.flink.configuration.Configuration
-import org.apache.mahout.flinkbindings.drm.{BlockifiedFlinkDrm, FlinkDrm}
-import org.apache.mahout.math.drm.logical.OpAx
+import org.apache.mahout.flinkbindings.FlinkEngine
+import org.apache.mahout.flinkbindings.drm.{BlockifiedFlinkDrm, FlinkDrm, 
RowsFlinkDrm}
+import org.apache.mahout.math.drm._
+import org.apache.mahout.math.drm.logical.{OpAtx, OpAx}
 import org.apache.mahout.math.scalabindings.RLikeOps._
+import org.apache.mahout.math.scalabindings._
 import org.apache.mahout.math.{Matrix, Vector}
 
 
@@ -58,4 +61,39 @@ object FlinkOpAx {
 
     new BlockifiedFlinkDrm(out, op.nrow.toInt)
   }
+
+
+  def atx_with_broadcast(op: OpAtx, srcA: FlinkDrm[Int]): FlinkDrm[Int] = {
+    implicit val ctx = srcA.context
+
+    val dataSetA = srcA.asBlockified.ds
+
+    // broadcast the vector x to the back end
+    val bcastX = drmBroadcast(op.x)
+
+    implicit val typeInformation = createTypeInformation[(Array[Int],Matrix)]
+    val inCoreM = dataSetA.map {
+      tuple =>
+        tuple._1.zipWithIndex.map {
+          case (key, idx) => tuple._2(idx, ::) * bcastX.value(key)
+        }
+          .reduce(_ += _)
+    }
+      // All-reduce
+      .reduce(_ += _)
+
+      // collect result
+      .collect()(0)
+
+      // Convert back to mtx
+      .toColMatrix
+
+    // It is ridiculous, but in this scheme we will have to re-parallelize it 
again in order to plug
+    // it back as a Flink drm
+    val res = FlinkEngine.parallelize(inCoreM, parallelismDegree = 1)
+
+    new RowsFlinkDrm[Int](res, 1)
+
+  }
+
 }
\ No newline at end of file

Reply via email to