Repository: spark
Updated Branches:
  refs/heads/branch-2.0 55c1fac21 -> 1a57bf0f4


[SPARK-15364][ML][PYSPARK] Implement PySpark picklers for ml.Vector and 
ml.Matrix under spark.ml.python

## What changes were proposed in this pull request?

Now we have PySpark picklers for new and old vector/matrix, individually. 
However, they are all implemented under `PythonMLlibAPI`. To separate 
spark.mllib from spark.ml, we should implement the picklers of new 
vector/matrix under `spark.ml.python` instead.

## How was this patch tested?
Existing tests.

Author: Liang-Chi Hsieh <sim...@tw.ibm.com>

Closes #13219 from viirya/pyspark-pickler-ml.

(cherry picked from commit baa3e633e18c47b12e79fe3ddc01fc8ec010f096)
Signed-off-by: Xiangrui Meng <m...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1a57bf0f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1a57bf0f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1a57bf0f

Branch: refs/heads/branch-2.0
Commit: 1a57bf0f4e0fec2fdb250902089403a589ab8795
Parents: 55c1fac
Author: Liang-Chi Hsieh <sim...@tw.ibm.com>
Authored: Mon Jun 13 19:59:53 2016 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Mon Jun 13 20:00:10 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/python/MLSerDe.scala    | 224 ++++++++++++++
 .../spark/mllib/api/python/PythonMLLibAPI.scala | 309 ++++---------------
 .../apache/spark/ml/python/MLSerDeSuite.scala   |  72 +++++
 python/pyspark/java_gateway.py                  |   1 +
 python/pyspark/ml/base.py                       |   2 +-
 python/pyspark/ml/classification.py             |   2 +-
 python/pyspark/ml/clustering.py                 |   2 +-
 python/pyspark/ml/common.py                     | 137 ++++++++
 python/pyspark/ml/evaluation.py                 |   2 +-
 python/pyspark/ml/feature.py                    |   2 +-
 python/pyspark/ml/pipeline.py                   |   2 +-
 python/pyspark/ml/recommendation.py             |   2 +-
 python/pyspark/ml/regression.py                 |   2 +-
 python/pyspark/ml/tests.py                      |  10 +-
 python/pyspark/ml/tuning.py                     |   2 +-
 python/pyspark/ml/util.py                       |   2 +-
 python/pyspark/ml/wrapper.py                    |   2 +-
 17 files changed, 518 insertions(+), 257 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala 
b/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala
new file mode 100644
index 0000000..1279c90
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.python
+
+import java.io.OutputStream
+import java.nio.{ByteBuffer, ByteOrder}
+import java.util.{ArrayList => JArrayList}
+
+import scala.collection.JavaConverters._
+
+import net.razorvine.pickle._
+
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.python.SerDeUtil
+import org.apache.spark.ml.linalg._
+import org.apache.spark.mllib.api.python.SerDeBase
+import org.apache.spark.rdd.RDD
+
+/**
+ * SerDe utility functions for pyspark.ml.
+ */
+private[spark] object MLSerDe extends SerDeBase with Serializable {
+
+  override val PYSPARK_PACKAGE = "pyspark.ml"
+
+  // Pickler for DenseVector
+  private[python] class DenseVectorPickler extends BasePickler[DenseVector] {
+
+    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+      val vector: DenseVector = obj.asInstanceOf[DenseVector]
+      val bytes = new Array[Byte](8 * vector.size)
+      val bb = ByteBuffer.wrap(bytes)
+      bb.order(ByteOrder.nativeOrder())
+      val db = bb.asDoubleBuffer()
+      db.put(vector.values)
+
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(bytes.length))
+      out.write(bytes)
+      out.write(Opcodes.TUPLE1)
+    }
+
+    def construct(args: Array[Object]): Object = {
+      require(args.length == 1)
+      if (args.length != 1) {
+        throw new PickleException("should be 1")
+      }
+      val bytes = getBytes(args(0))
+      val bb = ByteBuffer.wrap(bytes, 0, bytes.length)
+      bb.order(ByteOrder.nativeOrder())
+      val db = bb.asDoubleBuffer()
+      val ans = new Array[Double](bytes.length / 8)
+      db.get(ans)
+      Vectors.dense(ans)
+    }
+  }
+
+  // Pickler for DenseMatrix
+  private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] {
+
+    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+      val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
+      val bytes = new Array[Byte](8 * m.values.length)
+      val order = ByteOrder.nativeOrder()
+      val isTransposed = if (m.isTransposed) 1 else 0
+      ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values)
+
+      out.write(Opcodes.MARK)
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(m.numRows))
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(m.numCols))
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(bytes.length))
+      out.write(bytes)
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(isTransposed))
+      out.write(Opcodes.TUPLE)
+    }
+
+    def construct(args: Array[Object]): Object = {
+      if (args.length != 4) {
+        throw new PickleException("should be 4")
+      }
+      val bytes = getBytes(args(2))
+      val n = bytes.length / 8
+      val values = new Array[Double](n)
+      val order = ByteOrder.nativeOrder()
+      ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values)
+      val isTransposed = args(3).asInstanceOf[Int] == 1
+      new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], 
values, isTransposed)
+    }
+  }
+
+  // Pickler for SparseMatrix
+  private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] {
+
+    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+      val s = obj.asInstanceOf[SparseMatrix]
+      val order = ByteOrder.nativeOrder()
+
+      val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length)
+      val indicesBytes = new Array[Byte](4 * s.rowIndices.length)
+      val valuesBytes = new Array[Byte](8 * s.values.length)
+      val isTransposed = if (s.isTransposed) 1 else 0
+      ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs)
+      
ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices)
+      ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values)
+
+      out.write(Opcodes.MARK)
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(s.numRows))
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(s.numCols))
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length))
+      out.write(colPtrsBytes)
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(indicesBytes.length))
+      out.write(indicesBytes)
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(valuesBytes.length))
+      out.write(valuesBytes)
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(isTransposed))
+      out.write(Opcodes.TUPLE)
+    }
+
+    def construct(args: Array[Object]): Object = {
+      if (args.length != 6) {
+        throw new PickleException("should be 6")
+      }
+      val order = ByteOrder.nativeOrder()
+      val colPtrsBytes = getBytes(args(2))
+      val indicesBytes = getBytes(args(3))
+      val valuesBytes = getBytes(args(4))
+      val colPtrs = new Array[Int](colPtrsBytes.length / 4)
+      val rowIndices = new Array[Int](indicesBytes.length / 4)
+      val values = new Array[Double](valuesBytes.length / 8)
+      ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs)
+      ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices)
+      ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values)
+      val isTransposed = args(5).asInstanceOf[Int] == 1
+      new SparseMatrix(
+        args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, 
rowIndices, values,
+        isTransposed)
+    }
+  }
+
+  // Pickler for SparseVector
+  private[python] class SparseVectorPickler extends BasePickler[SparseVector] {
+
+    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+      val v: SparseVector = obj.asInstanceOf[SparseVector]
+      val n = v.indices.length
+      val indiceBytes = new Array[Byte](4 * n)
+      val order = ByteOrder.nativeOrder()
+      ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices)
+      val valueBytes = new Array[Byte](8 * n)
+      ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().put(v.values)
+
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(v.size))
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(indiceBytes.length))
+      out.write(indiceBytes)
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(valueBytes.length))
+      out.write(valueBytes)
+      out.write(Opcodes.TUPLE3)
+    }
+
+    def construct(args: Array[Object]): Object = {
+      if (args.length != 3) {
+        throw new PickleException("should be 3")
+      }
+      val size = args(0).asInstanceOf[Int]
+      val indiceBytes = getBytes(args(1))
+      val valueBytes = getBytes(args(2))
+      val n = indiceBytes.length / 4
+      val indices = new Array[Int](n)
+      val values = new Array[Double](n)
+      if (n > 0) {
+        val order = ByteOrder.nativeOrder()
+        ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().get(indices)
+        ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().get(values)
+      }
+      new SparseVector(size, indices, values)
+    }
+  }
+
+  var initialized = false
+  // This should be called before trying to serialize any above classes
+  // In cluster mode, this should be put in the closure
+  override def initialize(): Unit = {
+    SerDeUtil.initialize()
+    synchronized {
+      if (!initialized) {
+        new DenseVectorPickler().register()
+        new DenseMatrixPickler().register()
+        new SparseMatrixPickler().register()
+        new SparseVectorPickler().register()
+        initialized = true
+      }
+    }
+  }
+  // will not called in Executor automatically
+  initialize()
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index e43469b..7df6160 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -30,7 +30,6 @@ import net.razorvine.pickle._
 
 import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
 import org.apache.spark.api.python.SerDeUtil
-import org.apache.spark.ml.linalg.{DenseMatrix => NewDenseMatrix, DenseVector 
=> NewDenseVector, SparseMatrix => NewSparseMatrix, SparseVector => 
NewSparseVector, Vectors => NewVectors}
 import org.apache.spark.mllib.classification._
 import org.apache.spark.mllib.clustering._
 import org.apache.spark.mllib.evaluation.RankingMetrics
@@ -1205,23 +1204,21 @@ private[python] class PythonMLLibAPI extends 
Serializable {
 }
 
 /**
- * SerDe utility functions for PythonMLLibAPI.
+ * Basic SerDe utility class.
  */
-private[spark] object SerDe extends Serializable {
+private[spark] abstract class SerDeBase {
 
-  val PYSPARK_PACKAGE = "pyspark.mllib"
-  val PYSPARK_ML_PACKAGE = "pyspark.ml"
+  val PYSPARK_PACKAGE: String
+  def initialize(): Unit
 
   /**
    * Base class used for pickle
    */
-  private[python] abstract class BasePickler[T: ClassTag]
+  private[spark] abstract class BasePickler[T: ClassTag]
     extends IObjectPickler with IObjectConstructor {
 
-    protected def packageName: String = PYSPARK_PACKAGE
-
     private val cls = implicitly[ClassTag[T]].runtimeClass
-    private val module = packageName + "." + cls.getName.split('.')(4)
+    private val module = PYSPARK_PACKAGE + "." + cls.getName.split('.')(4)
     private val name = cls.getSimpleName
 
     // register this to Pickler and Unpickler
@@ -1268,45 +1265,73 @@ private[spark] object SerDe extends Serializable {
     private[python] def saveState(obj: Object, out: OutputStream, pickler: 
Pickler)
   }
 
-  // Pickler for (mllib) DenseVector
-  private[python] class DenseVectorPickler extends BasePickler[DenseVector] {
+  def dumps(obj: AnyRef): Array[Byte] = {
+    obj match {
+      // Pickler in Python side cannot deserialize Scala Array normally. See 
SPARK-12834.
+      case array: Array[_] => new Pickler().dumps(array.toSeq.asJava)
+      case _ => new Pickler().dumps(obj)
+    }
+  }
 
-    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
-      val vector: DenseVector = obj.asInstanceOf[DenseVector]
-      val bytes = new Array[Byte](8 * vector.size)
-      val bb = ByteBuffer.wrap(bytes)
-      bb.order(ByteOrder.nativeOrder())
-      val db = bb.asDoubleBuffer()
-      db.put(vector.values)
+  def loads(bytes: Array[Byte]): AnyRef = {
+    new Unpickler().loads(bytes)
+  }
 
-      out.write(Opcodes.BINSTRING)
-      out.write(PickleUtils.integer_to_bytes(bytes.length))
-      out.write(bytes)
-      out.write(Opcodes.TUPLE1)
+  /* convert object into Tuple */
+  def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = {
+    rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int]))
+  }
+
+  /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
+  def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = {
+    rdd.map(x => Array(x._1, x._2))
+  }
+
+  /**
+   * Convert an RDD of Java objects to an RDD of serialized Python objects, 
that is usable by
+   * PySpark.
+   */
+  def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
+    jRDD.rdd.mapPartitions { iter =>
+      initialize()  // let it called in executor
+      new SerDeUtil.AutoBatchedPickler(iter)
     }
+  }
 
-    def construct(args: Array[Object]): Object = {
-      require(args.length == 1)
-      if (args.length != 1) {
-        throw new PickleException("should be 1")
+  /**
+   * Convert an RDD of serialized Python objects to RDD of objects, that is 
usable by PySpark.
+   */
+  def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): 
JavaRDD[Any] = {
+    pyRDD.rdd.mapPartitions { iter =>
+      initialize()  // let it called in executor
+      val unpickle = new Unpickler
+      iter.flatMap { row =>
+        val obj = unpickle.loads(row)
+        if (batched) {
+          obj match {
+            case list: JArrayList[_] => list.asScala
+            case arr: Array[_] => arr
+          }
+        } else {
+          Seq(obj)
+        }
       }
-      val bytes = getBytes(args(0))
-      val bb = ByteBuffer.wrap(bytes, 0, bytes.length)
-      bb.order(ByteOrder.nativeOrder())
-      val db = bb.asDoubleBuffer()
-      val ans = new Array[Double](bytes.length / 8)
-      db.get(ans)
-      Vectors.dense(ans)
-    }
+    }.toJavaRDD()
   }
+}
 
-  // Pickler for (new) DenseVector
-  private[python] class NewDenseVectorPickler extends 
BasePickler[NewDenseVector] {
+/**
+ * SerDe utility functions for PythonMLLibAPI.
+ */
+private[spark] object SerDe extends SerDeBase with Serializable {
+
+  override val PYSPARK_PACKAGE = "pyspark.mllib"
 
-    override protected def packageName = PYSPARK_ML_PACKAGE
+  // Pickler for DenseVector
+  private[python] class DenseVectorPickler extends BasePickler[DenseVector] {
 
     def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
-      val vector: NewDenseVector = obj.asInstanceOf[NewDenseVector]
+      val vector: DenseVector = obj.asInstanceOf[DenseVector]
       val bytes = new Array[Byte](8 * vector.size)
       val bb = ByteBuffer.wrap(bytes)
       bb.order(ByteOrder.nativeOrder())
@@ -1330,11 +1355,11 @@ private[spark] object SerDe extends Serializable {
       val db = bb.asDoubleBuffer()
       val ans = new Array[Double](bytes.length / 8)
       db.get(ans)
-      NewVectors.dense(ans)
+      Vectors.dense(ans)
     }
   }
 
-  // Pickler for (mllib) DenseMatrix
+  // Pickler for DenseMatrix
   private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] {
 
     def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
@@ -1371,46 +1396,7 @@ private[spark] object SerDe extends Serializable {
     }
   }
 
-  // Pickler for (new) DenseMatrix
-  private[python] class NewDenseMatrixPickler extends 
BasePickler[NewDenseMatrix] {
-
-    override protected def packageName = PYSPARK_ML_PACKAGE
-
-    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
-      val m: NewDenseMatrix = obj.asInstanceOf[NewDenseMatrix]
-      val bytes = new Array[Byte](8 * m.values.length)
-      val order = ByteOrder.nativeOrder()
-      val isTransposed = if (m.isTransposed) 1 else 0
-      ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values)
-
-      out.write(Opcodes.MARK)
-      out.write(Opcodes.BININT)
-      out.write(PickleUtils.integer_to_bytes(m.numRows))
-      out.write(Opcodes.BININT)
-      out.write(PickleUtils.integer_to_bytes(m.numCols))
-      out.write(Opcodes.BINSTRING)
-      out.write(PickleUtils.integer_to_bytes(bytes.length))
-      out.write(bytes)
-      out.write(Opcodes.BININT)
-      out.write(PickleUtils.integer_to_bytes(isTransposed))
-      out.write(Opcodes.TUPLE)
-    }
-
-    def construct(args: Array[Object]): Object = {
-      if (args.length != 4) {
-        throw new PickleException("should be 4")
-      }
-      val bytes = getBytes(args(2))
-      val n = bytes.length / 8
-      val values = new Array[Double](n)
-      val order = ByteOrder.nativeOrder()
-      ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values)
-      val isTransposed = args(3).asInstanceOf[Int] == 1
-      new NewDenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], 
values, isTransposed)
-    }
-  }
-
-  // Pickler for (mllib) SparseMatrix
+  // Pickler for SparseMatrix
   private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] {
 
     def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
@@ -1465,64 +1451,7 @@ private[spark] object SerDe extends Serializable {
     }
   }
 
-  // Pickler for (new) SparseMatrix
-  private[python] class NewSparseMatrixPickler extends 
BasePickler[NewSparseMatrix] {
-
-    override protected def packageName = PYSPARK_ML_PACKAGE
-
-    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
-      val s = obj.asInstanceOf[NewSparseMatrix]
-      val order = ByteOrder.nativeOrder()
-
-      val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length)
-      val indicesBytes = new Array[Byte](4 * s.rowIndices.length)
-      val valuesBytes = new Array[Byte](8 * s.values.length)
-      val isTransposed = if (s.isTransposed) 1 else 0
-      ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs)
-      
ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices)
-      ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values)
-
-      out.write(Opcodes.MARK)
-      out.write(Opcodes.BININT)
-      out.write(PickleUtils.integer_to_bytes(s.numRows))
-      out.write(Opcodes.BININT)
-      out.write(PickleUtils.integer_to_bytes(s.numCols))
-      out.write(Opcodes.BINSTRING)
-      out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length))
-      out.write(colPtrsBytes)
-      out.write(Opcodes.BINSTRING)
-      out.write(PickleUtils.integer_to_bytes(indicesBytes.length))
-      out.write(indicesBytes)
-      out.write(Opcodes.BINSTRING)
-      out.write(PickleUtils.integer_to_bytes(valuesBytes.length))
-      out.write(valuesBytes)
-      out.write(Opcodes.BININT)
-      out.write(PickleUtils.integer_to_bytes(isTransposed))
-      out.write(Opcodes.TUPLE)
-    }
-
-    def construct(args: Array[Object]): Object = {
-      if (args.length != 6) {
-        throw new PickleException("should be 6")
-      }
-      val order = ByteOrder.nativeOrder()
-      val colPtrsBytes = getBytes(args(2))
-      val indicesBytes = getBytes(args(3))
-      val valuesBytes = getBytes(args(4))
-      val colPtrs = new Array[Int](colPtrsBytes.length / 4)
-      val rowIndices = new Array[Int](indicesBytes.length / 4)
-      val values = new Array[Double](valuesBytes.length / 8)
-      ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs)
-      ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices)
-      ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values)
-      val isTransposed = args(5).asInstanceOf[Int] == 1
-      new NewSparseMatrix(
-        args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, 
rowIndices, values,
-        isTransposed)
-    }
-  }
-
-  // Pickler for (mllib) SparseVector
+  // Pickler for SparseVector
   private[python] class SparseVectorPickler extends BasePickler[SparseVector] {
 
     def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
@@ -1564,50 +1493,6 @@ private[spark] object SerDe extends Serializable {
     }
   }
 
-  // Pickler for (new) SparseVector
-  private[python] class NewSparseVectorPickler extends 
BasePickler[NewSparseVector] {
-
-    override protected def packageName = PYSPARK_ML_PACKAGE
-
-    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
-      val v: NewSparseVector = obj.asInstanceOf[NewSparseVector]
-      val n = v.indices.length
-      val indiceBytes = new Array[Byte](4 * n)
-      val order = ByteOrder.nativeOrder()
-      ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices)
-      val valueBytes = new Array[Byte](8 * n)
-      ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().put(v.values)
-
-      out.write(Opcodes.BININT)
-      out.write(PickleUtils.integer_to_bytes(v.size))
-      out.write(Opcodes.BINSTRING)
-      out.write(PickleUtils.integer_to_bytes(indiceBytes.length))
-      out.write(indiceBytes)
-      out.write(Opcodes.BINSTRING)
-      out.write(PickleUtils.integer_to_bytes(valueBytes.length))
-      out.write(valueBytes)
-      out.write(Opcodes.TUPLE3)
-    }
-
-    def construct(args: Array[Object]): Object = {
-      if (args.length != 3) {
-        throw new PickleException("should be 3")
-      }
-      val size = args(0).asInstanceOf[Int]
-      val indiceBytes = getBytes(args(1))
-      val valueBytes = getBytes(args(2))
-      val n = indiceBytes.length / 4
-      val indices = new Array[Int](n)
-      val values = new Array[Double](n)
-      if (n > 0) {
-        val order = ByteOrder.nativeOrder()
-        ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().get(indices)
-        ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().get(values)
-      }
-      new NewSparseVector(size, indices, values)
-    }
-  }
-
   // Pickler for MLlib LabeledPoint
   private[python] class LabeledPointPickler extends BasePickler[LabeledPoint] {
 
@@ -1654,7 +1539,7 @@ private[spark] object SerDe extends Serializable {
   var initialized = false
   // This should be called before trying to serialize any above classes
   // In cluster mode, this should be put in the closure
-  def initialize(): Unit = {
+  override def initialize(): Unit = {
     SerDeUtil.initialize()
     synchronized {
       if (!initialized) {
@@ -1662,10 +1547,6 @@ private[spark] object SerDe extends Serializable {
         new DenseMatrixPickler().register()
         new SparseMatrixPickler().register()
         new SparseVectorPickler().register()
-        new NewDenseVectorPickler().register()
-        new NewDenseMatrixPickler().register()
-        new NewSparseMatrixPickler().register()
-        new NewSparseVectorPickler().register()
         new LabeledPointPickler().register()
         new RatingPickler().register()
         initialized = true
@@ -1674,58 +1555,4 @@ private[spark] object SerDe extends Serializable {
   }
   // will not called in Executor automatically
   initialize()
-
-  def dumps(obj: AnyRef): Array[Byte] = {
-    obj match {
-      // Pickler in Python side cannot deserialize Scala Array normally. See 
SPARK-12834.
-      case array: Array[_] => new Pickler().dumps(array.toSeq.asJava)
-      case _ => new Pickler().dumps(obj)
-    }
-  }
-
-  def loads(bytes: Array[Byte]): AnyRef = {
-    new Unpickler().loads(bytes)
-  }
-
-  /* convert object into Tuple */
-  def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = {
-    rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int]))
-  }
-
-  /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
-  def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = {
-    rdd.map(x => Array(x._1, x._2))
-  }
-
-  /**
-   * Convert an RDD of Java objects to an RDD of serialized Python objects, 
that is usable by
-   * PySpark.
-   */
-  def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
-    jRDD.rdd.mapPartitions { iter =>
-      initialize()  // let it called in executor
-      new SerDeUtil.AutoBatchedPickler(iter)
-    }
-  }
-
-  /**
-   * Convert an RDD of serialized Python objects to RDD of objects, that is 
usable by PySpark.
-   */
-  def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): 
JavaRDD[Any] = {
-    pyRDD.rdd.mapPartitions { iter =>
-      initialize()  // let it called in executor
-      val unpickle = new Unpickler
-      iter.flatMap { row =>
-        val obj = unpickle.loads(row)
-        if (batched) {
-          obj match {
-            case list: JArrayList[_] => list.asScala
-            case arr: Array[_] => arr
-          }
-        } else {
-          Seq(obj)
-        }
-      }
-    }.toJavaRDD()
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala
new file mode 100644
index 0000000..5eaef9a
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.python
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, 
Vectors}
+
+class MLSerDeSuite extends SparkFunSuite {
+
+  MLSerDe.initialize()
+
+  test("pickle vector") {
+    val vectors = Seq(
+      Vectors.dense(Array.empty[Double]),
+      Vectors.dense(0.0),
+      Vectors.dense(0.0, -2.0),
+      Vectors.sparse(0, Array.empty[Int], Array.empty[Double]),
+      Vectors.sparse(1, Array.empty[Int], Array.empty[Double]),
+      Vectors.sparse(2, Array(1), Array(-2.0)))
+    vectors.foreach { v =>
+      val u = MLSerDe.loads(MLSerDe.dumps(v))
+      assert(u.getClass === v.getClass)
+      assert(u === v)
+    }
+  }
+
+  test("pickle double") {
+    for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, 
Double.NaN)) {
+      val deser = 
MLSerDe.loads(MLSerDe.dumps(x.asInstanceOf[AnyRef])).asInstanceOf[Double]
+      // We use `equals` here for comparison because we cannot use `==` for NaN
+      assert(x.equals(deser))
+    }
+  }
+
+  test("pickle matrix") {
+    val values = Array[Double](0, 1.2, 3, 4.56, 7, 8)
+    val matrix = Matrices.dense(2, 3, values)
+    val nm = MLSerDe.loads(MLSerDe.dumps(matrix)).asInstanceOf[DenseMatrix]
+    assert(matrix === nm)
+
+    // Test conversion for empty matrix
+    val empty = Array[Double]()
+    val emptyMatrix = Matrices.dense(0, 0, empty)
+    val ne = 
MLSerDe.loads(MLSerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix]
+    assert(emptyMatrix == ne)
+
+    val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 
1.2, 3.4))
+    val nsm = MLSerDe.loads(MLSerDe.dumps(sm)).asInstanceOf[SparseMatrix]
+    assert(sm.toArray === nsm.toArray)
+
+    val smt = new SparseMatrix(
+      3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 
8.9),
+      isTransposed = true)
+    val nsmt = MLSerDe.loads(MLSerDe.dumps(smt)).asInstanceOf[SparseMatrix]
+    assert(smt.toArray === nsmt.toArray)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/python/pyspark/java_gateway.py
----------------------------------------------------------------------
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index cd4c55f..527ca82 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -116,6 +116,7 @@ def launch_gateway():
     java_import(gateway.jvm, "org.apache.spark.SparkConf")
     java_import(gateway.jvm, "org.apache.spark.api.java.*")
     java_import(gateway.jvm, "org.apache.spark.api.python.*")
+    java_import(gateway.jvm, "org.apache.spark.ml.python.*")
     java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
     # TODO(davies): move into sql
     java_import(gateway.jvm, "org.apache.spark.sql.*")

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/python/pyspark/ml/base.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index a7a58e1..339e5d6 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -19,7 +19,7 @@ from abc import ABCMeta, abstractmethod
 
 from pyspark import since
 from pyspark.ml.param import Params
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 
 
 @inherit_doc

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py 
b/python/pyspark/ml/classification.py
index 77badeb..121b926 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -26,7 +26,7 @@ from pyspark.ml.regression import (
 from pyspark.ml.util import *
 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
 from pyspark.ml.wrapper import JavaWrapper
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 from pyspark.sql import DataFrame
 from pyspark.sql.functions import udf, when
 from pyspark.sql.types import ArrayType, DoubleType

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/python/pyspark/ml/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 92df19e..75d9a0e 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -19,7 +19,7 @@ from pyspark import since, keyword_only
 from pyspark.ml.util import *
 from pyspark.ml.wrapper import JavaEstimator, JavaModel
 from pyspark.ml.param.shared import *
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 
 __all__ = ['BisectingKMeans', 'BisectingKMeansModel',
            'KMeans', 'KMeansModel',

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/python/pyspark/ml/common.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py
new file mode 100644
index 0000000..256e91e
--- /dev/null
+++ b/python/pyspark/ml/common.py
@@ -0,0 +1,137 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+if sys.version >= '3':
+    long = int
+    unicode = str
+
+import py4j.protocol
+from py4j.protocol import Py4JJavaError
+from py4j.java_gateway import JavaObject
+from py4j.java_collections import ListConverter, JavaArray, JavaList
+
+from pyspark import RDD, SparkContext
+from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+from pyspark.sql import DataFrame, SQLContext
+
+# Hack for support float('inf') in Py4j
+_old_smart_decode = py4j.protocol.smart_decode
+
+_float_str_mapping = {
+    'nan': 'NaN',
+    'inf': 'Infinity',
+    '-inf': '-Infinity',
+}
+
+
+def _new_smart_decode(obj):
+    if isinstance(obj, float):
+        s = str(obj)
+        return _float_str_mapping.get(s, s)
+    return _old_smart_decode(obj)
+
+py4j.protocol.smart_decode = _new_smart_decode
+
+
+_picklable_classes = [
+    'SparseVector',
+    'DenseVector',
+    'DenseMatrix',
+]
+
+
+# this will call the ML version of pythonToJava()
+def _to_java_object_rdd(rdd):
+    """ Return an JavaRDD of Object by unpickling
+
+    It will convert each Python object into Java object by Pyrolite, whenever 
the
+    RDD is serialized in batch or not.
+    """
+    rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
+    return rdd.ctx._jvm.MLSerDe.pythonToJava(rdd._jrdd, True)
+
+
+def _py2java(sc, obj):
+    """ Convert Python object into Java """
+    if isinstance(obj, RDD):
+        obj = _to_java_object_rdd(obj)
+    elif isinstance(obj, DataFrame):
+        obj = obj._jdf
+    elif isinstance(obj, SparkContext):
+        obj = obj._jsc
+    elif isinstance(obj, list):
+        obj = ListConverter().convert([_py2java(sc, x) for x in obj], 
sc._gateway._gateway_client)
+    elif isinstance(obj, JavaObject):
+        pass
+    elif isinstance(obj, (int, long, float, bool, bytes, unicode)):
+        pass
+    else:
+        data = bytearray(PickleSerializer().dumps(obj))
+        obj = sc._jvm.MLSerDe.loads(data)
+    return obj
+
+
+def _java2py(sc, r, encoding="bytes"):
+    if isinstance(r, JavaObject):
+        clsName = r.getClass().getSimpleName()
+        # convert RDD into JavaRDD
+        if clsName != 'JavaRDD' and clsName.endswith("RDD"):
+            r = r.toJavaRDD()
+            clsName = 'JavaRDD'
+
+        if clsName == 'JavaRDD':
+            jrdd = sc._jvm.MLSerDe.javaToPython(r)
+            return RDD(jrdd, sc)
+
+        if clsName == 'Dataset':
+            return DataFrame(r, SQLContext.getOrCreate(sc))
+
+        if clsName in _picklable_classes:
+            r = sc._jvm.MLSerDe.dumps(r)
+        elif isinstance(r, (JavaArray, JavaList)):
+            try:
+                r = sc._jvm.MLSerDe.dumps(r)
+            except Py4JJavaError:
+                pass  # not pickable
+
+    if isinstance(r, (bytearray, bytes)):
+        r = PickleSerializer().loads(bytes(r), encoding=encoding)
+    return r
+
+
+def callJavaFunc(sc, func, *args):
+    """ Call Java Function """
+    args = [_py2java(sc, a) for a in args]
+    return _java2py(sc, func(*args))
+
+
+def inherit_doc(cls):
+    """
+    A decorator that makes a class inherit documentation from its parents.
+    """
+    for name, func in vars(cls).items():
+        # only inherit docstring for public functions
+        if name.startswith("_"):
+            continue
+        if not func.__doc__:
+            for parent in cls.__bases__:
+                parent_func = getattr(parent, name, None)
+                if parent_func and getattr(parent_func, "__doc__", None):
+                    func.__doc__ = parent_func.__doc__
+                    break
+    return cls

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/python/pyspark/ml/evaluation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index cd071f1..1fe8772 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -21,7 +21,7 @@ from pyspark import since, keyword_only
 from pyspark.ml.wrapper import JavaParams
 from pyspark.ml.param import Param, Params, TypeConverters
 from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, 
HasRawPredictionCol
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 
 __all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator',
            'MulticlassClassificationEvaluator']

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index ca77ac3..a28764a 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -25,7 +25,7 @@ from pyspark.ml.linalg import _convert_to_vector
 from pyspark.ml.param.shared import *
 from pyspark.ml.util import JavaMLReadable, JavaMLWritable
 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 
 __all__ = ['Binarizer',
            'Bucketizer',

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/python/pyspark/ml/pipeline.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 0777527..a48f4bb 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -25,7 +25,7 @@ from pyspark.ml import Estimator, Model, Transformer
 from pyspark.ml.param import Param, Params
 from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable
 from pyspark.ml.wrapper import JavaParams
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 
 
 @inherit_doc

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/python/pyspark/ml/recommendation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/recommendation.py 
b/python/pyspark/ml/recommendation.py
index 1778bfe..0a70967 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -19,7 +19,7 @@ from pyspark import since, keyword_only
 from pyspark.ml.util import *
 from pyspark.ml.wrapper import JavaEstimator, JavaModel
 from pyspark.ml.param.shared import *
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 
 
 __all__ = ['ALS', 'ALSModel']

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 7c79ab7..db31993 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -21,7 +21,7 @@ from pyspark import since, keyword_only
 from pyspark.ml.param.shared import *
 from pyspark.ml.util import *
 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 from pyspark.sql import DataFrame
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 4358175..981ed9d 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -61,7 +61,7 @@ from pyspark.ml.regression import LinearRegression, 
DecisionTreeRegressor, \
     GeneralizedLinearRegression
 from pyspark.ml.tuning import *
 from pyspark.ml.wrapper import JavaParams
-from pyspark.mllib.common import _java2py
+from pyspark.ml.common import _java2py
 from pyspark.serializers import PickleSerializer
 from pyspark.sql import DataFrame, Row, SparkSession
 from pyspark.sql.functions import rand
@@ -1195,12 +1195,12 @@ class VectorTests(MLlibTestCase):
 
     def _test_serialize(self, v):
         self.assertEqual(v, ser.loads(ser.dumps(v)))
-        jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v)))
-        nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec)))
+        jvec = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(v)))
+        nv = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvec)))
         self.assertEqual(v, nv)
         vs = [v] * 100
-        jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs)))
-        nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs)))
+        jvecs = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(vs)))
+        nvs = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvecs)))
         self.assertEqual(vs, nvs)
 
     def test_serialize(self):

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/python/pyspark/ml/tuning.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index fe87b6c..f857c5e 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -25,7 +25,7 @@ from pyspark.ml.param import Params, Param, TypeConverters
 from pyspark.ml.param.shared import HasSeed
 from pyspark.ml.wrapper import JavaParams
 from pyspark.sql.functions import rand
-from pyspark.mllib.common import inherit_doc, _py2java
+from pyspark.ml.common import inherit_doc, _py2java
 
 __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 
'TrainValidationSplit',
            'TrainValidationSplitModel']

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/python/pyspark/ml/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 9d28823..4a31a29 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -23,7 +23,7 @@ if sys.version > '3':
     unicode = str
 
 from pyspark import SparkContext, since
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 
 
 def _jvm():

http://git-wip-us.apache.org/repos/asf/spark/blob/1a57bf0f/python/pyspark/ml/wrapper.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index fef0040..25c44b7 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -22,7 +22,7 @@ from pyspark.sql import DataFrame
 from pyspark.ml import Estimator, Transformer, Model
 from pyspark.ml.param import Params
 from pyspark.ml.util import _jvm
-from pyspark.mllib.common import inherit_doc, _java2py, _py2java
+from pyspark.ml.common import inherit_doc, _java2py, _py2java
 
 
 class JavaWrapper(object):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to