Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/2378#discussion_r17760498
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala ---
@@ -476,259 +436,167 @@ class PythonMLLibAPI extends Serializable {
numRows: Long,
numCols: Int,
numPartitions: java.lang.Integer,
- seed: java.lang.Long): JavaRDD[Array[Byte]] = {
+ seed: java.lang.Long): JavaRDD[Vector] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
- RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts,
s).map(SerDe.serializeDoubleVector)
+ RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s)
}
}
/**
- * :: DeveloperApi ::
- * MultivariateStatisticalSummary with Vector fields serialized.
+ * SerDe utility functions for PythonMLLibAPI.
*/
-@DeveloperApi
-class MultivariateStatisticalSummarySerialized(val summary:
MultivariateStatisticalSummary)
- extends Serializable {
+private[spark] object SerDe extends Serializable {
- def mean: Array[Byte] = SerDe.serializeDoubleVector(summary.mean)
+ val PYSPARK_PACKAGE = "pyspark.mllib"
- def variance: Array[Byte] = SerDe.serializeDoubleVector(summary.variance)
+ /**
+ * Base class used for pickle
+ */
+ private[python] abstract class BasePickler[T: ClassTag]
+ extends IObjectPickler with IObjectConstructor {
+
+ private val cls = implicitly[ClassTag[T]].runtimeClass
+ private val module = PYSPARK_PACKAGE + "." + cls.getName.split('.')(4)
+ private val name = cls.getSimpleName
+
+ // register this to Pickler and Unpickler
+ def register(): Unit = {
+ Pickler.registerCustomPickler(this.getClass, this)
+ Pickler.registerCustomPickler(cls, this)
+ Unpickler.registerConstructor(module, name, this)
+ }
- def count: Long = summary.count
+ def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+ if (obj == this) {
+ out.write(Opcodes.GLOBAL)
+ out.write((module + "\n" + name + "\n").getBytes())
+ } else {
+ pickler.save(this) // it will be memorized by Pickler
+ saveState(obj, out, pickler)
+ out.write(Opcodes.REDUCE)
+ }
+ }
+
+ private[python] def saveObjects(out: OutputStream, pickler: Pickler,
+ objects: Any*) = {
+ if (objects.length == 0 || objects.length > 3) {
+ out.write(Opcodes.MARK)
+ }
+ objects.foreach(pickler.save(_))
+ val code = objects.length match {
+ case 1 => Opcodes.TUPLE1
+ case 2 => Opcodes.TUPLE2
+ case 3 => Opcodes.TUPLE3
+ case _ => Opcodes.TUPLE
+ }
+ out.write(code)
+ }
- def numNonzeros: Array[Byte] =
SerDe.serializeDoubleVector(summary.numNonzeros)
+ private[python] def saveState(obj: Object, out: OutputStream, pickler:
Pickler)
+ }
- def max: Array[Byte] = SerDe.serializeDoubleVector(summary.max)
+ // Pickler for DenseVector
+ private[python] class DenseVectorPickler
+ extends BasePickler[DenseVector] {
- def min: Array[Byte] = SerDe.serializeDoubleVector(summary.min)
-}
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ val vector: DenseVector = obj.asInstanceOf[DenseVector]
+ saveObjects(out, pickler, vector.toArray)
+ }
-/**
- * SerDe utility functions for PythonMLLibAPI.
- */
-private[spark] object SerDe extends Serializable {
- private val DENSE_VECTOR_MAGIC: Byte = 1
- private val SPARSE_VECTOR_MAGIC: Byte = 2
- private val DENSE_MATRIX_MAGIC: Byte = 3
- private val LABELED_POINT_MAGIC: Byte = 4
-
- private[python] def deserializeDoubleVector(bytes: Array[Byte], offset:
Int = 0): Vector = {
- require(bytes.length - offset >= 5, "Byte array too short")
- val magic = bytes(offset)
- if (magic == DENSE_VECTOR_MAGIC) {
- deserializeDenseVector(bytes, offset)
- } else if (magic == SPARSE_VECTOR_MAGIC) {
- deserializeSparseVector(bytes, offset)
- } else {
- throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+ def construct(args: Array[Object]) :Object = {
+ require(args.length == 1)
+ new DenseVector(args(0).asInstanceOf[Array[Double]])
}
}
- private[python] def deserializeDouble(bytes: Array[Byte], offset: Int =
0): Double = {
- require(bytes.length - offset == 8, "Wrong size byte array for Double")
- val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
- bb.order(ByteOrder.nativeOrder())
- bb.getDouble
- }
-
- private[python] def deserializeDenseVector(bytes: Array[Byte], offset:
Int = 0): Vector = {
- val packetLength = bytes.length - offset
- require(packetLength >= 5, "Byte array too short")
- val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
- bb.order(ByteOrder.nativeOrder())
- val magic = bb.get()
- require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic)
- val length = bb.getInt()
- require (packetLength == 5 + 8 * length, "Invalid packet length: " +
packetLength)
- val db = bb.asDoubleBuffer()
- val ans = new Array[Double](length.toInt)
- db.get(ans)
- Vectors.dense(ans)
- }
-
- private[python] def deserializeSparseVector(bytes: Array[Byte], offset:
Int = 0): Vector = {
- val packetLength = bytes.length - offset
- require(packetLength >= 9, "Byte array too short")
- val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
- bb.order(ByteOrder.nativeOrder())
- val magic = bb.get()
- require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic)
- val size = bb.getInt()
- val nonZeros = bb.getInt()
- require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: "
+ packetLength)
- val ib = bb.asIntBuffer()
- val indices = new Array[Int](nonZeros)
- ib.get(indices)
- bb.position(bb.position() + 4 * nonZeros)
- val db = bb.asDoubleBuffer()
- val values = new Array[Double](nonZeros)
- db.get(values)
- Vectors.sparse(size, indices, values)
- }
+ // Pickler for DenseMatrix
+ private[python] class DenseMatrixPickler
+ extends BasePickler[DenseMatrix] {
- /**
- * Returns an 8-byte array for the input Double.
- *
- * Note: we currently do not use a magic byte for double for storage
efficiency.
- * This should be reconsidered when we add Ser/De for other 8-byte types
(e.g. Long), for safety.
- * The corresponding deserializer, deserializeDouble, needs to be
modified as well if the
- * serialization scheme changes.
- */
- private[python] def serializeDouble(double: Double): Array[Byte] = {
- val bytes = new Array[Byte](8)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.putDouble(double)
- bytes
- }
-
- private[python] def serializeDenseVector(doubles: Array[Double]):
Array[Byte] = {
- val len = doubles.length
- val bytes = new Array[Byte](5 + 8 * len)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.put(DENSE_VECTOR_MAGIC)
- bb.putInt(len)
- val db = bb.asDoubleBuffer()
- db.put(doubles)
- bytes
- }
-
- private[python] def serializeSparseVector(vector: SparseVector):
Array[Byte] = {
- val nonZeros = vector.indices.length
- val bytes = new Array[Byte](9 + 12 * nonZeros)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.put(SPARSE_VECTOR_MAGIC)
- bb.putInt(vector.size)
- bb.putInt(nonZeros)
- val ib = bb.asIntBuffer()
- ib.put(vector.indices)
- bb.position(bb.position() + 4 * nonZeros)
- val db = bb.asDoubleBuffer()
- db.put(vector.values)
- bytes
- }
-
- private[python] def serializeDoubleVector(vector: Vector): Array[Byte] =
vector match {
- case s: SparseVector =>
- serializeSparseVector(s)
- case _ =>
- serializeDenseVector(vector.toArray)
- }
-
- private[python] def deserializeDoubleMatrix(bytes: Array[Byte]):
Array[Array[Double]] = {
- val packetLength = bytes.length
- if (packetLength < 9) {
- throw new IllegalArgumentException("Byte array too short.")
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
+ saveObjects(out, pickler, m.numRows, m.numCols, m.values)
}
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- val magic = bb.get()
- if (magic != DENSE_MATRIX_MAGIC) {
- throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+
+ def construct(args: Array[Object]) :Object = {
+ require(args.length == 3)
+ new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int],
+ args(2).asInstanceOf[Array[Double]])
}
- val rows = bb.getInt()
- val cols = bb.getInt()
- if (packetLength != 9 + 8 * rows * cols) {
- throw new IllegalArgumentException("Size " + rows + "x" + cols + "
is wrong.")
+ }
+
+ // Pickler for SparseVector
+ private[python] class SparseVectorPickler
+ extends BasePickler[SparseVector] {
+
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ val v: SparseVector = obj.asInstanceOf[SparseVector]
+ saveObjects(out, pickler, v.size, v.indices, v.values)
}
- val db = bb.asDoubleBuffer()
- val ans = new Array[Array[Double]](rows.toInt)
- for (i <- 0 until rows.toInt) {
- ans(i) = new Array[Double](cols.toInt)
- db.get(ans(i))
+
+ def construct(args: Array[Object]) :Object = {
+ require(args.length == 3)
+ new SparseVector(args(0).asInstanceOf[Int],
args(1).asInstanceOf[Array[Int]],
+ args(2).asInstanceOf[Array[Double]])
}
- ans
}
- private[python] def serializeDoubleMatrix(doubles:
Array[Array[Double]]): Array[Byte] = {
- val rows = doubles.length
- var cols = 0
- if (rows > 0) {
- cols = doubles(0).length
+ // Pickler for LabeledPoint
+ private[python] class LabeledPointPickler
+ extends BasePickler[LabeledPoint] {
+
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ val point: LabeledPoint = obj.asInstanceOf[LabeledPoint]
+ saveObjects(out, pickler, point.label, point.features)
}
- val bytes = new Array[Byte](9 + 8 * rows * cols)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.put(DENSE_MATRIX_MAGIC)
- bb.putInt(rows)
- bb.putInt(cols)
- val db = bb.asDoubleBuffer()
- for (i <- 0 until rows) {
- db.put(doubles(i))
+
+ def construct(args: Array[Object]) :Object = {
+ if (args.length != 2) {
+ throw new PickleException("should be 2")
+ }
+ new LabeledPoint(args(0).asInstanceOf[Double],
args(1).asInstanceOf[Vector])
}
- bytes
}
- private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte]
= {
- val fb = serializeDoubleVector(p.features)
- val bytes = new Array[Byte](1 + 8 + fb.length)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.put(LABELED_POINT_MAGIC)
- bb.putDouble(p.label)
- bb.put(fb)
- bytes
- }
+ // Pickler for Rating
+ private[python] class RatingPickler
+ extends BasePickler[Rating] {
- private[python] def deserializeLabeledPoint(bytes: Array[Byte]):
LabeledPoint = {
- require(bytes.length >= 9, "Byte array too short")
- val magic = bytes(0)
- if (magic != LABELED_POINT_MAGIC) {
- throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ val rating: Rating = obj.asInstanceOf[Rating]
+ saveObjects(out, pickler, rating.user, rating.product, rating.rating)
}
- val labelBytes = ByteBuffer.wrap(bytes, 1, 8)
- labelBytes.order(ByteOrder.nativeOrder())
- val label = labelBytes.asDoubleBuffer().get(0)
- LabeledPoint(label, deserializeDoubleVector(bytes, 9))
- }
- // Reformat a Matrix into Array[Array[Double]] for serialization
- private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = {
- val values = matrix.toArray
- Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j
* matrix.numRows))
+ def construct(args: Array[Object]) :Object = {
+ if (args.length != 3) {
+ throw new PickleException("should be 3")
+ }
+ new Rating(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int],
+ args(2).asInstanceOf[Double])
+ }
}
+ def initialize(): Unit = {
+ new DenseVectorPickler().register()
+ new DenseMatrixPickler().register()
+ new SparseVectorPickler().register()
+ new LabeledPointPickler().register()
+ new RatingPickler().register()
+ }
- /** Unpack a Rating object from an array of bytes */
- private[python] def unpackRating(ratingBytes: Array[Byte]): Rating = {
- val bb = ByteBuffer.wrap(ratingBytes)
- bb.order(ByteOrder.nativeOrder())
- val user = bb.getInt()
- val product = bb.getInt()
- val rating = bb.getDouble()
- new Rating(user, product, rating)
+ def dumps(obj: AnyRef): Array[Byte] = {
+ new Pickler().dumps(obj)
}
- /** Unpack a tuple of Ints from an array of bytes */
- def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = {
- val bb = ByteBuffer.wrap(tupleBytes)
- bb.order(ByteOrder.nativeOrder())
- val v1 = bb.getInt()
- val v2 = bb.getInt()
- (v1, v2)
+ def loads(bytes: Array[Byte]): AnyRef = {
+ new Unpickler().loads(bytes)
}
- /**
- * Serialize a Rating object into an array of bytes.
- * It can be deserialized using RatingDeserializer().
- *
- * @param rate the Rating object to serialize
- * @return
- */
- def serializeRating(rate: Rating): Array[Byte] = {
- val len = 3
- val bytes = new Array[Byte](4 + 8 * len)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.putInt(len)
- val db = bb.asDoubleBuffer()
- db.put(rate.user.toDouble)
- db.put(rate.product.toDouble)
- db.put(rate.rating)
- bytes
+ /* convert object into Tuple */
+ def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = {
--- End diff --
OK, sounds good.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]