MAHOUT-1749 Mahout DSL for Flink: Implement Atx closes apache/mahout#204
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/a77f1c13 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/a77f1c13 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/a77f1c13 Branch: refs/heads/master Commit: a77f1c13de58d462eed7ff224ec333b22ac22bf3 Parents: e3c8db5 Author: Andrew Palumbo <[email protected]> Authored: Sat Mar 26 23:22:44 2016 -0400 Committer: Andrew Palumbo <[email protected]> Committed: Sat Mar 26 23:22:44 2016 -0400 ---------------------------------------------------------------------- .../mahout/flinkbindings/FlinkEngine.scala | 8 +--- .../mahout/flinkbindings/blas/FlinkOpAx.scala | 42 +++++++++++++++++++- 2 files changed, 41 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/a77f1c13/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala index c355cae..0fd2e05 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala @@ -137,13 +137,7 @@ object FlinkEngine extends DistributedEngine { FlinkOpAx.blockifiedBroadcastAx(op, flinkTranslate(a)) case op@OpAt(a) if op.keyClassTag == ClassTag.Int â FlinkOpAt.sparseTrick(op, flinkTranslate(a)).asInstanceOf[FlinkDrm[K]] case op@OpAtx(a, x) if op.keyClassTag == ClassTag.Int â - // express Atx as (A.t) %*% x - // TODO: create specific implementation of Atx, see MAHOUT-1749 - val opAt = OpAt(a) - val at = FlinkOpAt.sparseTrick(opAt, flinkTranslate(a)) - val atCast = new CheckpointedFlinkDrm(at.asRowWise.ds, _nrow = opAt.nrow, _ncol = opAt.ncol) - val opAx = OpAx(atCast, x) - FlinkOpAx.blockifiedBroadcastAx(opAx, flinkTranslate(atCast)).asInstanceOf[FlinkDrm[K]] + FlinkOpAx.atx_with_broadcast(op, flinkTranslate(a)).asInstanceOf[FlinkDrm[K]] case op@OpAtB(a, b) â FlinkOpAtB.notZippable(op, flinkTranslate(a), flinkTranslate(b)).asInstanceOf[FlinkDrm[K]] case op@OpABt(a, b) â http://git-wip-us.apache.org/repos/asf/mahout/blob/a77f1c13/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAx.scala ---------------------------------------------------------------------- diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAx.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAx.scala index ec20b6d..8a333c4 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAx.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAx.scala @@ -24,9 +24,12 @@ import org.apache.flink.api.common.functions.RichMapFunction import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.scala._ import org.apache.flink.configuration.Configuration -import org.apache.mahout.flinkbindings.drm.{BlockifiedFlinkDrm, FlinkDrm} -import org.apache.mahout.math.drm.logical.OpAx +import org.apache.mahout.flinkbindings.FlinkEngine +import org.apache.mahout.flinkbindings.drm.{BlockifiedFlinkDrm, FlinkDrm, RowsFlinkDrm} +import org.apache.mahout.math.drm._ +import org.apache.mahout.math.drm.logical.{OpAtx, OpAx} import org.apache.mahout.math.scalabindings.RLikeOps._ +import org.apache.mahout.math.scalabindings._ import org.apache.mahout.math.{Matrix, Vector} @@ -58,4 +61,39 @@ object FlinkOpAx { new BlockifiedFlinkDrm(out, op.nrow.toInt) } + + + def atx_with_broadcast(op: OpAtx, srcA: FlinkDrm[Int]): FlinkDrm[Int] = { + implicit val ctx = srcA.context + + val dataSetA = srcA.asBlockified.ds + + // broadcast the vector x to the back end + val bcastX = drmBroadcast(op.x) + + implicit val typeInformation = createTypeInformation[(Array[Int],Matrix)] + val inCoreM = dataSetA.map { + tuple => + tuple._1.zipWithIndex.map { + case (key, idx) => tuple._2(idx, ::) * bcastX.value(key) + } + .reduce(_ += _) + } + // All-reduce + .reduce(_ += _) + + // collect result + .collect()(0) + + // Convert back to mtx + .toColMatrix + + // It is ridiculous, but in this scheme we will have to re-parallelize it again in order to plug + // it back as a Flink drm + val res = FlinkEngine.parallelize(inCoreM, parallelismDegree = 1) + + new RowsFlinkDrm[Int](res, 1) + + } + } \ No newline at end of file
