http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSpark.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSpark.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSpark.scala index e5a2b2a..41efc27 100644 --- a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSpark.scala +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSpark.scala @@ -33,7 +33,7 @@ import org.apache.spark.SparkContext._ /** ==Spark-specific optimizer-checkpointed DRM.== * - * @param rdd underlying rdd to wrap over. + * @param rddInput underlying rdd to wrap over. * @param _nrow number of rows; if unspecified, we will compute with an inexpensive traversal. * @param _ncol number of columns; if unspecified, we will try to guess with an inexpensive traversal. * @param _cacheStorageLevel storage level @@ -44,9 +44,9 @@ import org.apache.spark.SparkContext._ * @tparam K matrix key type (e.g. the keys of sequence files once persisted) */ class CheckpointedDrmSpark[K: ClassTag]( - val rdd: DrmRdd[K], - private var _nrow: Long = -1L, - private var _ncol: Int = -1, + private[sparkbindings] val rddInput: DrmRddInput[K], + private[sparkbindings] var _nrow: Long = -1L, + private[sparkbindings] var _ncol: Int = -1, private val _cacheStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY, override protected[mahout] val partitioningTag: Long = Random.nextLong(), private var _canHaveMissingRows: Boolean = false @@ -63,7 +63,7 @@ class CheckpointedDrmSpark[K: ClassTag]( private[mahout] var intFixExtra: Long = 0L private var cached: Boolean = false - override val context: DistributedContext = rdd.context + override val context: DistributedContext = rddInput.backingRdd.context /** Explicit extraction of key class Tag */ def keyClassTag: ClassTag[K] = implicitly[ClassTag[K]] @@ -78,8 +78,8 @@ class CheckpointedDrmSpark[K: ClassTag]( } def cache() = { - if (!cached) { - rdd.persist(_cacheStorageLevel) + if (!cached && _cacheStorageLevel != StorageLevel.NONE) { + rddInput.backingRdd.persist(_cacheStorageLevel) cached = true } this @@ -92,7 +92,7 @@ class CheckpointedDrmSpark[K: ClassTag]( */ def uncache(): this.type = { if (cached) { - rdd.unpersist(blocking = false) + rddInput.backingRdd.unpersist(blocking = false) cached = false } this @@ -115,7 +115,7 @@ class CheckpointedDrmSpark[K: ClassTag]( */ def collect: Matrix = { - val intRowIndices = implicitly[ClassTag[K]] == implicitly[ClassTag[Int]] + val intRowIndices = classTag[K] == ClassTag.Int val cols = ncol val rows = safeToNonNegInt(nrow) @@ -124,7 +124,7 @@ class CheckpointedDrmSpark[K: ClassTag]( // since currently spark #collect() requires Serializeable support, // we serialize DRM vectors into byte arrays on backend and restore Vector // instances on the front end: - val data = rdd.map(t => (t._1, t._2)).collect() + val data = rddInput.toDrmRdd().map(t => (t._1, t._2)).collect() val m = if (data.forall(_._2.isDense)) @@ -165,7 +165,7 @@ class CheckpointedDrmSpark[K: ClassTag]( else if (classOf[Writable].isAssignableFrom(ktag.runtimeClass)) (x: K) => x.asInstanceOf[Writable] else throw new IllegalArgumentException("Do not know how to convert class tag %s to Writable.".format(ktag)) - rdd.saveAsSequenceFile(path) + rddInput.toDrmRdd().saveAsSequenceFile(path) } protected def computeNRow = { @@ -173,7 +173,7 @@ class CheckpointedDrmSpark[K: ClassTag]( val intRowIndex = classTag[K] == classTag[Int] if (intRowIndex) { - val rdd = cache().rdd.asInstanceOf[DrmRdd[Int]] + val rdd = cache().rddInput.toDrmRdd().asInstanceOf[DrmRdd[Int]] // I guess it is a suitable place to compute int keys consistency test here because we know // that nrow can be computed lazily, which always happens when rdd is already available, cached, @@ -186,16 +186,21 @@ class CheckpointedDrmSpark[K: ClassTag]( intFixExtra = (maxPlus1 - rowCount) max 0L maxPlus1 } else - cache().rdd.count() + cache().rddInput.toDrmRdd().count() } - protected def computeNCol = - cache().rdd.map(_._2.length).fold(-1)(max(_, _)) + protected def computeNCol = { + rddInput.isBlockified match { + case true â rddInput.toBlockifiedDrmRdd(throw new AssertionError("not reached")) + .map(_._2.ncol).reduce(max(_, _)) + case false â cache().rddInput.toDrmRdd().map(_._2.length).fold(-1)(max(_, _)) + } + } protected def computeNNonZero = - cache().rdd.map(_._2.getNumNonZeroElements.toLong).sum().toLong + cache().rddInput.toDrmRdd().map(_._2.getNumNonZeroElements.toLong).sum().toLong /** Changes the number of rows in the DRM without actually touching the underlying data. Used to * redimension a DRM after it has been created, which implies some blank, non-existent rows. @@ -205,8 +210,8 @@ class CheckpointedDrmSpark[K: ClassTag]( override def newRowCardinality(n: Int): CheckpointedDrm[K] = { assert(n > -1) assert( n >= nrow) - val newCheckpointedDrm = drmWrap[K](rdd, n, ncol) - newCheckpointedDrm + new CheckpointedDrmSpark(rddInput = rddInput, _nrow = n, _ncol = _ncol, _cacheStorageLevel = _cacheStorageLevel, + partitioningTag = partitioningTag, _canHaveMissingRows = _canHaveMissingRows) } }
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSparkOps.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSparkOps.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSparkOps.scala index 7cf6bd6..abcfc64 100644 --- a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSparkOps.scala +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSparkOps.scala @@ -11,6 +11,6 @@ class CheckpointedDrmSparkOps[K: ClassTag](drm: CheckpointedDrm[K]) { private[sparkbindings] val sparkDrm = drm.asInstanceOf[CheckpointedDrmSpark[K]] /** Spark matrix customization exposure */ - def rdd = sparkDrm.rdd + def rdd = sparkDrm.rddInput.toDrmRdd() } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/DrmRddInput.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/DrmRddInput.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/DrmRddInput.scala index b72818c..d9dbada 100644 --- a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/DrmRddInput.scala +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/DrmRddInput.scala @@ -23,22 +23,18 @@ import org.apache.spark.storage.StorageLevel import org.apache.mahout.sparkbindings._ /** Encapsulates either DrmRdd[K] or BlockifiedDrmRdd[K] */ -class DrmRddInput[K: ClassTag]( - private val rowWiseSrc: Option[( /*ncol*/ Int, /*rdd*/ DrmRdd[K])] = None, - private val blockifiedSrc: Option[BlockifiedDrmRdd[K]] = None - ) { +class DrmRddInput[K: ClassTag](private val input: Either[DrmRdd[K], BlockifiedDrmRdd[K]]) { - assert(rowWiseSrc.isDefined || blockifiedSrc.isDefined, "Undefined input") + private[sparkbindings] lazy val backingRdd = input.left.getOrElse(input.right.get) - private lazy val backingRdd = rowWiseSrc.map(_._2).getOrElse(blockifiedSrc.get) + def isBlockified: Boolean = input.isRight - def isBlockified:Boolean = blockifiedSrc.isDefined + def isRowWise: Boolean = input.isLeft - def isRowWise:Boolean = rowWiseSrc.isDefined + def toDrmRdd(): DrmRdd[K] = input.left.getOrElse(deblockify(rdd = input.right.get)) - def toDrmRdd(): DrmRdd[K] = rowWiseSrc.map(_._2).getOrElse(deblockify(rdd = blockifiedSrc.get)) - - def toBlockifiedDrmRdd() = blockifiedSrc.getOrElse(blockify(rdd = rowWiseSrc.get._2, blockncol = rowWiseSrc.get._1)) + /** Use late binding for this. It may or may not be needed, depending on current config. */ + def toBlockifiedDrmRdd(ncol: â Int) = input.right.getOrElse(blockify(rdd = input.left.get, blockncol = ncol)) def sparkContext: SparkContext = backingRdd.sparkContext http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/SparkBCast.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/SparkBCast.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/SparkBCast.scala index ac36f60..0371f9b 100644 --- a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/SparkBCast.scala +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/SparkBCast.scala @@ -22,4 +22,6 @@ import org.apache.spark.broadcast.Broadcast class SparkBCast[T](val sbcast: Broadcast[T]) extends BCast[T] with Serializable { def value: T = sbcast.value + + override def close(): Unit = sbcast.unpersist() } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala index c04b306..0de5ff8 100644 --- a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala @@ -37,18 +37,19 @@ package object drm { private[drm] final val log = Logger.getLogger("org.apache.mahout.sparkbindings"); - private[sparkbindings] implicit def input2drmRdd[K](input: DrmRddInput[K]): DrmRdd[K] = input.toDrmRdd() + private[sparkbindings] implicit def cpDrm2DrmRddInput[K: ClassTag](cp: CheckpointedDrmSpark[K]): DrmRddInput[K] = + cp.rddInput - private[sparkbindings] implicit def input2blockifiedDrmRdd[K](input: DrmRddInput[K]): BlockifiedDrmRdd[K] = input.toBlockifiedDrmRdd() + private[sparkbindings] implicit def cpDrmGeneric2DrmRddInput[K: ClassTag](cp: CheckpointedDrm[K]): DrmRddInput[K] = + cp.asInstanceOf[CheckpointedDrmSpark[K]] + + private[sparkbindings] implicit def drmRdd2drmRddInput[K: ClassTag](rdd: DrmRdd[K]) = new DrmRddInput[K](Left(rdd)) + + private[sparkbindings] implicit def blockifiedRdd2drmRddInput[K: ClassTag](rdd: BlockifiedDrmRdd[K]) = new + DrmRddInput[K]( + Right(rdd)) - private[sparkbindings] implicit def cpDrm2DrmRddInput[K: ClassTag](cp: CheckpointedDrm[K]): DrmRddInput[K] = - new DrmRddInput(rowWiseSrc = Some(cp.ncol -> cp.rdd)) -// /** Broadcast vector (Mahout vectors are not closure-friendly, use this instead. */ -// private[sparkbindings] def drmBroadcast(x: Vector)(implicit sc: SparkContext): Broadcast[Vector] = sc.broadcast(x) -// -// /** Broadcast in-core Mahout matrix. Use this instead of closure. */ -// private[sparkbindings] def drmBroadcast(m: Matrix)(implicit sc: SparkContext): Broadcast[Matrix] = sc.broadcast(m) /** Implicit broadcast cast for Spark physical op implementations. */ private[sparkbindings] implicit def bcast2val[K](bcast:Broadcast[K]):K = bcast.value @@ -74,7 +75,7 @@ package object drm { } block } else { - new SparseRowMatrix(vectors.size, blockncol, vectors) + new SparseRowMatrix(vectors.size, blockncol, vectors, true, false) } Iterator(keys -> block) @@ -101,7 +102,7 @@ package object drm { blockKeys.ensuring(blockKeys.size == block.nrow) blockKeys.view.zipWithIndex.map { case (key, idx) => - var v = block(idx, ::) // This is just a view! + val v = block(idx, ::) // This is just a view! // If a view rather than a concrete vector, clone into a concrete vector in order not to // attempt to serialize outer matrix when we save it (Although maybe most often this http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/GenericMatrixKryoSerializer.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/GenericMatrixKryoSerializer.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/GenericMatrixKryoSerializer.scala new file mode 100644 index 0000000..da58b35 --- /dev/null +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/GenericMatrixKryoSerializer.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.sparkbindings.io + + +import com.esotericsoftware.kryo.io.{Output, Input} +import com.esotericsoftware.kryo.{Kryo, Serializer} +import org.apache.log4j.Logger +import org.apache.mahout.logging._ +import org.apache.mahout.math._ +import org.apache.mahout.math.flavor.TraversingStructureEnum +import scalabindings._ +import RLikeOps._ +import collection._ +import JavaConversions._ + +object GenericMatrixKryoSerializer { + + private implicit final val log = Logger.getLogger(classOf[GenericMatrixKryoSerializer]) + +} + +/** Serializes Sparse or Dense in-core generic matrix (row-wise or column-wise backed) */ +class GenericMatrixKryoSerializer extends Serializer[Matrix] { + + import GenericMatrixKryoSerializer._ + + override def write(kryo: Kryo, output: Output, mx: Matrix): Unit = { + + debug(s"Writing mx of type ${mx.getClass.getName}") + + val structure = mx.getFlavor.getStructure + + // Write structure bit + output.writeInt(structure.ordinal(), true) + + // Write geometry + output.writeInt(mx.nrow, true) + output.writeInt(mx.ncol, true) + + // Write in most efficient traversal order (using backing vectors perhaps) + structure match { + case TraversingStructureEnum.COLWISE => writeRowWise(kryo, output, mx.t) + case TraversingStructureEnum.SPARSECOLWISE => writeSparseRowWise(kryo, output, mx.t) + case TraversingStructureEnum.SPARSEROWWISE => writeSparseRowWise(kryo, output, mx) + case TraversingStructureEnum.VECTORBACKED => writeVectorBacked(kryo, output, mx) + case _ => writeRowWise(kryo, output, mx) + } + + } + + private def writeVectorBacked(kryo: Kryo, output: Output, mx: Matrix) { + + require(mx != null) + + // At this point we are just doing some vector-backed classes individually. TODO: create + // api to obtain vector-backed matrix data. + kryo.writeClass(output, mx.getClass) + mx match { + case mxD: DiagonalMatrix => kryo.writeObject(output, mxD.diagv) + case mxS: DenseSymmetricMatrix => kryo.writeObject(output, dvec(mxS.getData)) + case mxT: UpperTriangular => kryo.writeObject(output, dvec(mxT.getData)) + case _ => throw new IllegalArgumentException(s"Unsupported matrix type:${mx.getClass.getName}") + } + } + + private def readVectorBacked(kryo: Kryo, input: Input, nrow: Int, ncol: Int) = { + + // We require vector-backed matrices to have vector-parameterized constructor to construct. + val clazz = kryo.readClass(input).getType + + debug(s"Deserializing vector-backed mx of type ${clazz.getName}.") + + clazz.getConstructor(classOf[Vector]).newInstance(kryo.readObject(input, classOf[Vector])).asInstanceOf[Matrix] + } + + private def writeRowWise(kryo: Kryo, output: Output, mx: Matrix): Unit = { + for (row <- mx) kryo.writeObject(output, row) + } + + private def readRows(kryo: Kryo, input: Input, nrow: Int) = { + Array.tabulate(nrow) { _ => kryo.readObject(input, classOf[Vector])} + } + + private def readSparseRows(kryo: Kryo, input: Input) = { + + // Number of slices + val nslices = input.readInt(true) + + Array.tabulate(nslices) { _ => + input.readInt(true) -> kryo.readObject(input, classOf[Vector]) + } + } + + private def writeSparseRowWise(kryo: Kryo, output: Output, mx: Matrix): Unit = { + + val nslices = mx.numSlices() + + output.writeInt(nslices, true) + + var actualNSlices = 0; + for (row <- mx.iterateNonEmpty()) { + output.writeInt(row.index(), true) + kryo.writeObject(output, row.vector()) + actualNSlices += 1 + } + + require(nslices == actualNSlices, "Number of slices reported by Matrix.numSlices() was different from actual " + + "slice iterator size.") + } + + override def read(kryo: Kryo, input: Input, mxClass: Class[Matrix]): Matrix = { + + // Read structure hint + val structure = TraversingStructureEnum.values()(input.readInt(true)) + + // Read geometry + val nrow = input.readInt(true) + val ncol = input.readInt(true) + + debug(s"read matrix geometry: $nrow x $ncol.") + + structure match { + + // Sparse or dense column wise + case TraversingStructureEnum.COLWISE => + val cols = readRows(kryo, input, ncol) + + if (!cols.isEmpty && cols.head.isDense) + dense(cols).t + else { + debug("Deserializing as SparseRowMatrix.t (COLWISE).") + new SparseRowMatrix(ncol, nrow, cols, true, false).t + } + + // transposed SparseMatrix case + case TraversingStructureEnum.SPARSECOLWISE => + val cols = readSparseRows(kryo, input) + val javamap = new java.util.HashMap[Integer, Vector]((cols.size << 1) + 1) + cols.foreach { case (idx, vec) => javamap.put(idx, vec)} + + debug("Deserializing as SparseMatrix.t (SPARSECOLWISE).") + new SparseMatrix(ncol, nrow, javamap, true).t + + // Sparse Row-wise -- this will be created as a SparseMatrix. + case TraversingStructureEnum.SPARSEROWWISE => + val rows = readSparseRows(kryo, input) + val javamap = new java.util.HashMap[Integer, Vector]((rows.size << 1) + 1) + rows.foreach { case (idx, vec) => javamap.put(idx, vec)} + + debug("Deserializing as SparseMatrix (SPARSEROWWISE).") + new SparseMatrix(nrow, ncol, javamap, true) + case TraversingStructureEnum.VECTORBACKED => + + debug("Deserializing vector-backed...") + readVectorBacked(kryo, input, nrow, ncol) + + // By default, read row-wise. + case _ => + val cols = readRows(kryo, input, nrow) + // this still copies a lot of stuff... + if (!cols.isEmpty && cols.head.isDense) { + + debug("Deserializing as DenseMatrix.") + dense(cols) + } else { + + debug("Deserializing as SparseRowMatrix(default).") + new SparseRowMatrix(nrow, ncol, cols, true, false) + } + } + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/MahoutKryoRegistrator.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/MahoutKryoRegistrator.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/MahoutKryoRegistrator.scala index a8a0bb4..5806ff5 100644 --- a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/MahoutKryoRegistrator.scala +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/MahoutKryoRegistrator.scala @@ -18,22 +18,28 @@ package org.apache.mahout.sparkbindings.io import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.serializers.JavaSerializer import org.apache.mahout.math._ -import org.apache.mahout.math.indexeddataset.{BiMap, BiDictionary} import org.apache.spark.serializer.KryoRegistrator -import org.apache.mahout.sparkbindings._ -import org.apache.mahout.math.Vector.Element +import org.apache.mahout.logging._ -import scala.collection.immutable.List +object MahoutKryoRegistrator { -/** Kryo serialization registrator for Mahout */ -class MahoutKryoRegistrator extends KryoRegistrator { + private final implicit val log = getLog(this.getClass) + + def registerClasses(kryo: Kryo) = { - override def registerClasses(kryo: Kryo) = { + trace("Registering mahout classes.") + + kryo.register(classOf[SparseColumnMatrix], new UnsupportedSerializer) + kryo.addDefaultSerializer(classOf[Vector], new VectorKryoSerializer()) + kryo.addDefaultSerializer(classOf[Matrix], new GenericMatrixKryoSerializer) - kryo.addDefaultSerializer(classOf[Vector], new WritableKryoSerializer[Vector, VectorWritable]) - kryo.addDefaultSerializer(classOf[DenseVector], new WritableKryoSerializer[Vector, VectorWritable]) - kryo.addDefaultSerializer(classOf[Matrix], new WritableKryoSerializer[Matrix, MatrixWritable]) } + +} + +/** Kryo serialization registrator for Mahout */ +class MahoutKryoRegistrator extends KryoRegistrator { + + override def registerClasses(kryo: Kryo) = MahoutKryoRegistrator.registerClasses(kryo) } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/UnsupportedSerializer.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/UnsupportedSerializer.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/UnsupportedSerializer.scala new file mode 100644 index 0000000..66b79f4 --- /dev/null +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/UnsupportedSerializer.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.sparkbindings.io + +import com.esotericsoftware.kryo.io.{Output, Input} +import com.esotericsoftware.kryo.{Kryo, Serializer} + +class UnsupportedSerializer extends Serializer[Any] { + + override def write(kryo: Kryo, output: Output, obj: Any): Unit = { + throw new IllegalArgumentException(s"I/O of this type(${obj.getClass.getName} is explicitly unsupported for a " + + "good reason.") + } + + override def read(kryo: Kryo, input: Input, `type`: Class[Any]): Any = ??? +} http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/VectorKryoSerializer.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/VectorKryoSerializer.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/VectorKryoSerializer.scala new file mode 100644 index 0000000..175778f --- /dev/null +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/VectorKryoSerializer.scala @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.sparkbindings.io + +import org.apache.log4j.Logger +import org.apache.mahout.logging._ +import org.apache.mahout.math._ +import org.apache.mahout.math.scalabindings._ +import RLikeOps._ + +import com.esotericsoftware.kryo.io.{OutputChunked, Output, Input} +import com.esotericsoftware.kryo.{Kryo, Serializer} + +import collection._ +import JavaConversions._ + + +object VectorKryoSerializer { + + final val FLAG_DENSE: Int = 0x01 + final val FLAG_SEQUENTIAL: Int = 0x02 + final val FLAG_NAMED: Int = 0x04 + final val FLAG_LAX_PRECISION: Int = 0x08 + + private final implicit val log = getLog(classOf[VectorKryoSerializer]) + +} + +class VectorKryoSerializer(val laxPrecision: Boolean = false) extends Serializer[Vector] { + + import VectorKryoSerializer._ + + override def write(kryo: Kryo, output: Output, vector: Vector): Unit = { + + require(vector != null) + + trace(s"Serializing vector of ${vector.getClass.getName} class.") + + // Write length + val len = vector.length + output.writeInt(len, true) + + // Interrogate vec properties + val dense = vector.isDense + val sequential = vector.isSequentialAccess + val named = vector.isInstanceOf[NamedVector] + + var flag = 0 + + if (dense) { + flag |= FLAG_DENSE + } else if (sequential) { + flag |= FLAG_SEQUENTIAL + } + + if (vector.isInstanceOf[NamedVector]) { + flag |= FLAG_NAMED + } + + if (laxPrecision) flag |= FLAG_LAX_PRECISION + + // Write flags + output.writeByte(flag) + + // Write name if needed + if (named) output.writeString(vector.asInstanceOf[NamedVector].getName) + + dense match { + + // Dense vector. + case true => + + laxPrecision match { + case true => + for (i <- 0 until vector.length) output.writeFloat(vector(i).toFloat) + case _ => + for (i <- 0 until vector.length) output.writeDouble(vector(i)) + } + case _ => + + // Turns out getNumNonZeroElements must check every element if it is indeed non-zero. The + // iterateNonZeros() on the other hand doesn't do that, so that's all inconsistent right + // now. so we'll just auto-terminate. + val iter = vector.nonZeroes.toIterator.filter(_.get() != 0.0) + + sequential match { + + // Delta encoding + case true => + + var idx = 0 + laxPrecision match { + case true => + while (iter.hasNext) { + val el = iter.next() + output.writeFloat(el.toFloat) + output.writeInt(el.index() - idx, true) + idx = el.index + } + // Terminate delta encoding. + output.writeFloat(0.0.toFloat) + case _ => + while (iter.hasNext) { + val el = iter.next() + output.writeDouble(el.get()) + output.writeInt(el.index() - idx, true) + idx = el.index + } + // Terminate delta encoding. + output.writeDouble(0.0) + } + + // Random access. + case _ => + + laxPrecision match { + + case true => + iter.foreach { el => + output.writeFloat(el.get().toFloat) + output.writeInt(el.index(), true) + } + // Terminate random access with 0.0 value. + output.writeFloat(0.0.toFloat) + case _ => + iter.foreach { el => + output.writeDouble(el.get()) + output.writeInt(el.index(), true) + } + // Terminate random access with 0.0 value. + output.writeDouble(0.0) + } + + } + + } + } + + override def read(kryo: Kryo, input: Input, vecClass: Class[Vector]): Vector = { + + val len = input.readInt(true) + val flags = input.readByte().toInt + val name = if ((flags & FLAG_NAMED) != 0) Some(input.readString()) else None + + val vec: Vector = flags match { + + // Dense + case _: Int if ((flags & FLAG_DENSE) != 0) => + + trace(s"Deserializing dense vector.") + + if ((flags & FLAG_LAX_PRECISION) != 0) { + new DenseVector(len) := { _ => input.readFloat()} + } else { + new DenseVector(len) := { _ => input.readDouble()} + } + + // Sparse case. + case _ => + + flags match { + + // Sequential. + case _: Int if ((flags & FLAG_SEQUENTIAL) != 0) => + + trace("Deserializing as sequential sparse vector.") + + val v = new SequentialAccessSparseVector(len) + var idx = 0 + var stop = false + + if ((flags & FLAG_LAX_PRECISION) != 0) { + + while (!stop) { + val value = input.readFloat() + if (value == 0.0) { + stop = true + } else { + idx += input.readInt(true) + v(idx) = value + } + } + } else { + while (!stop) { + val value = input.readDouble() + if (value == 0.0) { + stop = true + } else { + idx += input.readInt(true) + v(idx) = value + } + } + } + v + + // Random access + case _ => + + trace("Deserializing as random access vector.") + + // Read pairs until we see 0.0 value. Prone to corruption attacks obviously. + val v = new RandomAccessSparseVector(len) + var stop = false + if ((flags & FLAG_LAX_PRECISION) != 0) { + while (! stop ) { + val value = input.readFloat() + if ( value == 0.0 ) { + stop = true + } else { + val idx = input.readInt(true) + v(idx) = value + } + } + } else { + while (! stop ) { + val value = input.readDouble() + if (value == 0.0) { + stop = true + } else { + val idx = input.readInt(true) + v(idx) = value + } + } + } + v + } + } + + name.map{name => + + trace(s"Recovering named vector's name ${name}.") + + new NamedVector(vec, name) + } + .getOrElse(vec) + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala index 02f6b8c..330ae38 100644 --- a/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala @@ -17,27 +17,27 @@ package org.apache.mahout -import org.apache.mahout.drivers.TextDelimitedIndexedDatasetReader -import org.apache.mahout.math.indexeddataset.Schema -import org.apache.mahout.sparkbindings.indexeddataset.IndexedDatasetSpark -import org.apache.spark.{SparkConf, SparkContext} import java.io._ -import scala.collection.mutable.ArrayBuffer -import org.apache.mahout.common.IOUtils -import org.apache.log4j.Logger + +import org.apache.mahout.logging._ import org.apache.mahout.math.drm._ -import scala.reflect.ClassTag -import org.apache.mahout.sparkbindings.drm.{DrmRddInput, SparkBCast, CheckpointedDrmSparkOps, CheckpointedDrmSpark} -import org.apache.spark.rdd.RDD +import org.apache.mahout.math.{MatrixWritable, VectorWritable, Matrix, Vector} +import org.apache.mahout.sparkbindings.drm.{CheckpointedDrmSpark, CheckpointedDrmSparkOps, SparkBCast} +import org.apache.mahout.util.IOUtilsScala import org.apache.spark.broadcast.Broadcast -import org.apache.mahout.math.{VectorWritable, Vector, MatrixWritable, Matrix} -import org.apache.hadoop.io.Writable -import org.apache.spark.storage.StorageLevel +import org.apache.spark.rdd.RDD +import org.apache.spark.{SparkConf, SparkContext} + +import collection._ +import collection.generic.Growable +import scala.reflect.ClassTag + + /** Public api for Spark-specific operators */ package object sparkbindings { - private[sparkbindings] val log = Logger.getLogger("org.apache.mahout.sparkbindings") + private final implicit val log = getLog(`package`.getClass) /** Row-wise organized DRM rdd type */ type DrmRdd[K] = RDD[DrmTuple[K]] @@ -55,15 +55,11 @@ package object sparkbindings { * @param customJars * @return */ - def mahoutSparkContext( - masterUrl: String, - appName: String, - customJars: TraversableOnce[String] = Nil, - sparkConf: SparkConf = new SparkConf(), - addMahoutJars: Boolean = true - ): SparkDistributedContext = { + def mahoutSparkContext(masterUrl: String, appName: String, customJars: TraversableOnce[String] = Nil, + sparkConf: SparkConf = new SparkConf(), addMahoutJars: Boolean = true): + SparkDistributedContext = { - val closeables = new java.util.ArrayDeque[Closeable]() + val closeables = mutable.ListBuffer.empty[Closeable] try { @@ -84,9 +80,9 @@ package object sparkbindings { sparkConf.setJars(customJars.toSeq) } - sparkConf.setAppName(appName).setMaster(masterUrl) - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.kryo.registrator", "org.apache.mahout.sparkbindings.io.MahoutKryoRegistrator") + sparkConf.setAppName(appName).setMaster(masterUrl).set("spark.serializer", + "org.apache.spark.serializer.KryoSerializer").set("spark.kryo.registrator", + "org.apache.mahout.sparkbindings.io.MahoutKryoRegistrator") if (System.getenv("SPARK_HOME") != null) { sparkConf.setSparkHome(System.getenv("SPARK_HOME")) @@ -95,7 +91,7 @@ package object sparkbindings { new SparkDistributedContext(new SparkContext(config = sparkConf)) } finally { - IOUtils.close(closeables) + IOUtilsScala.close(closeables) } } @@ -103,19 +99,19 @@ package object sparkbindings { implicit def sc2sdc(sc: SparkContext): SparkDistributedContext = new SparkDistributedContext(sc) - implicit def dc2sc(dc:DistributedContext):SparkContext = { - assert (dc.isInstanceOf[SparkDistributedContext],"distributed context must be Spark-specific.") + implicit def dc2sc(dc: DistributedContext): SparkContext = { + assert(dc.isInstanceOf[SparkDistributedContext], "distributed context must be Spark-specific.") sdc2sc(dc.asInstanceOf[SparkDistributedContext]) } /** Broadcast transforms */ - implicit def sb2bc[T](b:Broadcast[T]):BCast[T] = new SparkBCast(b) + implicit def sb2bc[T](b: Broadcast[T]): BCast[T] = new SparkBCast(b) /** Adding Spark-specific ops */ implicit def cpDrm2cpDrmSparkOps[K: ClassTag](drm: CheckpointedDrm[K]): CheckpointedDrmSparkOps[K] = new CheckpointedDrmSparkOps[K](drm) - implicit def drm2cpDrmSparkOps[K:ClassTag](drm:DrmLike[K]):CheckpointedDrmSparkOps[K] = drm:CheckpointedDrm[K] + implicit def drm2cpDrmSparkOps[K: ClassTag](drm: DrmLike[K]): CheckpointedDrmSparkOps[K] = drm: CheckpointedDrm[K] private[sparkbindings] implicit def m2w(m: Matrix): MatrixWritable = new MatrixWritable(m) @@ -123,7 +119,7 @@ package object sparkbindings { private[sparkbindings] implicit def v2w(v: Vector): VectorWritable = new VectorWritable(v) - private[sparkbindings] implicit def w2v(w:VectorWritable):Vector = w.get() + private[sparkbindings] implicit def w2v(w: VectorWritable): Vector = w.get() /** * ==Wrap existing RDD into a matrix== @@ -141,34 +137,31 @@ package object sparkbindings { * @tparam K row key type * @return wrapped DRM */ - def drmWrap[K: ClassTag]( - rdd: DrmRdd[K], - nrow: Int = -1, - ncol: Int = -1, - cacheHint: CacheHint.CacheHint = CacheHint.NONE, - canHaveMissingRows: Boolean = false - ): CheckpointedDrm[K] = - - new CheckpointedDrmSpark[K]( - rdd = rdd, - _nrow = nrow, - _ncol = ncol, - _cacheStorageLevel = SparkEngine.cacheHint2Spark(cacheHint), - _canHaveMissingRows = canHaveMissingRows - ) + def drmWrap[K: ClassTag](rdd: DrmRdd[K], nrow: Long = -1, ncol: Int = -1, cacheHint: CacheHint.CacheHint = + CacheHint.NONE, canHaveMissingRows: Boolean = false): CheckpointedDrm[K] = + + new CheckpointedDrmSpark[K](rddInput = rdd, _nrow = nrow, _ncol = ncol, _cacheStorageLevel = SparkEngine + .cacheHint2Spark(cacheHint), _canHaveMissingRows = canHaveMissingRows) + + + /** Another drmWrap version that takes in vertical block-partitioned input to form the matrix. */ + def drmWrapBlockified[K: ClassTag](blockifiedDrmRdd: BlockifiedDrmRdd[K], nrow: Long = -1, ncol: Int = -1, + cacheHint: CacheHint.CacheHint = CacheHint.NONE, + canHaveMissingRows: Boolean = false): CheckpointedDrm[K] = + + drmWrap(drm.deblockify(blockifiedDrmRdd), nrow, ncol, cacheHint, canHaveMissingRows) private[sparkbindings] def getMahoutHome() = { var mhome = System.getenv("MAHOUT_HOME") if (mhome == null) mhome = System.getProperty("mahout.home") - require(mhome != null, "MAHOUT_HOME is required to spawn mahout-based spark jobs" ) + require(mhome != null, "MAHOUT_HOME is required to spawn mahout-based spark jobs") mhome } /** Acquire proper Mahout jars to be added to task context based on current MAHOUT_HOME. */ - private[sparkbindings] def findMahoutContextJars(closeables:java.util.Deque[Closeable]) = { + private[sparkbindings] def findMahoutContextJars(closeables: Growable[Closeable]) = { // Figure Mahout classpath using $MAHOUT_HOME/mahout classpath command. - val fmhome = new File(getMahoutHome()) val bin = new File(fmhome, "bin") val exec = new File(bin, "mahout") @@ -177,26 +170,25 @@ package object sparkbindings { val p = Runtime.getRuntime.exec(Array(exec.getAbsolutePath, "-spark", "classpath")) - closeables.addFirst(new Closeable { + closeables += new Closeable { def close() { p.destroy() } - }) + } val r = new BufferedReader(new InputStreamReader(p.getInputStream)) - closeables.addFirst(r) + closeables += r val w = new StringWriter() - closeables.addFirst(w) + closeables += w var continue = true; - val jars = new ArrayBuffer[String]() + val jars = new mutable.ArrayBuffer[String]() do { val cp = r.readLine() if (cp == null) - throw new IllegalArgumentException( - "Unable to read output from \"mahout -spark classpath\". Is SPARK_HOME defined?" - ) + throw new IllegalArgumentException("Unable to read output from \"mahout -spark classpath\". Is SPARK_HOME " + + "defined?") val j = cp.split(File.pathSeparatorChar) if (j.size > 10) { @@ -206,8 +198,7 @@ package object sparkbindings { } } while (continue) -// jars.foreach(j => log.info(j)) - + // jars.foreach(j => log.info(j)) // context specific jars val mcjars = jars.filter(j => j.matches(".*mahout-math-\\d.*\\.jar") || @@ -233,4 +224,13 @@ package object sparkbindings { mcjars } + private[sparkbindings] def validateBlockifiedDrmRdd[K](rdd: BlockifiedDrmRdd[K]): Boolean = { + // Mostly, here each block must contain exactly one block + val part1Req = rdd.mapPartitions(piter => Iterator(piter.size == 1)).reduce(_ && _) + + if (!part1Req) warn("blockified rdd: condition not met: exactly 1 per partition") + + return part1Req + } + } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/test/scala/org/apache/mahout/sparkbindings/SparkBindingsSuite.scala ---------------------------------------------------------------------- diff --git a/spark/src/test/scala/org/apache/mahout/sparkbindings/SparkBindingsSuite.scala b/spark/src/test/scala/org/apache/mahout/sparkbindings/SparkBindingsSuite.scala index fbc31f3..529d13c 100644 --- a/spark/src/test/scala/org/apache/mahout/sparkbindings/SparkBindingsSuite.scala +++ b/spark/src/test/scala/org/apache/mahout/sparkbindings/SparkBindingsSuite.scala @@ -1,10 +1,12 @@ package org.apache.mahout.sparkbindings -import org.scalatest.FunSuite +import java.io.{Closeable, File} import java.util -import java.io.{File, Closeable} -import org.apache.mahout.common.IOUtils + import org.apache.mahout.sparkbindings.test.DistributedSparkSuite +import org.apache.mahout.util.IOUtilsScala +import org.scalatest.FunSuite +import collection._ /** * @author dmitriy @@ -16,7 +18,7 @@ class SparkBindingsSuite extends FunSuite with DistributedSparkSuite { // let it to be ignored. ignore("context jars") { System.setProperty("mahout.home", new File("..").getAbsolutePath/*"/home/dmitriy/projects/github/mahout-commits"*/) - val closeables = new util.ArrayDeque[Closeable]() + val closeables = new mutable.ListBuffer[Closeable]() try { val mahoutJars = findMahoutContextJars(closeables) mahoutJars.foreach { @@ -26,7 +28,7 @@ class SparkBindingsSuite extends FunSuite with DistributedSparkSuite { mahoutJars.size should be > 0 mahoutJars.size shouldBe 4 } finally { - IOUtils.close(closeables) + IOUtilsScala.close(closeables) } } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/test/scala/org/apache/mahout/sparkbindings/blas/BlasSuite.scala ---------------------------------------------------------------------- diff --git a/spark/src/test/scala/org/apache/mahout/sparkbindings/blas/BlasSuite.scala b/spark/src/test/scala/org/apache/mahout/sparkbindings/blas/BlasSuite.scala index 1521cb8..8c8ac3f 100644 --- a/spark/src/test/scala/org/apache/mahout/sparkbindings/blas/BlasSuite.scala +++ b/spark/src/test/scala/org/apache/mahout/sparkbindings/blas/BlasSuite.scala @@ -26,7 +26,7 @@ import scalabindings._ import RLikeOps._ import drm._ import org.apache.mahout.sparkbindings._ -import org.apache.mahout.sparkbindings.drm.CheckpointedDrmSpark +import org.apache.mahout.sparkbindings.drm._ import org.apache.mahout.math.drm.logical.{OpAt, OpAtA, OpAewB, OpABt} import org.apache.mahout.sparkbindings.test.DistributedSparkSuite @@ -142,7 +142,7 @@ class BlasSuite extends FunSuite with DistributedSparkSuite { val drmA = drmParallelize(m = inCoreA, numPartitions = 2) val op = new OpAt(drmA) - val drmAt = new CheckpointedDrmSpark(rdd = At.at(op, srcA = drmA), _nrow = op.nrow, _ncol = op.ncol) + val drmAt = new CheckpointedDrmSpark(rddInput = At.at(op, srcA = drmA), _nrow = op.nrow, _ncol = op.ncol) val inCoreAt = drmAt.collect val inCoreControlAt = inCoreA.t http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeOpsSuite.scala ---------------------------------------------------------------------- diff --git a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeOpsSuite.scala b/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeOpsSuite.scala index 42026ae..7241660 100644 --- a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeOpsSuite.scala +++ b/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeOpsSuite.scala @@ -23,13 +23,14 @@ import drm._ import RLikeOps._ import RLikeDrmOps._ import org.apache.mahout.sparkbindings._ -import org.scalatest.FunSuite +import org.scalatest.{ConfigMap, BeforeAndAfterAllConfigMap, FunSuite} import org.apache.mahout.sparkbindings.test.DistributedSparkSuite +import scala.reflect.ClassTag + /** Tests for DrmLikeOps */ class DrmLikeOpsSuite extends FunSuite with DistributedSparkSuite with DrmLikeOpsSuiteBase { - test("exact, min and auto ||") { val inCoreA = dense((1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)) val A = drmParallelize(m = inCoreA, numPartitions = 2) @@ -39,18 +40,20 @@ class DrmLikeOpsSuite extends FunSuite with DistributedSparkSuite with DrmLikeOp (A + 1.0).par(exact = 4).rdd.partitions.size should equal(4) A.par(exact = 2).rdd.partitions.size should equal(2) A.par(exact = 1).rdd.partitions.size should equal(1) - A.par(exact = 0).rdd.partitions.size should equal(2) // No effect for par <= 0 + A.par(min = 4).rdd.partitions.size should equal(4) A.par(min = 2).rdd.partitions.size should equal(2) A.par(min = 1).rdd.partitions.size should equal(2) A.par(auto = true).rdd.partitions.size should equal(10) A.par(exact = 10).par(auto = true).rdd.partitions.size should equal(10) A.par(exact = 11).par(auto = true).rdd.partitions.size should equal(19) - A.par(exact = 20).par(auto = true).rdd.partitions.size should equal(20) + A.par(exact = 20).par(auto = true).rdd.partitions.size should equal(19) + + A.keyClassTag shouldBe ClassTag.Int + A.par(auto = true).keyClassTag shouldBe ClassTag.Int - intercept[AssertionError] { - A.par() - } + an[IllegalArgumentException] shouldBe thrownBy {A.par(exact = 0)} + an[IllegalArgumentException] shouldBe thrownBy {A.par()} } } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala ---------------------------------------------------------------------- diff --git a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala b/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala index 2a4f213..f422f86 100644 --- a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala +++ b/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala @@ -25,10 +25,16 @@ import drm._ import org.apache.mahout.sparkbindings._ import RLikeDrmOps._ import test.DistributedSparkSuite +import org.apache.mahout.math.drm.logical.{OpAtB, OpAewUnaryFuncFusion} +import org.apache.mahout.logging._ + +import scala.util.Random /** ==R-like DRM DSL operation tests -- Spark== */ class RLikeDrmOpsSuite extends FunSuite with DistributedSparkSuite with RLikeDrmOpsSuiteBase { + private final implicit val log = getLog(classOf[RLikeDrmOpsSuite]) + test("C = A + B missing rows") { val sc = mahoutCtx.asInstanceOf[SparkDistributedContext].sc @@ -113,4 +119,61 @@ class RLikeDrmOpsSuite extends FunSuite with DistributedSparkSuite with RLikeDrm } + test("A'B, bigger") { + + val rnd = new Random() + val a = new SparseRowMatrix(200, 1544) := { _ => rnd.nextGaussian() } + val b = new SparseRowMatrix(200, 300) := { _ => rnd.nextGaussian() } + + var ms = System.currentTimeMillis() + val atb = a.t %*% b + ms = System.currentTimeMillis() - ms + + println(s"in-core mul ms: $ms") + + val drmA = drmParallelize(a, numPartitions = 2) + val drmB = drmParallelize(b, numPartitions = 2) + + ms = System.currentTimeMillis() + val drmAtB = drmA.t %*% drmB + val mxAtB = drmAtB.collect + ms = System.currentTimeMillis() - ms + + println(s"a'b plan:${drmAtB.context.engine.optimizerRewrite(drmAtB)}") + println(s"a'b plan contains ${drmAtB.rdd.partitions.size} partitions.") + println(s"distributed mul ms: $ms.") + + (atb - mxAtB).norm should be < 1e-5 + + } + + test("C = At %*% B , zippable") { + + val mxA = dense((1, 2), (3, 4), (-3, -5)) + + val A = drmParallelize(mxA, numPartitions = 2) + .mapBlock()({ + case (keys, block) => keys.map(_.toString) -> block + }) + + val B = (A + 1.0) + + .mapBlock() { case (keys, block) â + val nblock = new SparseRowMatrix(block.nrow, block.ncol) := block + keys â nblock + } + + B.collect + + val C = A.t %*% B + + mahoutCtx.optimizerRewrite(C) should equal(OpAtB[String](A, B)) + + val inCoreC = C.collect + val inCoreControlC = mxA.t %*% (mxA + 1.0) + + (inCoreC - inCoreControlC).norm should be < 1E-10 + + } + } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/test/scala/org/apache/mahout/sparkbindings/io/IOSuite.scala ---------------------------------------------------------------------- diff --git a/spark/src/test/scala/org/apache/mahout/sparkbindings/io/IOSuite.scala b/spark/src/test/scala/org/apache/mahout/sparkbindings/io/IOSuite.scala new file mode 100644 index 0000000..f3a9721 --- /dev/null +++ b/spark/src/test/scala/org/apache/mahout/sparkbindings/io/IOSuite.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.sparkbindings.io + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.{Input, Output} +import com.twitter.chill.AllScalaRegistrar +import org.apache.mahout.math._ +import scalabindings._ +import RLikeOps._ + +import org.apache.mahout.common.RandomUtils +import org.apache.mahout.test.MahoutSuite +import org.scalatest.FunSuite + +import scala.util.Random + +class IOSuite extends FunSuite with MahoutSuite { + + import IOSuite._ + + test("Dense vector kryo") { + + val rnd = RandomUtils.getRandom + val vec = new DenseVector(165) := { _ => rnd.nextDouble()} + + val ret = kryoClone(vec, vec, vec) + val vec2 = ret(2) + + println(s"vec=$vec\nvc2=$vec2") + + vec2 === vec shouldBe true + vec2.isInstanceOf[DenseVector] shouldBe true + } + + test("Random sparse vector kryo") { + + val rnd = RandomUtils.getRandom + val vec = new RandomAccessSparseVector(165) := { _ => if (rnd.nextDouble() < 0.3) rnd.nextDouble() else 0} + val vec1 = new RandomAccessSparseVector(165) + vec1(2) = 2 + vec1(3) = 4 + vec1(3) = 0 + vec1(10) = 30 + + val ret = kryoClone(vec, vec1, vec) + val (vec2, vec3) = (ret(2), ret(1)) + + println(s"vec=$vec\nvc2=$vec2") + + vec2 === vec shouldBe true + vec1 === vec3 shouldBe true + vec2.isInstanceOf[RandomAccessSparseVector] shouldBe true + + } + + test("100% sparse vectors") { + + val vec1 = new SequentialAccessSparseVector(10) + val vec2 = new RandomAccessSparseVector(6) + val ret = kryoClone(vec1, vec2, vec1, vec2) + val vec3 = ret(2) + val vec4 = ret(3) + + vec1 === vec3 shouldBe true + vec2 === vec4 shouldBe true + } + + test("Sequential sparse vector kryo") { + + val rnd = RandomUtils.getRandom + val vec = new SequentialAccessSparseVector(165) := { _ => if (rnd.nextDouble() < 0.3) rnd.nextDouble() else 0} + + val vec1 = new SequentialAccessSparseVector(165) + vec1(2) = 0 + vec1(3) = 3 + vec1(4) = 2 + vec1(3) = 0 + + val ret = kryoClone(vec, vec1, vec) + val (vec2, vec3) = (ret(2), ret(1)) + + println(s"vec=$vec\nvc2=$vec2") + + vec2 === vec shouldBe true + vec1 === vec3 shouldBe true + vec2.isInstanceOf[SequentialAccessSparseVector] shouldBe true + } + + test("kryo matrix tests") { + val rnd = new Random() + + val mxA = new DenseMatrix(140, 150) := { _ => rnd.nextDouble()} + + val mxB = new SparseRowMatrix(140, 150) := { _ => if (rnd.nextDouble() < .3) rnd.nextDouble() else 0.0} + + val mxC = new SparseMatrix(140, 150) + for (i <- 0 until mxC.nrow) if (rnd.nextDouble() < .3) + mxC(i, ::) := { _ => if (rnd.nextDouble() < .3) rnd.nextDouble() else 0.0} + + val cnsl = mxC.numSlices() + println(s"Number of slices in mxC: ${cnsl}") + + val ret = kryoClone(mxA, mxA.t, mxB, mxB.t, mxC, mxC.t, mxA) + + val (mxAA, mxAAt, mxBB, mxBBt, mxCC, mxCCt, mxAAA) = (ret(0), ret(1), ret(2), ret(3), ret(4), ret(5), ret(6)) + + // ret.size shouldBe 7 + + mxA === mxAA shouldBe true + mxA === mxAAA shouldBe true + mxA === mxAAt.t shouldBe true + mxAA.isInstanceOf[DenseMatrix] shouldBe true + mxAAt.isInstanceOf[DenseMatrix] shouldBe false + + + mxB === mxBB shouldBe true + mxB === mxBBt.t shouldBe true + mxBB.isInstanceOf[SparseRowMatrix] shouldBe true + mxBBt.isInstanceOf[SparseRowMatrix] shouldBe false + mxBB(0,::).isDense shouldBe false + + + // Assert no persistence operation increased slice sparsity + mxC.numSlices() shouldBe cnsl + + // Assert deserialized product did not experience any empty slice inflation + mxCC.numSlices() shouldBe cnsl + mxCCt.t.numSlices() shouldBe cnsl + + // Incidentally, but not very significantly, iterating thru all rows that happens in equivalence + // operator, inserts empty rows into SparseMatrix so these asserts should not be before numSlices + // asserts. + mxC === mxCC shouldBe true + mxC === mxCCt.t shouldBe true + mxCCt.t.isInstanceOf[SparseMatrix] shouldBe true + + // Column-wise sparse matrix are deprecated and should be explicitly rejected by serializer. + an[IllegalArgumentException] should be thrownBy { + val mxDeprecated = new SparseColumnMatrix(14, 15) + kryoClone(mxDeprecated) + } + + } + + test("diag matrix") { + + val mxD = diagv(dvec(1, 2, 3, 5)) + val mxDD = kryoClone(mxD)(0) + mxD === mxDD shouldBe true + mxDD.isInstanceOf[DiagonalMatrix] shouldBe true + + } +} + +object IOSuite { + + def kryoClone[T](obj: T*): Seq[T] = { + + val kryo = new Kryo() + new AllScalaRegistrar()(kryo) + + MahoutKryoRegistrator.registerClasses(kryo) + + val baos = new ByteArrayOutputStream() + val output = new Output(baos) + obj.foreach(kryo.writeClassAndObject(output, _)) + output.close + + val input = new Input(new ByteArrayInputStream(baos.toByteArray)) + + def outStream: Stream[T] = + if (input.eof) Stream.empty + else kryo.readClassAndObject(input).asInstanceOf[T] #:: outStream + + outStream + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/test/scala/org/apache/mahout/sparkbindings/test/DistributedSparkSuite.scala ---------------------------------------------------------------------- diff --git a/spark/src/test/scala/org/apache/mahout/sparkbindings/test/DistributedSparkSuite.scala b/spark/src/test/scala/org/apache/mahout/sparkbindings/test/DistributedSparkSuite.scala index f18ec70..d917a22 100644 --- a/spark/src/test/scala/org/apache/mahout/sparkbindings/test/DistributedSparkSuite.scala +++ b/spark/src/test/scala/org/apache/mahout/sparkbindings/test/DistributedSparkSuite.scala @@ -17,11 +17,13 @@ package org.apache.mahout.sparkbindings.test +import org.apache.log4j.{Level, Logger} import org.scalatest.{ConfigMap, BeforeAndAfterAllConfigMap, Suite} import org.apache.spark.SparkConf import org.apache.mahout.sparkbindings._ import org.apache.mahout.test.{DistributedMahoutSuite, MahoutSuite} import org.apache.mahout.math.drm.DistributedContext +import collection.JavaConversions._ trait DistributedSparkSuite extends DistributedMahoutSuite with LoggerConfiguration { this: Suite => @@ -30,16 +32,21 @@ trait DistributedSparkSuite extends DistributedMahoutSuite with LoggerConfigurat protected var masterUrl = null.asInstanceOf[String] protected def initContext() { - masterUrl = "local[3]" + masterUrl = System.getProperties.getOrElse("test.spark.master", "local[3]") + val isLocal = masterUrl.startsWith("local") mahoutCtx = mahoutSparkContext(masterUrl = this.masterUrl, - appName = "MahoutLocalContext", + appName = "MahoutUnitTests", // Do not run MAHOUT_HOME jars in unit tests. - addMahoutJars = false, + addMahoutJars = !isLocal, sparkConf = new SparkConf() - .set("spark.kryoserializer.buffer.mb", "15") + .set("spark.kryoserializer.buffer.mb", "40") .set("spark.akka.frameSize", "30") .set("spark.default.parallelism", "10") + .set("spark.executor.memory", "2G") ) + // Spark reconfigures logging. Clamp down on it in tests. + Logger.getRootLogger.setLevel(Level.ERROR) + Logger.getLogger("org.apache.spark").setLevel(Level.WARN) } protected def resetContext() { http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/test/scala/org/apache/mahout/sparkbindings/test/LoggerConfiguration.scala ---------------------------------------------------------------------- diff --git a/spark/src/test/scala/org/apache/mahout/sparkbindings/test/LoggerConfiguration.scala b/spark/src/test/scala/org/apache/mahout/sparkbindings/test/LoggerConfiguration.scala index e48e7c7..2a996d7 100644 --- a/spark/src/test/scala/org/apache/mahout/sparkbindings/test/LoggerConfiguration.scala +++ b/spark/src/test/scala/org/apache/mahout/sparkbindings/test/LoggerConfiguration.scala @@ -25,6 +25,6 @@ trait LoggerConfiguration extends org.apache.mahout.test.LoggerConfiguration { override protected def beforeAll(configMap: ConfigMap) { super.beforeAll(configMap) - Logger.getLogger("org.apache.mahout.sparkbindings").setLevel(Level.INFO) + BasicConfigurator.resetConfiguration() } }
