Repository: spark
Updated Branches:
  refs/heads/master 04450d115 -> 1a9c6cdda


[SPARK-3573][MLLIB] Make MLlib's Vector compatible with SQL's SchemaRDD

Register MLlib's Vector as a SQL user-defined type (UDT) in both Scala and 
Python. With this PR, we can easily map a RDD[LabeledPoint] to a SchemaRDD, and 
then select columns or save to a Parquet file. Examples in Scala/Python are 
attached. The Scala code was copied from jkbradley.

~~This PR contains the changes from #3068 . I will rebase after #3068 is 
merged.~~

marmbrus jkbradley

Author: Xiangrui Meng <m...@databricks.com>

Closes #3070 from mengxr/SPARK-3573 and squashes the following commits:

3a0b6e5 [Xiangrui Meng] organize imports
236f0a0 [Xiangrui Meng] register vector as UDT and provide dataset examples


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1a9c6cdd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1a9c6cdd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1a9c6cdd

Branch: refs/heads/master
Commit: 1a9c6cddadebdc53d083ac3e0da276ce979b5d1f
Parents: 04450d1
Author: Xiangrui Meng <m...@databricks.com>
Authored: Mon Nov 3 22:29:48 2014 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Mon Nov 3 22:29:48 2014 -0800

----------------------------------------------------------------------
 dev/run-tests                                   |   2 +-
 .../src/main/python/mllib/dataset_example.py    |  62 ++++++++++
 .../spark/examples/mllib/DatasetExample.scala   | 121 +++++++++++++++++++
 mllib/pom.xml                                   |   5 +
 .../org/apache/spark/mllib/linalg/Vectors.scala |  69 ++++++++++-
 .../spark/mllib/linalg/VectorsSuite.scala       |  11 ++
 python/pyspark/mllib/linalg.py                  |  50 ++++++++
 python/pyspark/mllib/tests.py                   |  39 +++++-
 8 files changed, 353 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1a9c6cdd/dev/run-tests
----------------------------------------------------------------------
diff --git a/dev/run-tests b/dev/run-tests
index 0e9eefa..de607e4 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -180,7 +180,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS
   if [ -n "$_SQL_TESTS_ONLY" ]; then
     # This must be an array of individual arguments. Otherwise, having one 
long string
     #+ will be interpreted as a single test, which doesn't work.
-    SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test")
+    SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test")
   else
     SBT_MAVEN_TEST_ARGS=("test")
   fi

http://git-wip-us.apache.org/repos/asf/spark/blob/1a9c6cdd/examples/src/main/python/mllib/dataset_example.py
----------------------------------------------------------------------
diff --git a/examples/src/main/python/mllib/dataset_example.py 
b/examples/src/main/python/mllib/dataset_example.py
new file mode 100644
index 0000000..540dae7
--- /dev/null
+++ b/examples/src/main/python/mllib/dataset_example.py
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+An example of how to use SchemaRDD as a dataset for ML. Run with::
+    bin/spark-submit examples/src/main/python/mllib/dataset_example.py
+"""
+
+import os
+import sys
+import tempfile
+import shutil
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext
+from pyspark.mllib.util import MLUtils
+from pyspark.mllib.stat import Statistics
+
+
+def summarize(dataset):
+    print "schema: %s" % dataset.schema().json()
+    labels = dataset.map(lambda r: r.label)
+    print "label average: %f" % labels.mean()
+    features = dataset.map(lambda r: r.features)
+    summary = Statistics.colStats(features)
+    print "features average: %r" % summary.mean()
+
+if __name__ == "__main__":
+    if len(sys.argv) > 2:
+        print >> sys.stderr, "Usage: dataset_example.py <libsvm file>"
+        exit(-1)
+    sc = SparkContext(appName="DatasetExample")
+    sqlCtx = SQLContext(sc)
+    if len(sys.argv) == 2:
+        input = sys.argv[1]
+    else:
+        input = "data/mllib/sample_libsvm_data.txt"
+    points = MLUtils.loadLibSVMFile(sc, input)
+    dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache()
+    summarize(dataset0)
+    tempdir = tempfile.NamedTemporaryFile(delete=False).name
+    os.unlink(tempdir)
+    print "Save dataset as a Parquet file to %s." % tempdir
+    dataset0.saveAsParquetFile(tempdir)
+    print "Load it back and summarize it again."
+    dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache()
+    summarize(dataset1)
+    shutil.rmtree(tempdir)

http://git-wip-us.apache.org/repos/asf/spark/blob/1a9c6cdd/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
----------------------------------------------------------------------
diff --git 
a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala 
b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
new file mode 100644
index 0000000..f8d83f4
--- /dev/null
+++ 
b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import java.io.File
+
+import com.google.common.io.Files
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext, SchemaRDD}
+
+/**
+ * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset 
for ML. Run with
+ * {{{
+ * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
+ * }}}
+ * If you use it as a template to create your own app, please use 
`spark-submit` to submit your app.
+ */
+object DatasetExample {
+
+  case class Params(
+      input: String = "data/mllib/sample_libsvm_data.txt",
+      dataFormat: String = "libsvm") extends AbstractParams[Params]
+
+  def main(args: Array[String]) {
+    val defaultParams = Params()
+
+    val parser = new OptionParser[Params]("DatasetExample") {
+      head("Dataset: an example app using SchemaRDD as a Dataset for ML.")
+      opt[String]("input")
+        .text(s"input path to dataset")
+        .action((x, c) => c.copy(input = x))
+      opt[String]("dataFormat")
+        .text("data format: libsvm (default), dense (deprecated in Spark 
v1.1)")
+        .action((x, c) => c.copy(input = x))
+      checkConfig { params =>
+        success
+      }
+    }
+
+    parser.parse(args, defaultParams).map { params =>
+      run(params)
+    }.getOrElse {
+      sys.exit(1)
+    }
+  }
+
+  def run(params: Params) {
+
+    val conf = new SparkConf().setAppName(s"DatasetExample with $params")
+    val sc = new SparkContext(conf)
+    val sqlContext = new SQLContext(sc)
+    import sqlContext._ // for implicit conversions
+
+    // Load input data
+    val origData: RDD[LabeledPoint] = params.dataFormat match {
+      case "dense" => MLUtils.loadLabeledPoints(sc, params.input)
+      case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input)
+    }
+    println(s"Loaded ${origData.count()} instances from file: ${params.input}")
+
+    // Convert input data to SchemaRDD explicitly.
+    val schemaRDD: SchemaRDD = origData
+    println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}")
+    println(s"Converted to SchemaRDD with ${schemaRDD.count()} records")
+
+    // Select columns, using implicit conversion to SchemaRDD.
+    val labelsSchemaRDD: SchemaRDD = origData.select('label)
+    val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v }
+    val numLabels = labels.count()
+    val meanLabel = labels.fold(0.0)(_ + _) / numLabels
+    println(s"Selected label column with average value $meanLabel")
+
+    val featuresSchemaRDD: SchemaRDD = origData.select('features)
+    val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => 
v }
+    val featureSummary = features.aggregate(new 
MultivariateOnlineSummarizer())(
+      (summary, feat) => summary.add(feat),
+      (sum1, sum2) => sum1.merge(sum2))
+    println(s"Selected features column with average values:\n 
${featureSummary.mean.toString}")
+
+    val tmpDir = Files.createTempDir()
+    tmpDir.deleteOnExit()
+    val outputDir = new File(tmpDir, "dataset").toString
+    println(s"Saving to $outputDir as Parquet file.")
+    schemaRDD.saveAsParquetFile(outputDir)
+
+    println(s"Loading Parquet file with UDT from $outputDir.")
+    val newDataset = sqlContext.parquetFile(outputDir)
+
+    println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
+    val newFeatures = newDataset.select('features).map { case Row(v: Vector) 
=> v }
+    val newFeaturesSummary = newFeatures.aggregate(new 
MultivariateOnlineSummarizer())(
+      (summary, feat) => summary.add(feat),
+      (sum1, sum2) => sum1.merge(sum2))
+    println(s"Selected features column with average values:\n 
${newFeaturesSummary.mean.toString}")
+
+    sc.stop()
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1a9c6cdd/mllib/pom.xml
----------------------------------------------------------------------
diff --git a/mllib/pom.xml b/mllib/pom.xml
index fb7239e..87a7dda 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -46,6 +46,11 @@
       <version>${project.version}</version>
     </dependency>
     <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-sql_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+    </dependency>
+    <dependency>
       <groupId>org.eclipse.jetty</groupId>
       <artifactId>jetty-server</artifactId>
     </dependency>

http://git-wip-us.apache.org/repos/asf/spark/blob/1a9c6cdd/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 6af225b..ac217ed 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -17,22 +17,26 @@
 
 package org.apache.spark.mllib.linalg
 
-import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => 
JavaIterable}
 import java.util
+import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => 
JavaIterable}
 
 import scala.annotation.varargs
 import scala.collection.JavaConverters._
 
 import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
 
-import org.apache.spark.mllib.util.NumericParser
 import org.apache.spark.SparkException
+import org.apache.spark.mllib.util.NumericParser
+import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row}
+import org.apache.spark.sql.catalyst.types._
 
 /**
  * Represents a numeric vector, whose index type is Int and value type is 
Double.
  *
  * Note: Users should not implement this interface.
  */
+@SQLUserDefinedType(udt = classOf[VectorUDT])
 sealed trait Vector extends Serializable {
 
   /**
@@ -75,6 +79,65 @@ sealed trait Vector extends Serializable {
 }
 
 /**
+ * User-defined type for [[Vector]] which allows easy interaction with SQL
+ * via [[org.apache.spark.sql.SchemaRDD]].
+ */
+private[spark] class VectorUDT extends UserDefinedType[Vector] {
+
+  override def sqlType: StructType = {
+    // type: 0 = sparse, 1 = dense
+    // We only use "values" for dense vectors, and "size", "indices", and 
"values" for sparse
+    // vectors. The "values" field is nullable because we might want to add 
binary vectors later,
+    // which uses "size" and "indices", but not "values".
+    StructType(Seq(
+      StructField("type", ByteType, nullable = false),
+      StructField("size", IntegerType, nullable = true),
+      StructField("indices", ArrayType(IntegerType, containsNull = false), 
nullable = true),
+      StructField("values", ArrayType(DoubleType, containsNull = false), 
nullable = true)))
+  }
+
+  override def serialize(obj: Any): Row = {
+    val row = new GenericMutableRow(4)
+    obj match {
+      case sv: SparseVector =>
+        row.setByte(0, 0)
+        row.setInt(1, sv.size)
+        row.update(2, sv.indices.toSeq)
+        row.update(3, sv.values.toSeq)
+      case dv: DenseVector =>
+        row.setByte(0, 1)
+        row.setNullAt(1)
+        row.setNullAt(2)
+        row.update(3, dv.values.toSeq)
+    }
+    row
+  }
+
+  override def deserialize(datum: Any): Vector = {
+    datum match {
+      case row: Row =>
+        require(row.length == 4,
+          s"VectorUDT.deserialize given row with length ${row.length} but 
requires length == 4")
+        val tpe = row.getByte(0)
+        tpe match {
+          case 0 =>
+            val size = row.getInt(1)
+            val indices = row.getAs[Iterable[Int]](2).toArray
+            val values = row.getAs[Iterable[Double]](3).toArray
+            new SparseVector(size, indices, values)
+          case 1 =>
+            val values = row.getAs[Iterable[Double]](3).toArray
+            new DenseVector(values)
+        }
+    }
+  }
+
+  override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT"
+
+  override def userClass: Class[Vector] = classOf[Vector]
+}
+
+/**
  * Factory methods for [[org.apache.spark.mllib.linalg.Vector]].
  * We don't use the name `Vector` because Scala imports
  * [[scala.collection.immutable.Vector]] by default.
@@ -191,6 +254,7 @@ object Vectors {
 /**
  * A dense vector represented by a value array.
  */
+@SQLUserDefinedType(udt = classOf[VectorUDT])
 class DenseVector(val values: Array[Double]) extends Vector {
 
   override def size: Int = values.length
@@ -215,6 +279,7 @@ class DenseVector(val values: Array[Double]) extends Vector 
{
  * @param indices index array, assume to be strictly increasing.
  * @param values value array, must have the same length as the index array.
  */
+@SQLUserDefinedType(udt = classOf[VectorUDT])
 class SparseVector(
     override val size: Int,
     val indices: Array[Int],

http://git-wip-us.apache.org/repos/asf/spark/blob/1a9c6cdd/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index cd651fe..93a84fe 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -155,4 +155,15 @@ class VectorsSuite extends FunSuite {
         throw new RuntimeException(s"copy returned ${dvCopy.getClass} on 
${dv.getClass}.")
     }
   }
+
+  test("VectorUDT") {
+    val dv0 = Vectors.dense(Array.empty[Double])
+    val dv1 = Vectors.dense(1.0, 2.0)
+    val sv0 = Vectors.sparse(2, Array.empty, Array.empty)
+    val sv1 = Vectors.sparse(2, Array(1), Array(2.0))
+    val udt = new VectorUDT()
+    for (v <- Seq(dv0, dv1, sv0, sv1)) {
+      assert(v === udt.deserialize(udt.serialize(v)))
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1a9c6cdd/python/pyspark/mllib/linalg.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index d0a0e10..c0c3dff 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -29,6 +29,9 @@ import copy_reg
 
 import numpy as np
 
+from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, 
DoubleType, \
+    IntegerType, ByteType, Row
+
 
 __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors']
 
@@ -106,7 +109,54 @@ def _format_float(f, digits=4):
     return s
 
 
+class VectorUDT(UserDefinedType):
+    """
+    SQL user-defined type (UDT) for Vector.
+    """
+
+    @classmethod
+    def sqlType(cls):
+        return StructType([
+            StructField("type", ByteType(), False),
+            StructField("size", IntegerType(), True),
+            StructField("indices", ArrayType(IntegerType(), False), True),
+            StructField("values", ArrayType(DoubleType(), False), True)])
+
+    @classmethod
+    def module(cls):
+        return "pyspark.mllib.linalg"
+
+    @classmethod
+    def scalaUDT(cls):
+        return "org.apache.spark.mllib.linalg.VectorUDT"
+
+    def serialize(self, obj):
+        if isinstance(obj, SparseVector):
+            indices = [int(i) for i in obj.indices]
+            values = [float(v) for v in obj.values]
+            return (0, obj.size, indices, values)
+        elif isinstance(obj, DenseVector):
+            values = [float(v) for v in obj]
+            return (1, None, None, values)
+        else:
+            raise ValueError("cannot serialize %r of type %r" % (obj, 
type(obj)))
+
+    def deserialize(self, datum):
+        assert len(datum) == 4, \
+            "VectorUDT.deserialize given row with length %d but requires 4" % 
len(datum)
+        tpe = datum[0]
+        if tpe == 0:
+            return SparseVector(datum[1], datum[2], datum[3])
+        elif tpe == 1:
+            return DenseVector(datum[3])
+        else:
+            raise ValueError("do not recognize type %r" % tpe)
+
+
 class Vector(object):
+
+    __UDT__ = VectorUDT()
+
     """
     Abstract class for DenseVector and SparseVector
     """

http://git-wip-us.apache.org/repos/asf/spark/blob/1a9c6cdd/python/pyspark/mllib/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index d6fb87b..9fa4d6f 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -33,14 +33,14 @@ if sys.version_info[:2] <= (2, 6):
 else:
     import unittest
 
-from pyspark.serializers import PickleSerializer
-from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, 
_convert_to_vector
+from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, 
_convert_to_vector
 from pyspark.mllib.regression import LabeledPoint
 from pyspark.mllib.random import RandomRDDs
 from pyspark.mllib.stat import Statistics
+from pyspark.serializers import PickleSerializer
+from pyspark.sql import SQLContext
 from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
 
-
 _have_scipy = False
 try:
     import scipy.sparse
@@ -221,6 +221,39 @@ class StatTests(PySparkTestCase):
         self.assertEqual(10, summary.count())
 
 
+class VectorUDTTests(PySparkTestCase):
+
+    dv0 = DenseVector([])
+    dv1 = DenseVector([1.0, 2.0])
+    sv0 = SparseVector(2, [], [])
+    sv1 = SparseVector(2, [1], [2.0])
+    udt = VectorUDT()
+
+    def test_json_schema(self):
+        self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt)
+
+    def test_serialization(self):
+        for v in [self.dv0, self.dv1, self.sv0, self.sv1]:
+            self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v)))
+
+    def test_infer_schema(self):
+        sqlCtx = SQLContext(self.sc)
+        rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), 
LabeledPoint(0.0, self.sv1)])
+        srdd = sqlCtx.inferSchema(rdd)
+        schema = srdd.schema()
+        field = [f for f in schema.fields if f.name == "features"][0]
+        self.assertEqual(field.dataType, self.udt)
+        vectors = srdd.map(lambda p: p.features).collect()
+        self.assertEqual(len(vectors), 2)
+        for v in vectors:
+            if isinstance(v, SparseVector):
+                self.assertEqual(v, self.sv1)
+            elif isinstance(v, DenseVector):
+                self.assertEqual(v, self.dv1)
+            else:
+                raise ValueError("expecting a vector but got %r of type %r" % 
(v, type(v)))
+
+
 @unittest.skipIf(not _have_scipy, "SciPy not installed")
 class SciPyTests(PySparkTestCase):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to