Repository: mahout Updated Branches: refs/heads/flink-binding 7039f4c5f -> f5a4a9762
MAHOUT-1776 Refactor common Engine agnostic classes to Math-Scala module closes apache/mahout#163 Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/f5a4a976 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/f5a4a976 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/f5a4a976 Branch: refs/heads/flink-binding Commit: f5a4a976288a3ec10942f1b28ea793bacac33955 Parents: 7039f4c Author: smarthi <[email protected]> Authored: Fri Oct 23 18:23:43 2015 -0400 Committer: smarthi <[email protected]> Committed: Fri Oct 23 18:23:43 2015 -0400 ---------------------------------------------------------------------- math-scala/pom.xml | 6 + .../common/io/GenericMatrixKryoSerializer.scala | 188 ++++++++++++++ .../mahout/common/io/VectorKryoSerializer.scala | 248 ++++++++++++++++++ .../apache/mahout/common/Hadoop1HDFSUtil.scala | 8 +- .../io/GenericMatrixKryoSerializer.scala | 189 -------------- .../io/MahoutKryoRegistrator.scala | 1 + .../sparkbindings/io/VectorKryoSerializer.scala | 252 ------------------- .../sparkbindings/SparkBindingsSuite.scala | 6 +- .../mahout/sparkbindings/io/IOSuite.scala | 6 +- 9 files changed, 452 insertions(+), 452 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/f5a4a976/math-scala/pom.xml ---------------------------------------------------------------------- diff --git a/math-scala/pom.xml b/math-scala/pom.xml index e8c0357..0124612 100644 --- a/math-scala/pom.xml +++ b/math-scala/pom.xml @@ -122,6 +122,12 @@ <artifactId>mahout-math</artifactId> </dependency> + <dependency> + <groupId>com.esotericsoftware.kryo</groupId> + <artifactId>kryo</artifactId> + <version>2.21</version> + </dependency> + <!-- 3rd-party --> <dependency> <groupId>log4j</groupId> http://git-wip-us.apache.org/repos/asf/mahout/blob/f5a4a976/math-scala/src/main/scala/org/apache/mahout/common/io/GenericMatrixKryoSerializer.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/common/io/GenericMatrixKryoSerializer.scala b/math-scala/src/main/scala/org/apache/mahout/common/io/GenericMatrixKryoSerializer.scala new file mode 100644 index 0000000..534d37c --- /dev/null +++ b/math-scala/src/main/scala/org/apache/mahout/common/io/GenericMatrixKryoSerializer.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.io + +import com.esotericsoftware.kryo.io.{Input, Output} +import com.esotericsoftware.kryo.{Kryo, Serializer} +import org.apache.log4j.Logger +import org.apache.mahout.logging._ +import org.apache.mahout.math._ +import org.apache.mahout.math.flavor.TraversingStructureEnum +import org.apache.mahout.math.scalabindings.RLikeOps._ +import org.apache.mahout.math.scalabindings._ + +import scala.collection.JavaConversions._ + +object GenericMatrixKryoSerializer { + + private implicit final val log = Logger.getLogger(classOf[GenericMatrixKryoSerializer]) + +} + +/** Serializes Sparse or Dense in-core generic matrix (row-wise or column-wise backed) */ +class GenericMatrixKryoSerializer extends Serializer[Matrix] { + + import GenericMatrixKryoSerializer._ + + override def write(kryo: Kryo, output: Output, mx: Matrix): Unit = { + + debug(s"Writing mx of type ${mx.getClass.getName}") + + val structure = mx.getFlavor.getStructure + + // Write structure bit + output.writeInt(structure.ordinal(), true) + + // Write geometry + output.writeInt(mx.nrow, true) + output.writeInt(mx.ncol, true) + + // Write in most efficient traversal order (using backing vectors perhaps) + structure match { + case TraversingStructureEnum.COLWISE => writeRowWise(kryo, output, mx.t) + case TraversingStructureEnum.SPARSECOLWISE => writeSparseRowWise(kryo, output, mx.t) + case TraversingStructureEnum.SPARSEROWWISE => writeSparseRowWise(kryo, output, mx) + case TraversingStructureEnum.VECTORBACKED => writeVectorBacked(kryo, output, mx) + case _ => writeRowWise(kryo, output, mx) + } + + } + + private def writeVectorBacked(kryo: Kryo, output: Output, mx: Matrix) { + + require(mx != null) + + // At this point we are just doing some vector-backed classes individually. TODO: create + // api to obtain vector-backed matrix data. + kryo.writeClass(output, mx.getClass) + mx match { + case mxD: DiagonalMatrix => kryo.writeObject(output, mxD.diagv) + case mxS: DenseSymmetricMatrix => kryo.writeObject(output, dvec(mxS.getData)) + case mxT: UpperTriangular => kryo.writeObject(output, dvec(mxT.getData)) + case _ => throw new IllegalArgumentException(s"Unsupported matrix type:${mx.getClass.getName}") + } + } + + private def readVectorBacked(kryo: Kryo, input: Input, nrow: Int, ncol: Int) = { + + // We require vector-backed matrices to have vector-parameterized constructor to construct. + val clazz = kryo.readClass(input).getType + + debug(s"Deserializing vector-backed mx of type ${clazz.getName}.") + + clazz.getConstructor(classOf[Vector]).newInstance(kryo.readObject(input, classOf[Vector])).asInstanceOf[Matrix] + } + + private def writeRowWise(kryo: Kryo, output: Output, mx: Matrix): Unit = { + for (row <- mx) kryo.writeObject(output, row) + } + + private def readRows(kryo: Kryo, input: Input, nrow: Int) = { + Array.tabulate(nrow) { _ => kryo.readObject(input, classOf[Vector])} + } + + private def readSparseRows(kryo: Kryo, input: Input) = { + + // Number of slices + val nslices = input.readInt(true) + + Array.tabulate(nslices) { _ => + input.readInt(true) -> kryo.readObject(input, classOf[Vector]) + } + } + + private def writeSparseRowWise(kryo: Kryo, output: Output, mx: Matrix): Unit = { + + val nslices = mx.numSlices() + + output.writeInt(nslices, true) + + var actualNSlices = 0 + for (row <- mx.iterateNonEmpty()) { + output.writeInt(row.index(), true) + kryo.writeObject(output, row.vector()) + actualNSlices += 1 + } + + require(nslices == actualNSlices, "Number of slices reported by Matrix.numSlices() was different from actual " + + "slice iterator size.") + } + + override def read(kryo: Kryo, input: Input, mxClass: Class[Matrix]): Matrix = { + + // Read structure hint + val structure = TraversingStructureEnum.values()(input.readInt(true)) + + // Read geometry + val nrow = input.readInt(true) + val ncol = input.readInt(true) + + debug(s"read matrix geometry: $nrow x $ncol.") + + structure match { + + // Sparse or dense column wise + case TraversingStructureEnum.COLWISE => + val cols = readRows(kryo, input, ncol) + + if (!cols.isEmpty && cols.head.isDense) + dense(cols).t + else { + debug("Deserializing as SparseRowMatrix.t (COLWISE).") + new SparseRowMatrix(ncol, nrow, cols, true, false).t + } + + // transposed SparseMatrix case + case TraversingStructureEnum.SPARSECOLWISE => + val cols = readSparseRows(kryo, input) + val javamap = new java.util.HashMap[Integer, Vector]((cols.size << 1) + 1) + cols.foreach { case (idx, vec) => javamap.put(idx, vec)} + + debug("Deserializing as SparseMatrix.t (SPARSECOLWISE).") + new SparseMatrix(ncol, nrow, javamap, true).t + + // Sparse Row-wise -- this will be created as a SparseMatrix. + case TraversingStructureEnum.SPARSEROWWISE => + val rows = readSparseRows(kryo, input) + val javamap = new java.util.HashMap[Integer, Vector]((rows.size << 1) + 1) + rows.foreach { case (idx, vec) => javamap.put(idx, vec)} + + debug("Deserializing as SparseMatrix (SPARSEROWWISE).") + new SparseMatrix(nrow, ncol, javamap, true) + case TraversingStructureEnum.VECTORBACKED => + + debug("Deserializing vector-backed...") + readVectorBacked(kryo, input, nrow, ncol) + + // By default, read row-wise. + case _ => + val cols = readRows(kryo, input, nrow) + // this still copies a lot of stuff... + if (!cols.isEmpty && cols.head.isDense) { + + debug("Deserializing as DenseMatrix.") + dense(cols) + } else { + + debug("Deserializing as SparseRowMatrix(default).") + new SparseRowMatrix(nrow, ncol, cols, true, false) + } + } + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/f5a4a976/math-scala/src/main/scala/org/apache/mahout/common/io/VectorKryoSerializer.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/common/io/VectorKryoSerializer.scala b/math-scala/src/main/scala/org/apache/mahout/common/io/VectorKryoSerializer.scala new file mode 100644 index 0000000..3cc537c --- /dev/null +++ b/math-scala/src/main/scala/org/apache/mahout/common/io/VectorKryoSerializer.scala @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.io + +import com.esotericsoftware.kryo.io.{Input, Output} +import com.esotericsoftware.kryo.{Kryo, Serializer} +import org.apache.mahout.logging._ +import org.apache.mahout.math._ +import org.apache.mahout.math.scalabindings.RLikeOps._ + +import scala.collection.JavaConversions._ + + +object VectorKryoSerializer { + + final val FLAG_DENSE: Int = 0x01 + final val FLAG_SEQUENTIAL: Int = 0x02 + final val FLAG_NAMED: Int = 0x04 + final val FLAG_LAX_PRECISION: Int = 0x08 + + private final implicit val log = getLog(classOf[VectorKryoSerializer]) + +} + +class VectorKryoSerializer(val laxPrecision: Boolean = false) extends Serializer[Vector] { + + import VectorKryoSerializer._ + + override def write(kryo: Kryo, output: Output, vector: Vector): Unit = { + + require(vector != null) + + trace(s"Serializing vector of ${vector.getClass.getName} class.") + + // Write length + val len = vector.length + output.writeInt(len, true) + + // Interrogate vec properties + val dense = vector.isDense + val sequential = vector.isSequentialAccess + val named = vector.isInstanceOf[NamedVector] + + var flag = 0 + + if (dense) { + flag |= FLAG_DENSE + } else if (sequential) { + flag |= FLAG_SEQUENTIAL + } + + if (vector.isInstanceOf[NamedVector]) { + flag |= FLAG_NAMED + } + + if (laxPrecision) flag |= FLAG_LAX_PRECISION + + // Write flags + output.writeByte(flag) + + // Write name if needed + if (named) output.writeString(vector.asInstanceOf[NamedVector].getName) + + dense match { + + // Dense vector. + case true => + + laxPrecision match { + case true => + for (i <- 0 until vector.length) output.writeFloat(vector(i).toFloat) + case _ => + for (i <- 0 until vector.length) output.writeDouble(vector(i)) + } + case _ => + + // Turns out getNumNonZeroElements must check every element if it is indeed non-zero. The + // iterateNonZeros() on the other hand doesn't do that, so that's all inconsistent right + // now. so we'll just auto-terminate. + val iter = vector.nonZeroes.toIterator.filter(_.get() != 0.0) + + sequential match { + + // Delta encoding + case true => + + var idx = 0 + laxPrecision match { + case true => + while (iter.hasNext) { + val el = iter.next() + output.writeFloat(el.toFloat) + output.writeInt(el.index() - idx, true) + idx = el.index + } + // Terminate delta encoding. + output.writeFloat(0.0.toFloat) + case _ => + while (iter.hasNext) { + val el = iter.next() + output.writeDouble(el.get()) + output.writeInt(el.index() - idx, true) + idx = el.index + } + // Terminate delta encoding. + output.writeDouble(0.0) + } + + // Random access. + case _ => + + laxPrecision match { + + case true => + iter.foreach { el => + output.writeFloat(el.get().toFloat) + output.writeInt(el.index(), true) + } + // Terminate random access with 0.0 value. + output.writeFloat(0.0.toFloat) + case _ => + iter.foreach { el => + output.writeDouble(el.get()) + output.writeInt(el.index(), true) + } + // Terminate random access with 0.0 value. + output.writeDouble(0.0) + } + + } + + } + } + + override def read(kryo: Kryo, input: Input, vecClass: Class[Vector]): Vector = { + + val len = input.readInt(true) + val flags = input.readByte().toInt + val name = if ((flags & FLAG_NAMED) != 0) Some(input.readString()) else None + + val vec: Vector = flags match { + + // Dense + case _: Int if (flags & FLAG_DENSE) != 0 => + + trace(s"Deserializing dense vector.") + + if ((flags & FLAG_LAX_PRECISION) != 0) { + new DenseVector(len) := { _ => input.readFloat()} + } else { + new DenseVector(len) := { _ => input.readDouble()} + } + + // Sparse case. + case _ => + + flags match { + + // Sequential. + case _: Int if (flags & FLAG_SEQUENTIAL) != 0 => + + trace("Deserializing as sequential sparse vector.") + + val v = new SequentialAccessSparseVector(len) + var idx = 0 + var stop = false + + if ((flags & FLAG_LAX_PRECISION) != 0) { + + while (!stop) { + val value = input.readFloat() + if (value == 0.0) { + stop = true + } else { + idx += input.readInt(true) + v(idx) = value + } + } + } else { + while (!stop) { + val value = input.readDouble() + if (value == 0.0) { + stop = true + } else { + idx += input.readInt(true) + v(idx) = value + } + } + } + v + + // Random access + case _ => + + trace("Deserializing as random access vector.") + + // Read pairs until we see 0.0 value. Prone to corruption attacks obviously. + val v = new RandomAccessSparseVector(len) + var stop = false + if ((flags & FLAG_LAX_PRECISION) != 0) { + while (! stop ) { + val value = input.readFloat() + if ( value == 0.0 ) { + stop = true + } else { + val idx = input.readInt(true) + v(idx) = value + } + } + } else { + while (! stop ) { + val value = input.readDouble() + if (value == 0.0) { + stop = true + } else { + val idx = input.readInt(true) + v(idx) = value + } + } + } + v + } + } + + name.map{name => + + trace(s"Recovering named vector's name $name.") + + new NamedVector(vec, name) + } + .getOrElse(vec) + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/f5a4a976/spark/src/main/scala/org/apache/mahout/common/Hadoop1HDFSUtil.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/common/Hadoop1HDFSUtil.scala b/spark/src/main/scala/org/apache/mahout/common/Hadoop1HDFSUtil.scala index 399508d..29599b8 100644 --- a/spark/src/main/scala/org/apache/mahout/common/Hadoop1HDFSUtil.scala +++ b/spark/src/main/scala/org/apache/mahout/common/Hadoop1HDFSUtil.scala @@ -17,12 +17,10 @@ package org.apache.mahout.common -import org.apache.hadoop.io.{Writable, SequenceFile} -import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.{SequenceFile, Writable} import org.apache.spark.SparkContext -import collection._ -import JavaConversions._ /** * Deprecated Hadoop 1 api which we currently explicitly import via Mahout dependencies. May not work @@ -44,7 +42,7 @@ object Hadoop1HDFSUtil extends HDFSUtil { val partFilePath:Path = fs.listStatus(dfsPath) // Filter out anything starting with . - .filter { s => (!s.getPath.getName.startsWith("\\.") && !s.getPath.getName.startsWith("_") && !s.isDir)} + .filter { s => !s.getPath.getName.startsWith("\\.") && !s.getPath.getName.startsWith("_") && !s.isDir } // Take path .map(_.getPath) http://git-wip-us.apache.org/repos/asf/mahout/blob/f5a4a976/spark/src/main/scala/org/apache/mahout/sparkbindings/io/GenericMatrixKryoSerializer.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/GenericMatrixKryoSerializer.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/GenericMatrixKryoSerializer.scala deleted file mode 100644 index da58b35..0000000 --- a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/GenericMatrixKryoSerializer.scala +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.sparkbindings.io - - -import com.esotericsoftware.kryo.io.{Output, Input} -import com.esotericsoftware.kryo.{Kryo, Serializer} -import org.apache.log4j.Logger -import org.apache.mahout.logging._ -import org.apache.mahout.math._ -import org.apache.mahout.math.flavor.TraversingStructureEnum -import scalabindings._ -import RLikeOps._ -import collection._ -import JavaConversions._ - -object GenericMatrixKryoSerializer { - - private implicit final val log = Logger.getLogger(classOf[GenericMatrixKryoSerializer]) - -} - -/** Serializes Sparse or Dense in-core generic matrix (row-wise or column-wise backed) */ -class GenericMatrixKryoSerializer extends Serializer[Matrix] { - - import GenericMatrixKryoSerializer._ - - override def write(kryo: Kryo, output: Output, mx: Matrix): Unit = { - - debug(s"Writing mx of type ${mx.getClass.getName}") - - val structure = mx.getFlavor.getStructure - - // Write structure bit - output.writeInt(structure.ordinal(), true) - - // Write geometry - output.writeInt(mx.nrow, true) - output.writeInt(mx.ncol, true) - - // Write in most efficient traversal order (using backing vectors perhaps) - structure match { - case TraversingStructureEnum.COLWISE => writeRowWise(kryo, output, mx.t) - case TraversingStructureEnum.SPARSECOLWISE => writeSparseRowWise(kryo, output, mx.t) - case TraversingStructureEnum.SPARSEROWWISE => writeSparseRowWise(kryo, output, mx) - case TraversingStructureEnum.VECTORBACKED => writeVectorBacked(kryo, output, mx) - case _ => writeRowWise(kryo, output, mx) - } - - } - - private def writeVectorBacked(kryo: Kryo, output: Output, mx: Matrix) { - - require(mx != null) - - // At this point we are just doing some vector-backed classes individually. TODO: create - // api to obtain vector-backed matrix data. - kryo.writeClass(output, mx.getClass) - mx match { - case mxD: DiagonalMatrix => kryo.writeObject(output, mxD.diagv) - case mxS: DenseSymmetricMatrix => kryo.writeObject(output, dvec(mxS.getData)) - case mxT: UpperTriangular => kryo.writeObject(output, dvec(mxT.getData)) - case _ => throw new IllegalArgumentException(s"Unsupported matrix type:${mx.getClass.getName}") - } - } - - private def readVectorBacked(kryo: Kryo, input: Input, nrow: Int, ncol: Int) = { - - // We require vector-backed matrices to have vector-parameterized constructor to construct. - val clazz = kryo.readClass(input).getType - - debug(s"Deserializing vector-backed mx of type ${clazz.getName}.") - - clazz.getConstructor(classOf[Vector]).newInstance(kryo.readObject(input, classOf[Vector])).asInstanceOf[Matrix] - } - - private def writeRowWise(kryo: Kryo, output: Output, mx: Matrix): Unit = { - for (row <- mx) kryo.writeObject(output, row) - } - - private def readRows(kryo: Kryo, input: Input, nrow: Int) = { - Array.tabulate(nrow) { _ => kryo.readObject(input, classOf[Vector])} - } - - private def readSparseRows(kryo: Kryo, input: Input) = { - - // Number of slices - val nslices = input.readInt(true) - - Array.tabulate(nslices) { _ => - input.readInt(true) -> kryo.readObject(input, classOf[Vector]) - } - } - - private def writeSparseRowWise(kryo: Kryo, output: Output, mx: Matrix): Unit = { - - val nslices = mx.numSlices() - - output.writeInt(nslices, true) - - var actualNSlices = 0; - for (row <- mx.iterateNonEmpty()) { - output.writeInt(row.index(), true) - kryo.writeObject(output, row.vector()) - actualNSlices += 1 - } - - require(nslices == actualNSlices, "Number of slices reported by Matrix.numSlices() was different from actual " + - "slice iterator size.") - } - - override def read(kryo: Kryo, input: Input, mxClass: Class[Matrix]): Matrix = { - - // Read structure hint - val structure = TraversingStructureEnum.values()(input.readInt(true)) - - // Read geometry - val nrow = input.readInt(true) - val ncol = input.readInt(true) - - debug(s"read matrix geometry: $nrow x $ncol.") - - structure match { - - // Sparse or dense column wise - case TraversingStructureEnum.COLWISE => - val cols = readRows(kryo, input, ncol) - - if (!cols.isEmpty && cols.head.isDense) - dense(cols).t - else { - debug("Deserializing as SparseRowMatrix.t (COLWISE).") - new SparseRowMatrix(ncol, nrow, cols, true, false).t - } - - // transposed SparseMatrix case - case TraversingStructureEnum.SPARSECOLWISE => - val cols = readSparseRows(kryo, input) - val javamap = new java.util.HashMap[Integer, Vector]((cols.size << 1) + 1) - cols.foreach { case (idx, vec) => javamap.put(idx, vec)} - - debug("Deserializing as SparseMatrix.t (SPARSECOLWISE).") - new SparseMatrix(ncol, nrow, javamap, true).t - - // Sparse Row-wise -- this will be created as a SparseMatrix. - case TraversingStructureEnum.SPARSEROWWISE => - val rows = readSparseRows(kryo, input) - val javamap = new java.util.HashMap[Integer, Vector]((rows.size << 1) + 1) - rows.foreach { case (idx, vec) => javamap.put(idx, vec)} - - debug("Deserializing as SparseMatrix (SPARSEROWWISE).") - new SparseMatrix(nrow, ncol, javamap, true) - case TraversingStructureEnum.VECTORBACKED => - - debug("Deserializing vector-backed...") - readVectorBacked(kryo, input, nrow, ncol) - - // By default, read row-wise. - case _ => - val cols = readRows(kryo, input, nrow) - // this still copies a lot of stuff... - if (!cols.isEmpty && cols.head.isDense) { - - debug("Deserializing as DenseMatrix.") - dense(cols) - } else { - - debug("Deserializing as SparseRowMatrix(default).") - new SparseRowMatrix(nrow, ncol, cols, true, false) - } - } - - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/f5a4a976/spark/src/main/scala/org/apache/mahout/sparkbindings/io/MahoutKryoRegistrator.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/MahoutKryoRegistrator.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/MahoutKryoRegistrator.scala index 5806ff5..4e0e061 100644 --- a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/MahoutKryoRegistrator.scala +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/MahoutKryoRegistrator.scala @@ -18,6 +18,7 @@ package org.apache.mahout.sparkbindings.io import com.esotericsoftware.kryo.Kryo +import org.apache.mahout.common.io.{VectorKryoSerializer, GenericMatrixKryoSerializer} import org.apache.mahout.math._ import org.apache.spark.serializer.KryoRegistrator import org.apache.mahout.logging._ http://git-wip-us.apache.org/repos/asf/mahout/blob/f5a4a976/spark/src/main/scala/org/apache/mahout/sparkbindings/io/VectorKryoSerializer.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/VectorKryoSerializer.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/VectorKryoSerializer.scala deleted file mode 100644 index 175778f..0000000 --- a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/VectorKryoSerializer.scala +++ /dev/null @@ -1,252 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.sparkbindings.io - -import org.apache.log4j.Logger -import org.apache.mahout.logging._ -import org.apache.mahout.math._ -import org.apache.mahout.math.scalabindings._ -import RLikeOps._ - -import com.esotericsoftware.kryo.io.{OutputChunked, Output, Input} -import com.esotericsoftware.kryo.{Kryo, Serializer} - -import collection._ -import JavaConversions._ - - -object VectorKryoSerializer { - - final val FLAG_DENSE: Int = 0x01 - final val FLAG_SEQUENTIAL: Int = 0x02 - final val FLAG_NAMED: Int = 0x04 - final val FLAG_LAX_PRECISION: Int = 0x08 - - private final implicit val log = getLog(classOf[VectorKryoSerializer]) - -} - -class VectorKryoSerializer(val laxPrecision: Boolean = false) extends Serializer[Vector] { - - import VectorKryoSerializer._ - - override def write(kryo: Kryo, output: Output, vector: Vector): Unit = { - - require(vector != null) - - trace(s"Serializing vector of ${vector.getClass.getName} class.") - - // Write length - val len = vector.length - output.writeInt(len, true) - - // Interrogate vec properties - val dense = vector.isDense - val sequential = vector.isSequentialAccess - val named = vector.isInstanceOf[NamedVector] - - var flag = 0 - - if (dense) { - flag |= FLAG_DENSE - } else if (sequential) { - flag |= FLAG_SEQUENTIAL - } - - if (vector.isInstanceOf[NamedVector]) { - flag |= FLAG_NAMED - } - - if (laxPrecision) flag |= FLAG_LAX_PRECISION - - // Write flags - output.writeByte(flag) - - // Write name if needed - if (named) output.writeString(vector.asInstanceOf[NamedVector].getName) - - dense match { - - // Dense vector. - case true => - - laxPrecision match { - case true => - for (i <- 0 until vector.length) output.writeFloat(vector(i).toFloat) - case _ => - for (i <- 0 until vector.length) output.writeDouble(vector(i)) - } - case _ => - - // Turns out getNumNonZeroElements must check every element if it is indeed non-zero. The - // iterateNonZeros() on the other hand doesn't do that, so that's all inconsistent right - // now. so we'll just auto-terminate. - val iter = vector.nonZeroes.toIterator.filter(_.get() != 0.0) - - sequential match { - - // Delta encoding - case true => - - var idx = 0 - laxPrecision match { - case true => - while (iter.hasNext) { - val el = iter.next() - output.writeFloat(el.toFloat) - output.writeInt(el.index() - idx, true) - idx = el.index - } - // Terminate delta encoding. - output.writeFloat(0.0.toFloat) - case _ => - while (iter.hasNext) { - val el = iter.next() - output.writeDouble(el.get()) - output.writeInt(el.index() - idx, true) - idx = el.index - } - // Terminate delta encoding. - output.writeDouble(0.0) - } - - // Random access. - case _ => - - laxPrecision match { - - case true => - iter.foreach { el => - output.writeFloat(el.get().toFloat) - output.writeInt(el.index(), true) - } - // Terminate random access with 0.0 value. - output.writeFloat(0.0.toFloat) - case _ => - iter.foreach { el => - output.writeDouble(el.get()) - output.writeInt(el.index(), true) - } - // Terminate random access with 0.0 value. - output.writeDouble(0.0) - } - - } - - } - } - - override def read(kryo: Kryo, input: Input, vecClass: Class[Vector]): Vector = { - - val len = input.readInt(true) - val flags = input.readByte().toInt - val name = if ((flags & FLAG_NAMED) != 0) Some(input.readString()) else None - - val vec: Vector = flags match { - - // Dense - case _: Int if ((flags & FLAG_DENSE) != 0) => - - trace(s"Deserializing dense vector.") - - if ((flags & FLAG_LAX_PRECISION) != 0) { - new DenseVector(len) := { _ => input.readFloat()} - } else { - new DenseVector(len) := { _ => input.readDouble()} - } - - // Sparse case. - case _ => - - flags match { - - // Sequential. - case _: Int if ((flags & FLAG_SEQUENTIAL) != 0) => - - trace("Deserializing as sequential sparse vector.") - - val v = new SequentialAccessSparseVector(len) - var idx = 0 - var stop = false - - if ((flags & FLAG_LAX_PRECISION) != 0) { - - while (!stop) { - val value = input.readFloat() - if (value == 0.0) { - stop = true - } else { - idx += input.readInt(true) - v(idx) = value - } - } - } else { - while (!stop) { - val value = input.readDouble() - if (value == 0.0) { - stop = true - } else { - idx += input.readInt(true) - v(idx) = value - } - } - } - v - - // Random access - case _ => - - trace("Deserializing as random access vector.") - - // Read pairs until we see 0.0 value. Prone to corruption attacks obviously. - val v = new RandomAccessSparseVector(len) - var stop = false - if ((flags & FLAG_LAX_PRECISION) != 0) { - while (! stop ) { - val value = input.readFloat() - if ( value == 0.0 ) { - stop = true - } else { - val idx = input.readInt(true) - v(idx) = value - } - } - } else { - while (! stop ) { - val value = input.readDouble() - if (value == 0.0) { - stop = true - } else { - val idx = input.readInt(true) - v(idx) = value - } - } - } - v - } - } - - name.map{name => - - trace(s"Recovering named vector's name ${name}.") - - new NamedVector(vec, name) - } - .getOrElse(vec) - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/f5a4a976/spark/src/test/scala/org/apache/mahout/sparkbindings/SparkBindingsSuite.scala ---------------------------------------------------------------------- diff --git a/spark/src/test/scala/org/apache/mahout/sparkbindings/SparkBindingsSuite.scala b/spark/src/test/scala/org/apache/mahout/sparkbindings/SparkBindingsSuite.scala index 529d13c..61244a1 100644 --- a/spark/src/test/scala/org/apache/mahout/sparkbindings/SparkBindingsSuite.scala +++ b/spark/src/test/scala/org/apache/mahout/sparkbindings/SparkBindingsSuite.scala @@ -1,12 +1,12 @@ package org.apache.mahout.sparkbindings import java.io.{Closeable, File} -import java.util import org.apache.mahout.sparkbindings.test.DistributedSparkSuite import org.apache.mahout.util.IOUtilsScala import org.scalatest.FunSuite -import collection._ + +import scala.collection._ /** * @author dmitriy @@ -14,7 +14,7 @@ import collection._ class SparkBindingsSuite extends FunSuite with DistributedSparkSuite { // This test will succeed only when MAHOUT_HOME is set in the environment. So we keep it for - // diagnorstic purposes around, but we probably don't want it to run in the Jenkins, so we'd + // diagnostic purposes around, but we probably don't want it to run in the Jenkins, so we'd // let it to be ignored. ignore("context jars") { System.setProperty("mahout.home", new File("..").getAbsolutePath/*"/home/dmitriy/projects/github/mahout-commits"*/) http://git-wip-us.apache.org/repos/asf/mahout/blob/f5a4a976/spark/src/test/scala/org/apache/mahout/sparkbindings/io/IOSuite.scala ---------------------------------------------------------------------- diff --git a/spark/src/test/scala/org/apache/mahout/sparkbindings/io/IOSuite.scala b/spark/src/test/scala/org/apache/mahout/sparkbindings/io/IOSuite.scala index f3a9721..1814f17 100644 --- a/spark/src/test/scala/org/apache/mahout/sparkbindings/io/IOSuite.scala +++ b/spark/src/test/scala/org/apache/mahout/sparkbindings/io/IOSuite.scala @@ -116,11 +116,11 @@ class IOSuite extends FunSuite with MahoutSuite { mxC(i, ::) := { _ => if (rnd.nextDouble() < .3) rnd.nextDouble() else 0.0} val cnsl = mxC.numSlices() - println(s"Number of slices in mxC: ${cnsl}") + println(s"Number of slices in mxC: $cnsl") val ret = kryoClone(mxA, mxA.t, mxB, mxB.t, mxC, mxC.t, mxA) - val (mxAA, mxAAt, mxBB, mxBBt, mxCC, mxCCt, mxAAA) = (ret(0), ret(1), ret(2), ret(3), ret(4), ret(5), ret(6)) + val (mxAA, mxAAt, mxBB, mxBBt, mxCC, mxCCt, mxAAA) = (ret.head, ret(1), ret(2), ret(3), ret(4), ret(5), ret(6)) // ret.size shouldBe 7 @@ -163,7 +163,7 @@ class IOSuite extends FunSuite with MahoutSuite { test("diag matrix") { val mxD = diagv(dvec(1, 2, 3, 5)) - val mxDD = kryoClone(mxD)(0) + val mxDD = kryoClone(mxD).head mxD === mxDD shouldBe true mxDD.isInstanceOf[DiagonalMatrix] shouldBe true
