Repository: mahout Updated Branches: refs/heads/master 1e0812876 -> b3b72cb65
MAHOUT-1896: Add convenience methods for interacting with SparkML closes apache/mahout-263 Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/b3b72cb6 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/b3b72cb6 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/b3b72cb6 Branch: refs/heads/master Commit: b3b72cb658ed6ad59a092eb554b06898163a1375 Parents: 1e08128 Author: rawkintrevo <[email protected]> Authored: Fri Jan 13 13:52:35 2017 -0600 Committer: rawkintrevo <[email protected]> Committed: Fri Jan 13 13:52:35 2017 -0600 ---------------------------------------------------------------------- spark/pom.xml | 6 ++ .../apache/mahout/sparkbindings/package.scala | 51 +++++++++++++- .../mahout/sparkbindings/drm/DrmLikeSuite.scala | 74 ++++++++++++++++++++ 3 files changed, 130 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/b3b72cb6/spark/pom.xml ---------------------------------------------------------------------- diff --git a/spark/pom.xml b/spark/pom.xml index 5fc9863..f965d38 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -149,6 +149,12 @@ </dependency> <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-mllib_${scala.compat.version}</artifactId> + <version>${spark.version}</version> + </dependency> + + <dependency> <groupId>org.apache.mahout</groupId> <artifactId>mahout-math-scala_${scala.compat.version}</artifactId> </dependency> http://git-wip-us.apache.org/repos/asf/mahout/blob/b3b72cb6/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala index acca75e..8064cf0 100644 --- a/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala @@ -21,12 +21,15 @@ import java.io._ import org.apache.mahout.logging._ import org.apache.mahout.math.drm._ -import org.apache.mahout.math.{MatrixWritable, VectorWritable, Matrix, Vector} +import org.apache.mahout.math.{Matrix, MatrixWritable, Vector, VectorWritable} import org.apache.mahout.sparkbindings.drm.{CheckpointedDrmSpark, CheckpointedDrmSparkOps, SparkBCast} import org.apache.mahout.util.IOUtilsScala import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.{Vector => SparkVector, SparseVector => SparseSparkVector, DenseVector => DenseSparkVector} +import org.apache.spark.sql.DataFrame import collection._ import collection.generic.Growable @@ -141,6 +144,52 @@ package object sparkbindings { new CheckpointedDrmSpark[K](rddInput = rdd, _nrow = nrow, _ncol = ncol, cacheHint = cacheHint, _canHaveMissingRows = canHaveMissingRows) + /** A drmWrap version that takes an RDD[org.apache.spark.mllib.regression.LabeledPoint] + * returns a DRM where column the label is the last column */ + def drmWrapMLLibLabeledPoint(rdd: RDD[LabeledPoint], + nrow: Long = -1, + ncol: Int = -1, + cacheHint: CacheHint.CacheHint = CacheHint.NONE, + canHaveMissingRows: Boolean = false): CheckpointedDrm[Int] = { + val drmRDD: DrmRdd[Int] = rdd.zipWithIndex.map(lv => { + lv._1.features match { + case _: DenseSparkVector => (lv._2.toInt, new org.apache.mahout.math.DenseVector( lv._1.features.toArray ++ Array(lv._1.label) )) + case _: SparseSparkVector => (lv._2.toInt, + new org.apache.mahout.math.RandomAccessSparseVector(new org.apache.mahout.math.DenseVector( lv._1.features.toArray ++ Array(lv._1.label) )) ) + } + }) + + drmWrap(drmRDD, nrow, ncol, cacheHint, canHaveMissingRows) + } + + /** A drmWrap version that takes a DataFrame of Row[Double] */ + def drmWrapDataFrame(df: DataFrame, + nrow: Long = -1, + ncol: Int = -1, + cacheHint: CacheHint.CacheHint = CacheHint.NONE, + canHaveMissingRows: Boolean = false): CheckpointedDrm[Int] = { + val drmRDD: DrmRdd[Int] = df.rdd + .zipWithIndex + .map( o => (o._2.toInt, o._1.mkString(",").split(",").map(s => s.toDouble)) ) + .map(o => (o._1, new org.apache.mahout.math.DenseVector( o._2 ))) + + drmWrap(drmRDD, nrow, ncol, cacheHint, canHaveMissingRows) + } + + /** A drmWrap Version that takes an RDD[org.apache.spark.mllib.linalg.Vector] */ + def drmWrapMLLibVector(rdd: RDD[SparkVector], + nrow: Long = -1, + ncol: Int = -1, + cacheHint: CacheHint.CacheHint = CacheHint.NONE, + canHaveMissingRows: Boolean = false): CheckpointedDrm[Int] = { + val drmRDD: DrmRdd[Int] = rdd.zipWithIndex.map( v => { + v._1 match { + case _: DenseSparkVector => (v._2.toInt, new org.apache.mahout.math.DenseVector(v._1.toArray)) + case _: SparseSparkVector => (v._2.toInt, new org.apache.mahout.math.RandomAccessSparseVector(new org.apache.mahout.math.DenseVector(v._1.toArray)) ) + } + }) + drmWrap(drmRDD, nrow, ncol, cacheHint, canHaveMissingRows) + } /** Another drmWrap version that takes in vertical block-partitioned input to form the matrix. */ def drmWrapBlockified[K: ClassTag](blockifiedDrmRdd: BlockifiedDrmRdd[K], nrow: Long = -1, ncol: Int = -1, http://git-wip-us.apache.org/repos/asf/mahout/blob/b3b72cb6/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeSuite.scala ---------------------------------------------------------------------- diff --git a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeSuite.scala b/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeSuite.scala index 8f9b00f..e88e7ef 100644 --- a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeSuite.scala +++ b/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeSuite.scala @@ -23,8 +23,10 @@ import scalabindings._ import drm._ import RLikeOps._ import RLikeDrmOps._ +import org.apache.mahout.sparkbindings._ import org.apache.mahout.sparkbindings.test.DistributedSparkSuite +case class Thingy(thing1: Double, thing2: Double, thing3: Double) /** DRMLike tests -- just run common DRM tests in Spark. */ class DrmLikeSuite extends FunSuite with DistributedSparkSuite with DrmLikeSuiteBase { @@ -63,6 +65,78 @@ class DrmLikeSuite extends FunSuite with DistributedSparkSuite with DrmLikeSuite throw new AssertionError("Block must be dense.") keys -> block }).norm should be < 1e-4 + + } + + test("DRM wrap labeled points") { + + import org.apache.spark.mllib.linalg.{Vectors => SparkVector} + import org.apache.spark.mllib.regression.LabeledPoint + + val sc = mahoutCtx.asInstanceOf[SparkDistributedContext].sc + + val lpRDD = sc.parallelize(Seq(LabeledPoint(1.0, SparkVector.dense(2.0, 0.0, 4.0)), + LabeledPoint(2.0, SparkVector.dense(3.0, 0.0, 5.0)), + LabeledPoint(3.0, SparkVector.dense(4.0, 0.0, 6.0)) )) + + val lpDRM = drmWrapMLLibLabeledPoint(rdd = lpRDD) + val lpM = lpDRM.collect(::,::) + val testM = dense((2,0,4,1), (3,0,5,2), (4,0,6,3)) + assert(lpM === testM) } + test("DRM wrap spark vectors") { + + import org.apache.spark.mllib.linalg.{Vectors => SparkVector} + + val sc = mahoutCtx.asInstanceOf[SparkDistributedContext].sc + + val svRDD = sc.parallelize(Seq(SparkVector.dense(2.0, 0.0, 4.0), + SparkVector.dense(3.0, 0.0, 5.0), + SparkVector.dense(4.0, 0.0, 6.0) )) + + val svDRM = drmWrapMLLibVector(rdd = svRDD) + val svM = svDRM.collect(::,::) + val testM = dense((2,0,4), (3,0,5), (4,0,6)) + + assert(svM === testM) + + val ssvRDD = sc.parallelize(Seq(SparkVector.sparse(3, Array(1,2), Array(3,4)), + SparkVector.sparse(3, Array(0,2), Array(3,4)), + SparkVector.sparse(3, Array(0,1), Array(3,4))) ) + + val ssvDRM = drmWrapMLLibVector(rdd = ssvRDD) + val ssvM = ssvDRM.collect(::,::) + + val testSM = sparse( + (1, 3) :: (2, 4) :: Nil, + (0, 3) :: (2, 4) :: Nil, + (0, 3) :: (1, 4) :: Nil) + + assert(ssvM === testSM) + } + + + + test("DRM wrap spark dataframe") { + + import org.apache.spark.mllib.linalg.{Vectors => SparkVector} + + val sc = mahoutCtx.asInstanceOf[SparkDistributedContext].sc + + val sqlContext= new org.apache.spark.sql.SQLContext(sc) + import sqlContext.implicits._ + + val myDF = sc.parallelize(Seq((2.0, 0.0, 4.0), + (3.0, 0.0, 5.0), + (4.0, 0.0, 6.0) )) + .map(o => Thingy(o._1, o._2, o._3)) + .toDF() + + val dfDRM = drmWrapDataFrame(df = myDF) + val dfM = dfDRM.collect(::,::) + val testM = dense((2,0,4), (3,0,5), (4,0,6)) + + assert(dfM === testM) + } }
