http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSpark.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSpark.scala
 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSpark.scala
index e5a2b2a..41efc27 100644
--- 
a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSpark.scala
+++ 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSpark.scala
@@ -33,7 +33,7 @@ import org.apache.spark.SparkContext._
 
 /** ==Spark-specific optimizer-checkpointed DRM.==
   *
-  * @param rdd underlying rdd to wrap over.
+  * @param rddInput underlying rdd to wrap over.
   * @param _nrow number of rows; if unspecified, we will compute with an 
inexpensive traversal.
   * @param _ncol number of columns; if unspecified, we will try to guess with 
an inexpensive traversal.
   * @param _cacheStorageLevel storage level
@@ -44,9 +44,9 @@ import org.apache.spark.SparkContext._
   * @tparam K matrix key type (e.g. the keys of sequence files once persisted)
   */
 class CheckpointedDrmSpark[K: ClassTag](
-    val rdd: DrmRdd[K],
-    private var _nrow: Long = -1L,
-    private var _ncol: Int = -1,
+    private[sparkbindings] val rddInput: DrmRddInput[K],
+    private[sparkbindings] var _nrow: Long = -1L,
+    private[sparkbindings] var _ncol: Int = -1,
     private val _cacheStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY,
     override protected[mahout] val partitioningTag: Long = Random.nextLong(),
     private var _canHaveMissingRows: Boolean = false
@@ -63,7 +63,7 @@ class CheckpointedDrmSpark[K: ClassTag](
   private[mahout] var intFixExtra: Long = 0L
 
   private var cached: Boolean = false
-  override val context: DistributedContext = rdd.context
+  override val context: DistributedContext = rddInput.backingRdd.context
 
   /** Explicit extraction of key class Tag   */
   def keyClassTag: ClassTag[K] = implicitly[ClassTag[K]]
@@ -78,8 +78,8 @@ class CheckpointedDrmSpark[K: ClassTag](
   }
 
   def cache() = {
-    if (!cached) {
-      rdd.persist(_cacheStorageLevel)
+    if (!cached && _cacheStorageLevel != StorageLevel.NONE) {
+      rddInput.backingRdd.persist(_cacheStorageLevel)
       cached = true
     }
     this
@@ -92,7 +92,7 @@ class CheckpointedDrmSpark[K: ClassTag](
    */
   def uncache(): this.type = {
     if (cached) {
-      rdd.unpersist(blocking = false)
+      rddInput.backingRdd.unpersist(blocking = false)
       cached = false
     }
     this
@@ -115,7 +115,7 @@ class CheckpointedDrmSpark[K: ClassTag](
    */
   def collect: Matrix = {
 
-    val intRowIndices = implicitly[ClassTag[K]] == implicitly[ClassTag[Int]]
+    val intRowIndices = classTag[K] == ClassTag.Int
 
     val cols = ncol
     val rows = safeToNonNegInt(nrow)
@@ -124,7 +124,7 @@ class CheckpointedDrmSpark[K: ClassTag](
     // since currently spark #collect() requires Serializeable support,
     // we serialize DRM vectors into byte arrays on backend and restore Vector
     // instances on the front end:
-    val data = rdd.map(t => (t._1, t._2)).collect()
+    val data = rddInput.toDrmRdd().map(t => (t._1, t._2)).collect()
 
 
     val m = if (data.forall(_._2.isDense))
@@ -165,7 +165,7 @@ class CheckpointedDrmSpark[K: ClassTag](
       else if (classOf[Writable].isAssignableFrom(ktag.runtimeClass)) (x: K) 
=> x.asInstanceOf[Writable]
       else throw new IllegalArgumentException("Do not know how to convert 
class tag %s to Writable.".format(ktag))
 
-    rdd.saveAsSequenceFile(path)
+    rddInput.toDrmRdd().saveAsSequenceFile(path)
   }
 
   protected def computeNRow = {
@@ -173,7 +173,7 @@ class CheckpointedDrmSpark[K: ClassTag](
     val intRowIndex = classTag[K] == classTag[Int]
 
     if (intRowIndex) {
-      val rdd = cache().rdd.asInstanceOf[DrmRdd[Int]]
+      val rdd = cache().rddInput.toDrmRdd().asInstanceOf[DrmRdd[Int]]
 
       // I guess it is a suitable place to compute int keys consistency test 
here because we know
       // that nrow can be computed lazily, which always happens when rdd is 
already available, cached,
@@ -186,16 +186,21 @@ class CheckpointedDrmSpark[K: ClassTag](
       intFixExtra = (maxPlus1 - rowCount) max 0L
       maxPlus1
     } else
-      cache().rdd.count()
+      cache().rddInput.toDrmRdd().count()
   }
 
 
 
-  protected def computeNCol =
-    cache().rdd.map(_._2.length).fold(-1)(max(_, _))
+  protected def computeNCol = {
+    rddInput.isBlockified match {
+      case true ⇒ rddInput.toBlockifiedDrmRdd(throw new AssertionError("not 
reached"))
+        .map(_._2.ncol).reduce(max(_, _))
+      case false ⇒ 
cache().rddInput.toDrmRdd().map(_._2.length).fold(-1)(max(_, _))
+    }
+  }
 
   protected def computeNNonZero =
-    cache().rdd.map(_._2.getNumNonZeroElements.toLong).sum().toLong
+    
cache().rddInput.toDrmRdd().map(_._2.getNumNonZeroElements.toLong).sum().toLong
 
   /** Changes the number of rows in the DRM without actually touching the 
underlying data. Used to
     * redimension a DRM after it has been created, which implies some blank, 
non-existent rows.
@@ -205,8 +210,8 @@ class CheckpointedDrmSpark[K: ClassTag](
   override def newRowCardinality(n: Int): CheckpointedDrm[K] = {
     assert(n > -1)
     assert( n >= nrow)
-    val newCheckpointedDrm = drmWrap[K](rdd, n, ncol)
-    newCheckpointedDrm
+    new CheckpointedDrmSpark(rddInput = rddInput, _nrow = n, _ncol = _ncol, 
_cacheStorageLevel = _cacheStorageLevel,
+      partitioningTag = partitioningTag, _canHaveMissingRows = 
_canHaveMissingRows)
   }
 
 }

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSparkOps.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSparkOps.scala
 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSparkOps.scala
index 7cf6bd6..abcfc64 100644
--- 
a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSparkOps.scala
+++ 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSparkOps.scala
@@ -11,6 +11,6 @@ class CheckpointedDrmSparkOps[K: ClassTag](drm: 
CheckpointedDrm[K]) {
   private[sparkbindings] val sparkDrm = 
drm.asInstanceOf[CheckpointedDrmSpark[K]]
 
   /** Spark matrix customization exposure */
-  def rdd = sparkDrm.rdd
+  def rdd = sparkDrm.rddInput.toDrmRdd()
 
 }

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/DrmRddInput.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/DrmRddInput.scala 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/DrmRddInput.scala
index b72818c..d9dbada 100644
--- a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/DrmRddInput.scala
+++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/DrmRddInput.scala
@@ -23,22 +23,18 @@ import org.apache.spark.storage.StorageLevel
 import org.apache.mahout.sparkbindings._
 
 /** Encapsulates either DrmRdd[K] or BlockifiedDrmRdd[K] */
-class DrmRddInput[K: ClassTag](
-    private val rowWiseSrc: Option[( /*ncol*/ Int, /*rdd*/ DrmRdd[K])] = None,
-    private val blockifiedSrc: Option[BlockifiedDrmRdd[K]] = None
-    ) {
+class DrmRddInput[K: ClassTag](private val input: Either[DrmRdd[K], 
BlockifiedDrmRdd[K]]) {
 
-  assert(rowWiseSrc.isDefined || blockifiedSrc.isDefined, "Undefined input")
+  private[sparkbindings] lazy val backingRdd = 
input.left.getOrElse(input.right.get)
 
-  private lazy val backingRdd = 
rowWiseSrc.map(_._2).getOrElse(blockifiedSrc.get)
+  def isBlockified: Boolean = input.isRight
 
-  def isBlockified:Boolean = blockifiedSrc.isDefined
+  def isRowWise: Boolean = input.isLeft
 
-  def isRowWise:Boolean = rowWiseSrc.isDefined
+  def toDrmRdd(): DrmRdd[K] = input.left.getOrElse(deblockify(rdd = 
input.right.get))
 
-  def toDrmRdd(): DrmRdd[K] = rowWiseSrc.map(_._2).getOrElse(deblockify(rdd = 
blockifiedSrc.get))
-
-  def toBlockifiedDrmRdd() = blockifiedSrc.getOrElse(blockify(rdd = 
rowWiseSrc.get._2, blockncol = rowWiseSrc.get._1))
+  /** Use late binding for this. It may or may not be needed, depending on 
current config. */
+  def toBlockifiedDrmRdd(ncol: ⇒ Int) = input.right.getOrElse(blockify(rdd = 
input.left.get, blockncol = ncol))
 
   def sparkContext: SparkContext = backingRdd.sparkContext
 

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/SparkBCast.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/SparkBCast.scala 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/SparkBCast.scala
index ac36f60..0371f9b 100644
--- a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/SparkBCast.scala
+++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/SparkBCast.scala
@@ -22,4 +22,6 @@ import org.apache.spark.broadcast.Broadcast
 
 class SparkBCast[T](val sbcast: Broadcast[T]) extends BCast[T] with 
Serializable {
   def value: T = sbcast.value
+
+  override def close(): Unit = sbcast.unpersist()
 }

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala
index c04b306..0de5ff8 100644
--- a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala
+++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala
@@ -37,18 +37,19 @@ package object drm {
 
   private[drm] final val log = 
Logger.getLogger("org.apache.mahout.sparkbindings");
 
-  private[sparkbindings] implicit def input2drmRdd[K](input: DrmRddInput[K]): 
DrmRdd[K] = input.toDrmRdd()
+  private[sparkbindings] implicit def cpDrm2DrmRddInput[K: ClassTag](cp: 
CheckpointedDrmSpark[K]): DrmRddInput[K] =
+    cp.rddInput
 
-  private[sparkbindings] implicit def input2blockifiedDrmRdd[K](input: 
DrmRddInput[K]): BlockifiedDrmRdd[K] = input.toBlockifiedDrmRdd()
+  private[sparkbindings] implicit def cpDrmGeneric2DrmRddInput[K: 
ClassTag](cp: CheckpointedDrm[K]): DrmRddInput[K] =
+    cp.asInstanceOf[CheckpointedDrmSpark[K]]
+
+  private[sparkbindings] implicit def drmRdd2drmRddInput[K: ClassTag](rdd: 
DrmRdd[K]) = new DrmRddInput[K](Left(rdd))
+
+  private[sparkbindings] implicit def blockifiedRdd2drmRddInput[K: 
ClassTag](rdd: BlockifiedDrmRdd[K]) = new
+      DrmRddInput[K](
+    Right(rdd))
 
-  private[sparkbindings] implicit def cpDrm2DrmRddInput[K: ClassTag](cp: 
CheckpointedDrm[K]): DrmRddInput[K] =
-    new DrmRddInput(rowWiseSrc = Some(cp.ncol -> cp.rdd))
 
-//  /** Broadcast vector (Mahout vectors are not closure-friendly, use this 
instead. */
-//  private[sparkbindings] def drmBroadcast(x: Vector)(implicit sc: 
SparkContext): Broadcast[Vector] = sc.broadcast(x)
-//
-//  /** Broadcast in-core Mahout matrix. Use this instead of closure. */
-//  private[sparkbindings] def drmBroadcast(m: Matrix)(implicit sc: 
SparkContext): Broadcast[Matrix] = sc.broadcast(m)
 
   /** Implicit broadcast cast for Spark physical op implementations. */
   private[sparkbindings] implicit def bcast2val[K](bcast:Broadcast[K]):K = 
bcast.value
@@ -74,7 +75,7 @@ package object drm {
           }
           block
         } else {
-          new SparseRowMatrix(vectors.size, blockncol, vectors)
+          new SparseRowMatrix(vectors.size, blockncol, vectors, true, false)
         }
 
         Iterator(keys -> block)
@@ -101,7 +102,7 @@ package object drm {
         blockKeys.ensuring(blockKeys.size == block.nrow)
         blockKeys.view.zipWithIndex.map {
           case (key, idx) =>
-            var v = block(idx, ::) // This is just a view!
+            val v = block(idx, ::) // This is just a view!
 
             // If a view rather than a concrete vector, clone into a concrete 
vector in order not to
             // attempt to serialize outer matrix when we save it (Although 
maybe most often this

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/GenericMatrixKryoSerializer.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/GenericMatrixKryoSerializer.scala
 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/GenericMatrixKryoSerializer.scala
new file mode 100644
index 0000000..da58b35
--- /dev/null
+++ 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/GenericMatrixKryoSerializer.scala
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.sparkbindings.io
+
+
+import com.esotericsoftware.kryo.io.{Output, Input}
+import com.esotericsoftware.kryo.{Kryo, Serializer}
+import org.apache.log4j.Logger
+import org.apache.mahout.logging._
+import org.apache.mahout.math._
+import org.apache.mahout.math.flavor.TraversingStructureEnum
+import scalabindings._
+import RLikeOps._
+import collection._
+import JavaConversions._
+
+object GenericMatrixKryoSerializer {
+
+  private implicit final val log = 
Logger.getLogger(classOf[GenericMatrixKryoSerializer])
+
+}
+
+/** Serializes Sparse or Dense in-core generic matrix (row-wise or column-wise 
backed) */
+class GenericMatrixKryoSerializer extends Serializer[Matrix] {
+
+  import GenericMatrixKryoSerializer._
+
+  override def write(kryo: Kryo, output: Output, mx: Matrix): Unit = {
+
+    debug(s"Writing mx of type ${mx.getClass.getName}")
+
+    val structure = mx.getFlavor.getStructure
+
+    // Write structure bit
+    output.writeInt(structure.ordinal(), true)
+
+    // Write geometry
+    output.writeInt(mx.nrow, true)
+    output.writeInt(mx.ncol, true)
+
+    // Write in most efficient traversal order (using backing vectors perhaps)
+    structure match {
+      case TraversingStructureEnum.COLWISE => writeRowWise(kryo, output, mx.t)
+      case TraversingStructureEnum.SPARSECOLWISE => writeSparseRowWise(kryo, 
output, mx.t)
+      case TraversingStructureEnum.SPARSEROWWISE => writeSparseRowWise(kryo, 
output, mx)
+      case TraversingStructureEnum.VECTORBACKED => writeVectorBacked(kryo, 
output, mx)
+      case _ => writeRowWise(kryo, output, mx)
+    }
+
+  }
+
+  private def writeVectorBacked(kryo: Kryo, output: Output, mx: Matrix) {
+
+    require(mx != null)
+
+    // At this point we are just doing some vector-backed classes 
individually. TODO: create
+    // api to obtain vector-backed matrix data.
+    kryo.writeClass(output, mx.getClass)
+    mx match {
+      case mxD: DiagonalMatrix => kryo.writeObject(output, mxD.diagv)
+      case mxS: DenseSymmetricMatrix => kryo.writeObject(output, 
dvec(mxS.getData))
+      case mxT: UpperTriangular => kryo.writeObject(output, dvec(mxT.getData))
+      case _ => throw new IllegalArgumentException(s"Unsupported matrix 
type:${mx.getClass.getName}")
+    }
+  }
+
+  private def readVectorBacked(kryo: Kryo, input: Input, nrow: Int, ncol: Int) 
= {
+
+    // We require vector-backed matrices to have vector-parameterized 
constructor to construct.
+    val clazz = kryo.readClass(input).getType
+
+    debug(s"Deserializing vector-backed mx of type ${clazz.getName}.")
+
+    clazz.getConstructor(classOf[Vector]).newInstance(kryo.readObject(input, 
classOf[Vector])).asInstanceOf[Matrix]
+  }
+
+  private def writeRowWise(kryo: Kryo, output: Output, mx: Matrix): Unit = {
+    for (row <- mx) kryo.writeObject(output, row)
+  }
+
+  private def readRows(kryo: Kryo, input: Input, nrow: Int) = {
+    Array.tabulate(nrow) { _ => kryo.readObject(input, classOf[Vector])}
+  }
+
+  private def readSparseRows(kryo: Kryo, input: Input) = {
+
+    // Number of slices
+    val nslices = input.readInt(true)
+
+    Array.tabulate(nslices) { _ =>
+      input.readInt(true) -> kryo.readObject(input, classOf[Vector])
+    }
+  }
+
+  private def writeSparseRowWise(kryo: Kryo, output: Output, mx: Matrix): Unit 
= {
+
+    val nslices = mx.numSlices()
+
+    output.writeInt(nslices, true)
+
+    var actualNSlices = 0;
+    for (row <- mx.iterateNonEmpty()) {
+      output.writeInt(row.index(), true)
+      kryo.writeObject(output, row.vector())
+      actualNSlices += 1
+    }
+
+    require(nslices == actualNSlices, "Number of slices reported by 
Matrix.numSlices() was different from actual " +
+      "slice iterator size.")
+  }
+
+  override def read(kryo: Kryo, input: Input, mxClass: Class[Matrix]): Matrix 
= {
+
+    // Read structure hint
+    val structure = TraversingStructureEnum.values()(input.readInt(true))
+
+    // Read geometry
+    val nrow = input.readInt(true)
+    val ncol = input.readInt(true)
+
+    debug(s"read matrix geometry: $nrow x $ncol.")
+
+    structure match {
+
+      // Sparse or dense column wise
+      case TraversingStructureEnum.COLWISE =>
+        val cols = readRows(kryo, input, ncol)
+
+        if (!cols.isEmpty && cols.head.isDense)
+          dense(cols).t
+        else {
+          debug("Deserializing as SparseRowMatrix.t (COLWISE).")
+          new SparseRowMatrix(ncol, nrow, cols, true, false).t
+        }
+
+      // transposed SparseMatrix case
+      case TraversingStructureEnum.SPARSECOLWISE =>
+        val cols = readSparseRows(kryo, input)
+        val javamap = new java.util.HashMap[Integer, Vector]((cols.size << 1) 
+ 1)
+        cols.foreach { case (idx, vec) => javamap.put(idx, vec)}
+
+        debug("Deserializing as SparseMatrix.t (SPARSECOLWISE).")
+        new SparseMatrix(ncol, nrow, javamap, true).t
+
+      // Sparse Row-wise -- this will be created as a SparseMatrix.
+      case TraversingStructureEnum.SPARSEROWWISE =>
+        val rows = readSparseRows(kryo, input)
+        val javamap = new java.util.HashMap[Integer, Vector]((rows.size << 1) 
+ 1)
+        rows.foreach { case (idx, vec) => javamap.put(idx, vec)}
+
+        debug("Deserializing as SparseMatrix (SPARSEROWWISE).")
+        new SparseMatrix(nrow, ncol, javamap, true)
+      case TraversingStructureEnum.VECTORBACKED =>
+
+        debug("Deserializing vector-backed...")
+        readVectorBacked(kryo, input, nrow, ncol)
+
+      // By default, read row-wise.
+      case _ =>
+        val cols = readRows(kryo, input, nrow)
+        // this still copies a lot of stuff...
+        if (!cols.isEmpty && cols.head.isDense) {
+
+          debug("Deserializing as DenseMatrix.")
+          dense(cols)
+        } else {
+
+          debug("Deserializing as SparseRowMatrix(default).")
+          new SparseRowMatrix(nrow, ncol, cols, true, false)
+        }
+    }
+
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/MahoutKryoRegistrator.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/MahoutKryoRegistrator.scala
 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/MahoutKryoRegistrator.scala
index a8a0bb4..5806ff5 100644
--- 
a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/MahoutKryoRegistrator.scala
+++ 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/MahoutKryoRegistrator.scala
@@ -18,22 +18,28 @@
 package org.apache.mahout.sparkbindings.io
 
 import com.esotericsoftware.kryo.Kryo
-import com.esotericsoftware.kryo.serializers.JavaSerializer
 import org.apache.mahout.math._
-import org.apache.mahout.math.indexeddataset.{BiMap, BiDictionary}
 import org.apache.spark.serializer.KryoRegistrator
-import org.apache.mahout.sparkbindings._
-import org.apache.mahout.math.Vector.Element
+import org.apache.mahout.logging._
 
-import scala.collection.immutable.List
+object MahoutKryoRegistrator {
 
-/** Kryo serialization registrator for Mahout */
-class MahoutKryoRegistrator extends KryoRegistrator {
+  private final implicit val log = getLog(this.getClass)
+
+  def registerClasses(kryo: Kryo) = {
 
-  override def registerClasses(kryo: Kryo) = {
+    trace("Registering mahout classes.")
+
+    kryo.register(classOf[SparseColumnMatrix], new UnsupportedSerializer)
+    kryo.addDefaultSerializer(classOf[Vector], new VectorKryoSerializer())
+    kryo.addDefaultSerializer(classOf[Matrix], new GenericMatrixKryoSerializer)
 
-    kryo.addDefaultSerializer(classOf[Vector], new 
WritableKryoSerializer[Vector, VectorWritable])
-    kryo.addDefaultSerializer(classOf[DenseVector], new 
WritableKryoSerializer[Vector, VectorWritable])
-    kryo.addDefaultSerializer(classOf[Matrix], new 
WritableKryoSerializer[Matrix, MatrixWritable])
   }
+
+}
+
+/** Kryo serialization registrator for Mahout */
+class MahoutKryoRegistrator extends KryoRegistrator {
+
+  override def registerClasses(kryo: Kryo) = 
MahoutKryoRegistrator.registerClasses(kryo)
 }

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/UnsupportedSerializer.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/UnsupportedSerializer.scala
 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/UnsupportedSerializer.scala
new file mode 100644
index 0000000..66b79f4
--- /dev/null
+++ 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/UnsupportedSerializer.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.sparkbindings.io
+
+import com.esotericsoftware.kryo.io.{Output, Input}
+import com.esotericsoftware.kryo.{Kryo, Serializer}
+
+class UnsupportedSerializer extends Serializer[Any] {
+
+  override def write(kryo: Kryo, output: Output, obj: Any): Unit = {
+    throw new IllegalArgumentException(s"I/O of this 
type(${obj.getClass.getName} is explicitly unsupported for a " +
+      "good reason.")
+  }
+
+  override def read(kryo: Kryo, input: Input, `type`: Class[Any]): Any = ???
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/VectorKryoSerializer.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/main/scala/org/apache/mahout/sparkbindings/io/VectorKryoSerializer.scala
 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/VectorKryoSerializer.scala
new file mode 100644
index 0000000..175778f
--- /dev/null
+++ 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/io/VectorKryoSerializer.scala
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.sparkbindings.io
+
+import org.apache.log4j.Logger
+import org.apache.mahout.logging._
+import org.apache.mahout.math._
+import org.apache.mahout.math.scalabindings._
+import RLikeOps._
+
+import com.esotericsoftware.kryo.io.{OutputChunked, Output, Input}
+import com.esotericsoftware.kryo.{Kryo, Serializer}
+
+import collection._
+import JavaConversions._
+
+
+object VectorKryoSerializer {
+
+  final val FLAG_DENSE: Int = 0x01
+  final val FLAG_SEQUENTIAL: Int = 0x02
+  final val FLAG_NAMED: Int = 0x04
+  final val FLAG_LAX_PRECISION: Int = 0x08
+
+  private final implicit val log = getLog(classOf[VectorKryoSerializer])
+
+}
+
+class VectorKryoSerializer(val laxPrecision: Boolean = false) extends 
Serializer[Vector] {
+
+  import VectorKryoSerializer._
+
+  override def write(kryo: Kryo, output: Output, vector: Vector): Unit = {
+
+    require(vector != null)
+
+    trace(s"Serializing vector of ${vector.getClass.getName} class.")
+
+    // Write length
+    val len = vector.length
+    output.writeInt(len, true)
+
+    // Interrogate vec properties
+    val dense = vector.isDense
+    val sequential = vector.isSequentialAccess
+    val named = vector.isInstanceOf[NamedVector]
+
+    var flag = 0
+
+    if (dense) {
+      flag |= FLAG_DENSE
+    } else if (sequential) {
+      flag |= FLAG_SEQUENTIAL
+    }
+
+    if (vector.isInstanceOf[NamedVector]) {
+      flag |= FLAG_NAMED
+    }
+
+    if (laxPrecision) flag |= FLAG_LAX_PRECISION
+
+    // Write flags
+    output.writeByte(flag)
+
+    // Write name if needed
+    if (named) output.writeString(vector.asInstanceOf[NamedVector].getName)
+
+    dense match {
+
+      // Dense vector.
+      case true =>
+
+        laxPrecision match {
+          case true =>
+            for (i <- 0 until vector.length) 
output.writeFloat(vector(i).toFloat)
+          case _ =>
+            for (i <- 0 until vector.length) output.writeDouble(vector(i))
+        }
+      case _ =>
+
+        // Turns out getNumNonZeroElements must check every element if it is 
indeed non-zero. The
+        // iterateNonZeros() on the other hand doesn't do that, so that's all 
inconsistent right
+        // now. so we'll just auto-terminate.
+        val iter = vector.nonZeroes.toIterator.filter(_.get() != 0.0)
+
+        sequential match {
+
+          // Delta encoding
+          case true =>
+
+            var idx = 0
+            laxPrecision match {
+              case true =>
+                while (iter.hasNext) {
+                  val el = iter.next()
+                  output.writeFloat(el.toFloat)
+                  output.writeInt(el.index() - idx, true)
+                  idx = el.index
+                }
+                // Terminate delta encoding.
+                output.writeFloat(0.0.toFloat)
+              case _ =>
+                while (iter.hasNext) {
+                  val el = iter.next()
+                  output.writeDouble(el.get())
+                  output.writeInt(el.index() - idx, true)
+                  idx = el.index
+                }
+                // Terminate delta encoding.
+                output.writeDouble(0.0)
+            }
+
+          // Random access.
+          case _ =>
+
+            laxPrecision match {
+
+              case true =>
+                iter.foreach { el =>
+                  output.writeFloat(el.get().toFloat)
+                  output.writeInt(el.index(), true)
+                }
+                // Terminate random access with 0.0 value.
+                output.writeFloat(0.0.toFloat)
+              case _ =>
+                iter.foreach { el =>
+                  output.writeDouble(el.get())
+                  output.writeInt(el.index(), true)
+                }
+                // Terminate random access with 0.0 value.
+                output.writeDouble(0.0)
+            }
+
+        }
+
+    }
+  }
+
+  override def read(kryo: Kryo, input: Input, vecClass: Class[Vector]): Vector 
= {
+
+    val len = input.readInt(true)
+    val flags = input.readByte().toInt
+    val name = if ((flags & FLAG_NAMED) != 0) Some(input.readString()) else 
None
+
+    val vec: Vector = flags match {
+
+      // Dense
+      case _: Int if ((flags & FLAG_DENSE) != 0) =>
+
+        trace(s"Deserializing dense vector.")
+
+        if ((flags & FLAG_LAX_PRECISION) != 0) {
+          new DenseVector(len) := { _ => input.readFloat()}
+        } else {
+          new DenseVector(len) := { _ => input.readDouble()}
+        }
+
+      // Sparse case.
+      case _ =>
+
+        flags match {
+
+          // Sequential.
+          case _: Int if ((flags & FLAG_SEQUENTIAL) != 0) =>
+
+            trace("Deserializing as sequential sparse vector.")
+
+            val v = new SequentialAccessSparseVector(len)
+            var idx = 0
+            var stop = false
+
+            if ((flags & FLAG_LAX_PRECISION) != 0) {
+
+              while (!stop) {
+                val value = input.readFloat()
+                if (value == 0.0) {
+                  stop = true
+                } else {
+                  idx += input.readInt(true)
+                  v(idx) = value
+                }
+              }
+            } else {
+              while (!stop) {
+                val value = input.readDouble()
+                if (value == 0.0) {
+                  stop = true
+                } else {
+                  idx += input.readInt(true)
+                  v(idx) = value
+                }
+              }
+            }
+            v
+
+          // Random access
+          case _ =>
+
+            trace("Deserializing as random access vector.")
+
+            // Read pairs until we see 0.0 value. Prone to corruption attacks 
obviously.
+            val v = new RandomAccessSparseVector(len)
+            var stop = false
+            if ((flags & FLAG_LAX_PRECISION) != 0) {
+              while (! stop ) {
+                val value = input.readFloat()
+                if ( value == 0.0 ) {
+                  stop = true
+                } else {
+                  val idx = input.readInt(true)
+                  v(idx) = value
+                }
+              }
+            } else {
+              while (! stop ) {
+                val value = input.readDouble()
+                if (value == 0.0) {
+                  stop = true
+                } else {
+                  val idx = input.readInt(true)
+                  v(idx) = value
+                }
+              }
+            }
+            v
+        }
+    }
+
+    name.map{name =>
+
+      trace(s"Recovering named vector's name ${name}.")
+
+      new NamedVector(vec, name)
+    }
+      .getOrElse(vec)
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala
----------------------------------------------------------------------
diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala
index 02f6b8c..330ae38 100644
--- a/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala
+++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala
@@ -17,27 +17,27 @@
 
 package org.apache.mahout
 
-import org.apache.mahout.drivers.TextDelimitedIndexedDatasetReader
-import org.apache.mahout.math.indexeddataset.Schema
-import org.apache.mahout.sparkbindings.indexeddataset.IndexedDatasetSpark
-import org.apache.spark.{SparkConf, SparkContext}
 import java.io._
-import scala.collection.mutable.ArrayBuffer
-import org.apache.mahout.common.IOUtils
-import org.apache.log4j.Logger
+
+import org.apache.mahout.logging._
 import org.apache.mahout.math.drm._
-import scala.reflect.ClassTag
-import org.apache.mahout.sparkbindings.drm.{DrmRddInput, SparkBCast, 
CheckpointedDrmSparkOps, CheckpointedDrmSpark}
-import org.apache.spark.rdd.RDD
+import org.apache.mahout.math.{MatrixWritable, VectorWritable, Matrix, Vector}
+import org.apache.mahout.sparkbindings.drm.{CheckpointedDrmSpark, 
CheckpointedDrmSparkOps, SparkBCast}
+import org.apache.mahout.util.IOUtilsScala
 import org.apache.spark.broadcast.Broadcast
-import org.apache.mahout.math.{VectorWritable, Vector, MatrixWritable, Matrix}
-import org.apache.hadoop.io.Writable
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.rdd.RDD
+import org.apache.spark.{SparkConf, SparkContext}
+
+import collection._
+import collection.generic.Growable
+import scala.reflect.ClassTag
+
+
 
 /** Public api for Spark-specific operators */
 package object sparkbindings {
 
-  private[sparkbindings] val log = 
Logger.getLogger("org.apache.mahout.sparkbindings")
+  private final implicit val log = getLog(`package`.getClass)
 
   /** Row-wise organized DRM rdd type */
   type DrmRdd[K] = RDD[DrmTuple[K]]
@@ -55,15 +55,11 @@ package object sparkbindings {
    * @param customJars
    * @return
    */
-  def mahoutSparkContext(
-      masterUrl: String,
-      appName: String,
-      customJars: TraversableOnce[String] = Nil,
-      sparkConf: SparkConf = new SparkConf(),
-      addMahoutJars: Boolean = true
-      ): SparkDistributedContext = {
+  def mahoutSparkContext(masterUrl: String, appName: String, customJars: 
TraversableOnce[String] = Nil,
+                         sparkConf: SparkConf = new SparkConf(), 
addMahoutJars: Boolean = true):
+  SparkDistributedContext = {
 
-    val closeables = new java.util.ArrayDeque[Closeable]()
+    val closeables = mutable.ListBuffer.empty[Closeable]
 
     try {
 
@@ -84,9 +80,9 @@ package object sparkbindings {
         sparkConf.setJars(customJars.toSeq)
       }
 
-      sparkConf.setAppName(appName).setMaster(masterUrl)
-          .set("spark.serializer", 
"org.apache.spark.serializer.KryoSerializer")
-          .set("spark.kryo.registrator", 
"org.apache.mahout.sparkbindings.io.MahoutKryoRegistrator")
+      
sparkConf.setAppName(appName).setMaster(masterUrl).set("spark.serializer",
+        
"org.apache.spark.serializer.KryoSerializer").set("spark.kryo.registrator",
+          "org.apache.mahout.sparkbindings.io.MahoutKryoRegistrator")
 
       if (System.getenv("SPARK_HOME") != null) {
         sparkConf.setSparkHome(System.getenv("SPARK_HOME"))
@@ -95,7 +91,7 @@ package object sparkbindings {
       new SparkDistributedContext(new SparkContext(config = sparkConf))
 
     } finally {
-      IOUtils.close(closeables)
+      IOUtilsScala.close(closeables)
     }
   }
 
@@ -103,19 +99,19 @@ package object sparkbindings {
 
   implicit def sc2sdc(sc: SparkContext): SparkDistributedContext = new 
SparkDistributedContext(sc)
 
-  implicit def dc2sc(dc:DistributedContext):SparkContext = {
-    assert (dc.isInstanceOf[SparkDistributedContext],"distributed context must 
be Spark-specific.")
+  implicit def dc2sc(dc: DistributedContext): SparkContext = {
+    assert(dc.isInstanceOf[SparkDistributedContext], "distributed context must 
be Spark-specific.")
     sdc2sc(dc.asInstanceOf[SparkDistributedContext])
   }
 
   /** Broadcast transforms */
-  implicit def sb2bc[T](b:Broadcast[T]):BCast[T] = new SparkBCast(b)
+  implicit def sb2bc[T](b: Broadcast[T]): BCast[T] = new SparkBCast(b)
 
   /** Adding Spark-specific ops */
   implicit def cpDrm2cpDrmSparkOps[K: ClassTag](drm: CheckpointedDrm[K]): 
CheckpointedDrmSparkOps[K] =
     new CheckpointedDrmSparkOps[K](drm)
 
-  implicit def 
drm2cpDrmSparkOps[K:ClassTag](drm:DrmLike[K]):CheckpointedDrmSparkOps[K] = 
drm:CheckpointedDrm[K]
+  implicit def drm2cpDrmSparkOps[K: ClassTag](drm: DrmLike[K]): 
CheckpointedDrmSparkOps[K] = drm: CheckpointedDrm[K]
 
   private[sparkbindings] implicit def m2w(m: Matrix): MatrixWritable = new 
MatrixWritable(m)
 
@@ -123,7 +119,7 @@ package object sparkbindings {
 
   private[sparkbindings] implicit def v2w(v: Vector): VectorWritable = new 
VectorWritable(v)
 
-  private[sparkbindings] implicit def w2v(w:VectorWritable):Vector = w.get()
+  private[sparkbindings] implicit def w2v(w: VectorWritable): Vector = w.get()
 
   /**
    * ==Wrap existing RDD into a matrix==
@@ -141,34 +137,31 @@ package object sparkbindings {
    * @tparam K row key type
    * @return wrapped DRM
    */
-  def drmWrap[K: ClassTag](
-      rdd: DrmRdd[K],
-      nrow: Int = -1,
-      ncol: Int = -1,
-      cacheHint: CacheHint.CacheHint = CacheHint.NONE,
-      canHaveMissingRows: Boolean = false
-      ): CheckpointedDrm[K] =
-
-    new CheckpointedDrmSpark[K](
-      rdd = rdd,
-      _nrow = nrow,
-      _ncol = ncol,
-      _cacheStorageLevel = SparkEngine.cacheHint2Spark(cacheHint),
-      _canHaveMissingRows = canHaveMissingRows
-    )
+  def drmWrap[K: ClassTag](rdd: DrmRdd[K], nrow: Long = -1, ncol: Int = -1, 
cacheHint: CacheHint.CacheHint =
+  CacheHint.NONE, canHaveMissingRows: Boolean = false): CheckpointedDrm[K] =
+
+    new CheckpointedDrmSpark[K](rddInput = rdd, _nrow = nrow, _ncol = ncol, 
_cacheStorageLevel = SparkEngine
+      .cacheHint2Spark(cacheHint), _canHaveMissingRows = canHaveMissingRows)
+
+
+  /** Another drmWrap version that takes in vertical block-partitioned input 
to form the matrix. */
+  def drmWrapBlockified[K: ClassTag](blockifiedDrmRdd: BlockifiedDrmRdd[K], 
nrow: Long = -1, ncol: Int = -1,
+                                     cacheHint: CacheHint.CacheHint = 
CacheHint.NONE,
+                                     canHaveMissingRows: Boolean = false): 
CheckpointedDrm[K] =
+
+    drmWrap(drm.deblockify(blockifiedDrmRdd), nrow, ncol, cacheHint, 
canHaveMissingRows)
 
   private[sparkbindings] def getMahoutHome() = {
     var mhome = System.getenv("MAHOUT_HOME")
     if (mhome == null) mhome = System.getProperty("mahout.home")
-    require(mhome != null, "MAHOUT_HOME is required to spawn mahout-based 
spark jobs" )
+    require(mhome != null, "MAHOUT_HOME is required to spawn mahout-based 
spark jobs")
     mhome
   }
 
   /** Acquire proper Mahout jars to be added to task context based on current 
MAHOUT_HOME. */
-  private[sparkbindings] def 
findMahoutContextJars(closeables:java.util.Deque[Closeable]) = {
+  private[sparkbindings] def findMahoutContextJars(closeables: 
Growable[Closeable]) = {
 
     // Figure Mahout classpath using $MAHOUT_HOME/mahout classpath command.
-
     val fmhome = new File(getMahoutHome())
     val bin = new File(fmhome, "bin")
     val exec = new File(bin, "mahout")
@@ -177,26 +170,25 @@ package object sparkbindings {
 
     val p = Runtime.getRuntime.exec(Array(exec.getAbsolutePath, "-spark", 
"classpath"))
 
-    closeables.addFirst(new Closeable {
+    closeables += new Closeable {
       def close() {
         p.destroy()
       }
-    })
+    }
 
     val r = new BufferedReader(new InputStreamReader(p.getInputStream))
-    closeables.addFirst(r)
+    closeables += r
 
     val w = new StringWriter()
-    closeables.addFirst(w)
+    closeables += w
 
     var continue = true;
-    val jars = new ArrayBuffer[String]()
+    val jars = new mutable.ArrayBuffer[String]()
     do {
       val cp = r.readLine()
       if (cp == null)
-        throw new IllegalArgumentException(
-          "Unable to read output from \"mahout -spark classpath\". Is 
SPARK_HOME defined?"
-        )
+        throw new IllegalArgumentException("Unable to read output from 
\"mahout -spark classpath\". Is SPARK_HOME " +
+          "defined?")
 
       val j = cp.split(File.pathSeparatorChar)
       if (j.size > 10) {
@@ -206,8 +198,7 @@ package object sparkbindings {
       }
     } while (continue)
 
-//    jars.foreach(j => log.info(j))
-
+    //    jars.foreach(j => log.info(j))
     // context specific jars
     val mcjars = jars.filter(j =>
       j.matches(".*mahout-math-\\d.*\\.jar") ||
@@ -233,4 +224,13 @@ package object sparkbindings {
     mcjars
   }
 
+  private[sparkbindings] def validateBlockifiedDrmRdd[K](rdd: 
BlockifiedDrmRdd[K]): Boolean = {
+    // Mostly, here each block must contain exactly one block
+    val part1Req = rdd.mapPartitions(piter => Iterator(piter.size == 
1)).reduce(_ && _)
+
+    if (!part1Req) warn("blockified rdd: condition not met: exactly 1 per 
partition")
+
+    return part1Req
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/test/scala/org/apache/mahout/sparkbindings/SparkBindingsSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/test/scala/org/apache/mahout/sparkbindings/SparkBindingsSuite.scala 
b/spark/src/test/scala/org/apache/mahout/sparkbindings/SparkBindingsSuite.scala
index fbc31f3..529d13c 100644
--- 
a/spark/src/test/scala/org/apache/mahout/sparkbindings/SparkBindingsSuite.scala
+++ 
b/spark/src/test/scala/org/apache/mahout/sparkbindings/SparkBindingsSuite.scala
@@ -1,10 +1,12 @@
 package org.apache.mahout.sparkbindings
 
-import org.scalatest.FunSuite
+import java.io.{Closeable, File}
 import java.util
-import java.io.{File, Closeable}
-import org.apache.mahout.common.IOUtils
+
 import org.apache.mahout.sparkbindings.test.DistributedSparkSuite
+import org.apache.mahout.util.IOUtilsScala
+import org.scalatest.FunSuite
+import collection._
 
 /**
  * @author dmitriy
@@ -16,7 +18,7 @@ class SparkBindingsSuite extends FunSuite with 
DistributedSparkSuite {
   // let it to be ignored.
   ignore("context jars") {
     System.setProperty("mahout.home", new 
File("..").getAbsolutePath/*"/home/dmitriy/projects/github/mahout-commits"*/)
-    val closeables = new util.ArrayDeque[Closeable]()
+    val closeables = new mutable.ListBuffer[Closeable]()
     try {
       val mahoutJars = findMahoutContextJars(closeables)
       mahoutJars.foreach {
@@ -26,7 +28,7 @@ class SparkBindingsSuite extends FunSuite with 
DistributedSparkSuite {
       mahoutJars.size should be > 0
       mahoutJars.size shouldBe 4
     } finally {
-      IOUtils.close(closeables)
+      IOUtilsScala.close(closeables)
     }
 
   }

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/test/scala/org/apache/mahout/sparkbindings/blas/BlasSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/test/scala/org/apache/mahout/sparkbindings/blas/BlasSuite.scala 
b/spark/src/test/scala/org/apache/mahout/sparkbindings/blas/BlasSuite.scala
index 1521cb8..8c8ac3f 100644
--- a/spark/src/test/scala/org/apache/mahout/sparkbindings/blas/BlasSuite.scala
+++ b/spark/src/test/scala/org/apache/mahout/sparkbindings/blas/BlasSuite.scala
@@ -26,7 +26,7 @@ import scalabindings._
 import RLikeOps._
 import drm._
 import org.apache.mahout.sparkbindings._
-import org.apache.mahout.sparkbindings.drm.CheckpointedDrmSpark
+import org.apache.mahout.sparkbindings.drm._
 import org.apache.mahout.math.drm.logical.{OpAt, OpAtA, OpAewB, OpABt}
 import org.apache.mahout.sparkbindings.test.DistributedSparkSuite
 
@@ -142,7 +142,7 @@ class BlasSuite extends FunSuite with DistributedSparkSuite 
{
     val drmA = drmParallelize(m = inCoreA, numPartitions = 2)
 
     val op = new OpAt(drmA)
-    val drmAt = new CheckpointedDrmSpark(rdd = At.at(op, srcA = drmA), _nrow = 
op.nrow, _ncol = op.ncol)
+    val drmAt = new CheckpointedDrmSpark(rddInput = At.at(op, srcA = drmA), 
_nrow = op.nrow, _ncol = op.ncol)
     val inCoreAt = drmAt.collect
     val inCoreControlAt = inCoreA.t
 

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeOpsSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeOpsSuite.scala
 
b/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeOpsSuite.scala
index 42026ae..7241660 100644
--- 
a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeOpsSuite.scala
+++ 
b/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeOpsSuite.scala
@@ -23,13 +23,14 @@ import drm._
 import RLikeOps._
 import RLikeDrmOps._
 import org.apache.mahout.sparkbindings._
-import org.scalatest.FunSuite
+import org.scalatest.{ConfigMap, BeforeAndAfterAllConfigMap, FunSuite}
 import org.apache.mahout.sparkbindings.test.DistributedSparkSuite
 
+import scala.reflect.ClassTag
+
 /** Tests for DrmLikeOps */
 class DrmLikeOpsSuite extends FunSuite with DistributedSparkSuite with 
DrmLikeOpsSuiteBase {
 
-
   test("exact, min and auto ||") {
     val inCoreA = dense((1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6))
     val A = drmParallelize(m = inCoreA, numPartitions = 2)
@@ -39,18 +40,20 @@ class DrmLikeOpsSuite extends FunSuite with 
DistributedSparkSuite with DrmLikeOp
     (A + 1.0).par(exact = 4).rdd.partitions.size should equal(4)
     A.par(exact = 2).rdd.partitions.size should equal(2)
     A.par(exact = 1).rdd.partitions.size should equal(1)
-    A.par(exact = 0).rdd.partitions.size should equal(2) // No effect for par 
<= 0
+
     A.par(min = 4).rdd.partitions.size should equal(4)
     A.par(min = 2).rdd.partitions.size should equal(2)
     A.par(min = 1).rdd.partitions.size should equal(2)
     A.par(auto = true).rdd.partitions.size should equal(10)
     A.par(exact = 10).par(auto = true).rdd.partitions.size should equal(10)
     A.par(exact = 11).par(auto = true).rdd.partitions.size should equal(19)
-    A.par(exact = 20).par(auto = true).rdd.partitions.size should equal(20)
+    A.par(exact = 20).par(auto = true).rdd.partitions.size should equal(19)
+
+    A.keyClassTag shouldBe ClassTag.Int
+    A.par(auto = true).keyClassTag shouldBe ClassTag.Int
 
-    intercept[AssertionError] {
-      A.par()
-    }
+    an[IllegalArgumentException] shouldBe thrownBy {A.par(exact = 0)}
+    an[IllegalArgumentException] shouldBe thrownBy {A.par()}
   }
 
 }

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala
 
b/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala
index 2a4f213..f422f86 100644
--- 
a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala
+++ 
b/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala
@@ -25,10 +25,16 @@ import drm._
 import org.apache.mahout.sparkbindings._
 import RLikeDrmOps._
 import test.DistributedSparkSuite
+import org.apache.mahout.math.drm.logical.{OpAtB, OpAewUnaryFuncFusion}
+import org.apache.mahout.logging._
+
+import scala.util.Random
 
 /** ==R-like DRM DSL operation tests -- Spark== */
 class RLikeDrmOpsSuite extends FunSuite with DistributedSparkSuite with 
RLikeDrmOpsSuiteBase {
 
+  private final implicit val log = getLog(classOf[RLikeDrmOpsSuite])
+
   test("C = A + B missing rows") {
     val sc = mahoutCtx.asInstanceOf[SparkDistributedContext].sc
 
@@ -113,4 +119,61 @@ class RLikeDrmOpsSuite extends FunSuite with 
DistributedSparkSuite with RLikeDrm
 
   }
 
+  test("A'B, bigger") {
+
+    val rnd = new Random()
+    val a = new SparseRowMatrix(200, 1544) := { _ => rnd.nextGaussian() }
+    val b = new SparseRowMatrix(200, 300) := { _ => rnd.nextGaussian() }
+
+    var ms = System.currentTimeMillis()
+    val atb = a.t %*% b
+    ms = System.currentTimeMillis() - ms
+
+    println(s"in-core mul ms: $ms")
+
+    val drmA = drmParallelize(a, numPartitions = 2)
+    val drmB = drmParallelize(b, numPartitions = 2)
+
+    ms = System.currentTimeMillis()
+    val drmAtB = drmA.t %*% drmB
+    val mxAtB = drmAtB.collect
+    ms = System.currentTimeMillis() - ms
+
+    println(s"a'b plan:${drmAtB.context.engine.optimizerRewrite(drmAtB)}")
+    println(s"a'b plan contains ${drmAtB.rdd.partitions.size} partitions.")
+    println(s"distributed mul ms: $ms.")
+
+    (atb - mxAtB).norm should be < 1e-5
+
+  }
+
+  test("C = At %*% B , zippable") {
+
+    val mxA = dense((1, 2), (3, 4), (-3, -5))
+
+    val A = drmParallelize(mxA, numPartitions = 2)
+      .mapBlock()({
+      case (keys, block) => keys.map(_.toString) -> block
+    })
+
+    val B = (A + 1.0)
+
+      .mapBlock() { case (keys, block) ⇒
+      val nblock = new SparseRowMatrix(block.nrow, block.ncol) := block
+      keys → nblock
+    }
+
+    B.collect
+
+    val C = A.t %*% B
+
+    mahoutCtx.optimizerRewrite(C) should equal(OpAtB[String](A, B))
+
+    val inCoreC = C.collect
+    val inCoreControlC = mxA.t %*% (mxA + 1.0)
+
+    (inCoreC - inCoreControlC).norm should be < 1E-10
+
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/test/scala/org/apache/mahout/sparkbindings/io/IOSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/test/scala/org/apache/mahout/sparkbindings/io/IOSuite.scala 
b/spark/src/test/scala/org/apache/mahout/sparkbindings/io/IOSuite.scala
new file mode 100644
index 0000000..f3a9721
--- /dev/null
+++ b/spark/src/test/scala/org/apache/mahout/sparkbindings/io/IOSuite.scala
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.sparkbindings.io
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+
+import com.esotericsoftware.kryo.Kryo
+import com.esotericsoftware.kryo.io.{Input, Output}
+import com.twitter.chill.AllScalaRegistrar
+import org.apache.mahout.math._
+import scalabindings._
+import RLikeOps._
+
+import org.apache.mahout.common.RandomUtils
+import org.apache.mahout.test.MahoutSuite
+import org.scalatest.FunSuite
+
+import scala.util.Random
+
+class IOSuite extends FunSuite with MahoutSuite {
+
+  import IOSuite._
+
+  test("Dense vector kryo") {
+
+    val rnd = RandomUtils.getRandom
+    val vec = new DenseVector(165) := { _ => rnd.nextDouble()}
+
+    val ret = kryoClone(vec, vec, vec)
+    val vec2 = ret(2)
+
+    println(s"vec=$vec\nvc2=$vec2")
+
+    vec2 === vec shouldBe true
+    vec2.isInstanceOf[DenseVector] shouldBe true
+  }
+
+  test("Random sparse vector kryo") {
+
+    val rnd = RandomUtils.getRandom
+    val vec = new RandomAccessSparseVector(165) := { _ => if (rnd.nextDouble() 
< 0.3) rnd.nextDouble() else 0}
+    val vec1 = new RandomAccessSparseVector(165)
+    vec1(2) = 2
+    vec1(3) = 4
+    vec1(3) = 0
+    vec1(10) = 30
+
+    val ret = kryoClone(vec, vec1, vec)
+    val (vec2, vec3) = (ret(2), ret(1))
+
+    println(s"vec=$vec\nvc2=$vec2")
+
+    vec2 === vec shouldBe true
+    vec1 === vec3 shouldBe true
+    vec2.isInstanceOf[RandomAccessSparseVector] shouldBe true
+
+  }
+
+  test("100% sparse vectors") {
+
+    val vec1 = new SequentialAccessSparseVector(10)
+    val vec2 = new RandomAccessSparseVector(6)
+    val ret = kryoClone(vec1, vec2, vec1, vec2)
+    val vec3 = ret(2)
+    val vec4 = ret(3)
+
+    vec1 === vec3 shouldBe true
+    vec2 === vec4 shouldBe true
+  }
+
+  test("Sequential sparse vector kryo") {
+
+    val rnd = RandomUtils.getRandom
+    val vec = new SequentialAccessSparseVector(165) := { _ => if 
(rnd.nextDouble() < 0.3) rnd.nextDouble() else 0}
+
+    val vec1 = new SequentialAccessSparseVector(165)
+    vec1(2) = 0
+    vec1(3) = 3
+    vec1(4) = 2
+    vec1(3) = 0
+
+    val ret = kryoClone(vec, vec1, vec)
+    val (vec2, vec3) = (ret(2), ret(1))
+
+    println(s"vec=$vec\nvc2=$vec2")
+
+    vec2 === vec shouldBe true
+    vec1 === vec3 shouldBe true
+    vec2.isInstanceOf[SequentialAccessSparseVector] shouldBe true
+  }
+
+  test("kryo matrix tests") {
+    val rnd = new Random()
+
+    val mxA = new DenseMatrix(140, 150) := { _ => rnd.nextDouble()}
+
+    val mxB = new SparseRowMatrix(140, 150) := { _ => if (rnd.nextDouble() < 
.3) rnd.nextDouble() else 0.0}
+
+    val mxC = new SparseMatrix(140, 150)
+    for (i <- 0 until mxC.nrow) if (rnd.nextDouble() < .3)
+      mxC(i, ::) := { _ => if (rnd.nextDouble() < .3) rnd.nextDouble() else 
0.0}
+
+    val cnsl = mxC.numSlices()
+    println(s"Number of slices in mxC: ${cnsl}")
+
+    val ret = kryoClone(mxA, mxA.t, mxB, mxB.t, mxC, mxC.t, mxA)
+
+    val (mxAA, mxAAt, mxBB, mxBBt, mxCC, mxCCt, mxAAA) = (ret(0), ret(1), 
ret(2), ret(3), ret(4), ret(5), ret(6))
+
+    // ret.size shouldBe 7
+
+    mxA === mxAA shouldBe true
+    mxA === mxAAA shouldBe true
+    mxA === mxAAt.t shouldBe true
+    mxAA.isInstanceOf[DenseMatrix] shouldBe true
+    mxAAt.isInstanceOf[DenseMatrix] shouldBe false
+
+
+    mxB === mxBB shouldBe true
+    mxB === mxBBt.t shouldBe true
+    mxBB.isInstanceOf[SparseRowMatrix] shouldBe true
+    mxBBt.isInstanceOf[SparseRowMatrix] shouldBe false
+    mxBB(0,::).isDense shouldBe false
+
+
+    // Assert no persistence operation increased slice sparsity
+    mxC.numSlices() shouldBe cnsl
+
+    // Assert deserialized product did not experience any empty slice inflation
+    mxCC.numSlices() shouldBe cnsl
+    mxCCt.t.numSlices() shouldBe cnsl
+
+    // Incidentally, but not very significantly, iterating thru all rows that 
happens in equivalence
+    // operator, inserts empty rows into SparseMatrix so these asserts should 
not be before numSlices
+    // asserts.
+    mxC === mxCC shouldBe true
+    mxC === mxCCt.t shouldBe true
+    mxCCt.t.isInstanceOf[SparseMatrix] shouldBe true
+
+    // Column-wise sparse matrix are deprecated and should be explicitly 
rejected by serializer.
+    an[IllegalArgumentException] should be thrownBy {
+      val mxDeprecated = new SparseColumnMatrix(14, 15)
+      kryoClone(mxDeprecated)
+    }
+
+  }
+
+  test("diag matrix") {
+
+    val mxD = diagv(dvec(1, 2, 3, 5))
+    val mxDD = kryoClone(mxD)(0)
+    mxD === mxDD shouldBe true
+    mxDD.isInstanceOf[DiagonalMatrix] shouldBe true
+
+  }
+}
+
+object IOSuite {
+
+  def kryoClone[T](obj: T*): Seq[T] = {
+
+    val kryo = new Kryo()
+    new AllScalaRegistrar()(kryo)
+
+    MahoutKryoRegistrator.registerClasses(kryo)
+
+    val baos = new ByteArrayOutputStream()
+    val output = new Output(baos)
+    obj.foreach(kryo.writeClassAndObject(output, _))
+    output.close
+
+    val input = new Input(new ByteArrayInputStream(baos.toByteArray))
+
+    def outStream: Stream[T] =
+      if (input.eof) Stream.empty
+      else kryo.readClassAndObject(input).asInstanceOf[T] #:: outStream
+
+    outStream
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/test/scala/org/apache/mahout/sparkbindings/test/DistributedSparkSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/test/scala/org/apache/mahout/sparkbindings/test/DistributedSparkSuite.scala
 
b/spark/src/test/scala/org/apache/mahout/sparkbindings/test/DistributedSparkSuite.scala
index f18ec70..d917a22 100644
--- 
a/spark/src/test/scala/org/apache/mahout/sparkbindings/test/DistributedSparkSuite.scala
+++ 
b/spark/src/test/scala/org/apache/mahout/sparkbindings/test/DistributedSparkSuite.scala
@@ -17,11 +17,13 @@
 
 package org.apache.mahout.sparkbindings.test
 
+import org.apache.log4j.{Level, Logger}
 import org.scalatest.{ConfigMap, BeforeAndAfterAllConfigMap, Suite}
 import org.apache.spark.SparkConf
 import org.apache.mahout.sparkbindings._
 import org.apache.mahout.test.{DistributedMahoutSuite, MahoutSuite}
 import org.apache.mahout.math.drm.DistributedContext
+import collection.JavaConversions._
 
 trait DistributedSparkSuite extends DistributedMahoutSuite with 
LoggerConfiguration {
   this: Suite =>
@@ -30,16 +32,21 @@ trait DistributedSparkSuite extends DistributedMahoutSuite 
with LoggerConfigurat
   protected var masterUrl = null.asInstanceOf[String]
 
   protected def initContext() {
-    masterUrl = "local[3]"
+    masterUrl = System.getProperties.getOrElse("test.spark.master", "local[3]")
+    val isLocal = masterUrl.startsWith("local")
     mahoutCtx = mahoutSparkContext(masterUrl = this.masterUrl,
-      appName = "MahoutLocalContext",
+      appName = "MahoutUnitTests",
       // Do not run MAHOUT_HOME jars in unit tests.
-      addMahoutJars = false,
+      addMahoutJars = !isLocal,
       sparkConf = new SparkConf()
-          .set("spark.kryoserializer.buffer.mb", "15")
+          .set("spark.kryoserializer.buffer.mb", "40")
           .set("spark.akka.frameSize", "30")
           .set("spark.default.parallelism", "10")
+          .set("spark.executor.memory", "2G")
     )
+    // Spark reconfigures logging. Clamp down on it in tests.
+    Logger.getRootLogger.setLevel(Level.ERROR)
+    Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
   }
 
   protected def resetContext() {

http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/test/scala/org/apache/mahout/sparkbindings/test/LoggerConfiguration.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/test/scala/org/apache/mahout/sparkbindings/test/LoggerConfiguration.scala
 
b/spark/src/test/scala/org/apache/mahout/sparkbindings/test/LoggerConfiguration.scala
index e48e7c7..2a996d7 100644
--- 
a/spark/src/test/scala/org/apache/mahout/sparkbindings/test/LoggerConfiguration.scala
+++ 
b/spark/src/test/scala/org/apache/mahout/sparkbindings/test/LoggerConfiguration.scala
@@ -25,6 +25,6 @@ trait LoggerConfiguration extends 
org.apache.mahout.test.LoggerConfiguration {
 
   override protected def beforeAll(configMap: ConfigMap) {
     super.beforeAll(configMap)
-    Logger.getLogger("org.apache.mahout.sparkbindings").setLevel(Level.INFO)
+    BasicConfigurator.resetConfiguration()
   }
 }

Reply via email to