Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/2378#discussion_r17697397
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala ---
@@ -476,259 +436,167 @@ class PythonMLLibAPI extends Serializable {
numRows: Long,
numCols: Int,
numPartitions: java.lang.Integer,
- seed: java.lang.Long): JavaRDD[Array[Byte]] = {
+ seed: java.lang.Long): JavaRDD[Vector] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
- RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts,
s).map(SerDe.serializeDoubleVector)
+ RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s)
}
}
/**
- * :: DeveloperApi ::
- * MultivariateStatisticalSummary with Vector fields serialized.
+ * SerDe utility functions for PythonMLLibAPI.
*/
-@DeveloperApi
-class MultivariateStatisticalSummarySerialized(val summary:
MultivariateStatisticalSummary)
- extends Serializable {
+private[spark] object SerDe extends Serializable {
- def mean: Array[Byte] = SerDe.serializeDoubleVector(summary.mean)
+ val PYSPARK_PACKAGE = "pyspark.mllib"
- def variance: Array[Byte] = SerDe.serializeDoubleVector(summary.variance)
+ /**
+ * Base class used for pickle
+ */
+ private[python] abstract class BasePickler[T: ClassTag]
+ extends IObjectPickler with IObjectConstructor {
+
+ private val cls = implicitly[ClassTag[T]].runtimeClass
+ private val module = PYSPARK_PACKAGE + "." + cls.getName.split('.')(4)
+ private val name = cls.getSimpleName
+
+ // register this to Pickler and Unpickler
+ def register(): Unit = {
+ Pickler.registerCustomPickler(this.getClass, this)
+ Pickler.registerCustomPickler(cls, this)
+ Unpickler.registerConstructor(module, name, this)
+ }
- def count: Long = summary.count
+ def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+ if (obj == this) {
+ out.write(Opcodes.GLOBAL)
+ out.write((module + "\n" + name + "\n").getBytes())
+ } else {
+ pickler.save(this) // it will be memorized by Pickler
+ saveState(obj, out, pickler)
+ out.write(Opcodes.REDUCE)
+ }
+ }
+
+ private[python] def saveObjects(out: OutputStream, pickler: Pickler,
+ objects: Any*) = {
+ if (objects.length == 0 || objects.length > 3) {
+ out.write(Opcodes.MARK)
+ }
+ objects.foreach(pickler.save(_))
+ val code = objects.length match {
+ case 1 => Opcodes.TUPLE1
+ case 2 => Opcodes.TUPLE2
+ case 3 => Opcodes.TUPLE3
+ case _ => Opcodes.TUPLE
+ }
+ out.write(code)
+ }
- def numNonzeros: Array[Byte] =
SerDe.serializeDoubleVector(summary.numNonzeros)
+ private[python] def saveState(obj: Object, out: OutputStream, pickler:
Pickler)
+ }
- def max: Array[Byte] = SerDe.serializeDoubleVector(summary.max)
+ // Pickler for DenseVector
+ private[python] class DenseVectorPickler
+ extends BasePickler[DenseVector] {
- def min: Array[Byte] = SerDe.serializeDoubleVector(summary.min)
-}
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ val vector: DenseVector = obj.asInstanceOf[DenseVector]
+ saveObjects(out, pickler, vector.toArray)
+ }
-/**
- * SerDe utility functions for PythonMLLibAPI.
- */
-private[spark] object SerDe extends Serializable {
- private val DENSE_VECTOR_MAGIC: Byte = 1
- private val SPARSE_VECTOR_MAGIC: Byte = 2
- private val DENSE_MATRIX_MAGIC: Byte = 3
- private val LABELED_POINT_MAGIC: Byte = 4
-
- private[python] def deserializeDoubleVector(bytes: Array[Byte], offset:
Int = 0): Vector = {
- require(bytes.length - offset >= 5, "Byte array too short")
- val magic = bytes(offset)
- if (magic == DENSE_VECTOR_MAGIC) {
- deserializeDenseVector(bytes, offset)
- } else if (magic == SPARSE_VECTOR_MAGIC) {
- deserializeSparseVector(bytes, offset)
- } else {
- throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+ def construct(args: Array[Object]) :Object = {
+ require(args.length == 1)
+ new DenseVector(args(0).asInstanceOf[Array[Double]])
}
}
- private[python] def deserializeDouble(bytes: Array[Byte], offset: Int =
0): Double = {
- require(bytes.length - offset == 8, "Wrong size byte array for Double")
- val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
- bb.order(ByteOrder.nativeOrder())
- bb.getDouble
- }
-
- private[python] def deserializeDenseVector(bytes: Array[Byte], offset:
Int = 0): Vector = {
- val packetLength = bytes.length - offset
- require(packetLength >= 5, "Byte array too short")
- val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
- bb.order(ByteOrder.nativeOrder())
- val magic = bb.get()
- require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic)
- val length = bb.getInt()
- require (packetLength == 5 + 8 * length, "Invalid packet length: " +
packetLength)
- val db = bb.asDoubleBuffer()
- val ans = new Array[Double](length.toInt)
- db.get(ans)
- Vectors.dense(ans)
- }
-
- private[python] def deserializeSparseVector(bytes: Array[Byte], offset:
Int = 0): Vector = {
- val packetLength = bytes.length - offset
- require(packetLength >= 9, "Byte array too short")
- val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
- bb.order(ByteOrder.nativeOrder())
- val magic = bb.get()
- require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic)
- val size = bb.getInt()
- val nonZeros = bb.getInt()
- require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: "
+ packetLength)
- val ib = bb.asIntBuffer()
- val indices = new Array[Int](nonZeros)
- ib.get(indices)
- bb.position(bb.position() + 4 * nonZeros)
- val db = bb.asDoubleBuffer()
- val values = new Array[Double](nonZeros)
- db.get(values)
- Vectors.sparse(size, indices, values)
- }
+ // Pickler for DenseMatrix
+ private[python] class DenseMatrixPickler
+ extends BasePickler[DenseMatrix] {
- /**
- * Returns an 8-byte array for the input Double.
- *
- * Note: we currently do not use a magic byte for double for storage
efficiency.
- * This should be reconsidered when we add Ser/De for other 8-byte types
(e.g. Long), for safety.
- * The corresponding deserializer, deserializeDouble, needs to be
modified as well if the
- * serialization scheme changes.
- */
- private[python] def serializeDouble(double: Double): Array[Byte] = {
- val bytes = new Array[Byte](8)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.putDouble(double)
- bytes
- }
-
- private[python] def serializeDenseVector(doubles: Array[Double]):
Array[Byte] = {
- val len = doubles.length
- val bytes = new Array[Byte](5 + 8 * len)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.put(DENSE_VECTOR_MAGIC)
- bb.putInt(len)
- val db = bb.asDoubleBuffer()
- db.put(doubles)
- bytes
- }
-
- private[python] def serializeSparseVector(vector: SparseVector):
Array[Byte] = {
- val nonZeros = vector.indices.length
- val bytes = new Array[Byte](9 + 12 * nonZeros)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.put(SPARSE_VECTOR_MAGIC)
- bb.putInt(vector.size)
- bb.putInt(nonZeros)
- val ib = bb.asIntBuffer()
- ib.put(vector.indices)
- bb.position(bb.position() + 4 * nonZeros)
- val db = bb.asDoubleBuffer()
- db.put(vector.values)
- bytes
- }
-
- private[python] def serializeDoubleVector(vector: Vector): Array[Byte] =
vector match {
- case s: SparseVector =>
- serializeSparseVector(s)
- case _ =>
- serializeDenseVector(vector.toArray)
- }
-
- private[python] def deserializeDoubleMatrix(bytes: Array[Byte]):
Array[Array[Double]] = {
- val packetLength = bytes.length
- if (packetLength < 9) {
- throw new IllegalArgumentException("Byte array too short.")
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
+ saveObjects(out, pickler, m.numRows, m.numCols, m.values)
}
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- val magic = bb.get()
- if (magic != DENSE_MATRIX_MAGIC) {
- throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+
+ def construct(args: Array[Object]) :Object = {
+ require(args.length == 3)
+ new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int],
+ args(2).asInstanceOf[Array[Double]])
}
- val rows = bb.getInt()
- val cols = bb.getInt()
- if (packetLength != 9 + 8 * rows * cols) {
- throw new IllegalArgumentException("Size " + rows + "x" + cols + "
is wrong.")
+ }
+
+ // Pickler for SparseVector
+ private[python] class SparseVectorPickler
+ extends BasePickler[SparseVector] {
+
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ val v: SparseVector = obj.asInstanceOf[SparseVector]
+ saveObjects(out, pickler, v.size, v.indices, v.values)
}
- val db = bb.asDoubleBuffer()
- val ans = new Array[Array[Double]](rows.toInt)
- for (i <- 0 until rows.toInt) {
- ans(i) = new Array[Double](cols.toInt)
- db.get(ans(i))
+
+ def construct(args: Array[Object]) :Object = {
+ require(args.length == 3)
+ new SparseVector(args(0).asInstanceOf[Int],
args(1).asInstanceOf[Array[Int]],
+ args(2).asInstanceOf[Array[Double]])
}
- ans
}
- private[python] def serializeDoubleMatrix(doubles:
Array[Array[Double]]): Array[Byte] = {
- val rows = doubles.length
- var cols = 0
- if (rows > 0) {
- cols = doubles(0).length
+ // Pickler for LabeledPoint
+ private[python] class LabeledPointPickler
+ extends BasePickler[LabeledPoint] {
+
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ val point: LabeledPoint = obj.asInstanceOf[LabeledPoint]
+ saveObjects(out, pickler, point.label, point.features)
}
- val bytes = new Array[Byte](9 + 8 * rows * cols)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.put(DENSE_MATRIX_MAGIC)
- bb.putInt(rows)
- bb.putInt(cols)
- val db = bb.asDoubleBuffer()
- for (i <- 0 until rows) {
- db.put(doubles(i))
+
+ def construct(args: Array[Object]) :Object = {
+ if (args.length != 2) {
+ throw new PickleException("should be 2")
--- End diff --
Use consistent Exception type. (In some other places, require() is used
instead.)
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]