http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOps.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOps.scala b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOps.scala index 97e06cf..7091c53 100644 --- a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOps.scala +++ b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOps.scala @@ -16,18 +16,28 @@ */ package org.apache.mahout.math.scalabindings +import org.apache.mahout.math.function.Functions import org.apache.mahout.math.{Vector, Matrix} import scala.collection.JavaConversions._ import RLikeOps._ class RLikeMatrixOps(m: Matrix) extends MatrixOps(m) { + /** Structure-optimized mmul */ + def %*%(that: Matrix) = MMul(m, that, None) + + def :%*%(that:Matrix) = %*%(that) + + def %*%:(that: Matrix) = that :%*% m + /** - * matrix-matrix multiplication - * @param that - * @return + * The "legacy" matrix-matrix multiplication. + * + * @param that right hand operand + * @return matrix multiplication result + * @deprecated use %*% */ - def %*%(that: Matrix) = m.times(that) + def %***%(that: Matrix) = m.times(that) /** * matrix-vector multiplication @@ -65,13 +75,16 @@ class RLikeMatrixOps(m: Matrix) extends MatrixOps(m) { * @param that */ def *=(that: Matrix) = { - m.zip(that).foreach(t => t._1.vector *= t._2.vector) + m.assign(that, Functions.MULT) m } + /** A *=: B is equivalent to B *= A. Included for completeness. */ + def *=:(that: Matrix) = m *= that + /** Elementwise deletion */ def /=(that: Matrix) = { - m.zip(that).foreach(t => t._1.vector() /= t._2.vector) + m.zip(that).foreach(t â t._1.vector() /= t._2.vector) m } @@ -80,15 +93,63 @@ class RLikeMatrixOps(m: Matrix) extends MatrixOps(m) { m } + /** 5.0 *=: A is equivalent to A *= 5.0. Included for completeness. */ + def *=:(that: Double) = m *= that + def /=(that: Double) = { - m.foreach(_.vector() /= that) + m ::= { x â x / that } m } /** 1.0 /=: A is equivalent to A = 1.0/A in R */ def /=:(that: Double) = { - m.foreach(that /=: _.vector()) + if (that != 0.0) m := { x â that / x } m } + + def ^=(that: Double) = { + m ::= { x â math.pow(x, that) } + m + } + + def ^(that: Double) = m.cloned ^= that + + def cbind(that: Matrix): Matrix = { + require(m.nrow == that.nrow) + if (m.ncol > 0) { + if (that.ncol > 0) { + val mx = m.like(m.nrow, m.ncol + that.ncol) + mx(::, 0 until m.ncol) := m + mx(::, m.ncol until mx.ncol) := that + mx + } else m + } else that + } + + def cbind(that: Double): Matrix = { + val mx = m.like(m.nrow, m.ncol + 1) + mx(::, 0 until m.ncol) := m + if (that != 0.0) mx(::, m.ncol) := that + mx + } + + def rbind(that: Matrix): Matrix = { + require(m.ncol == that.ncol) + if (m.nrow > 0) { + if (that.nrow > 0) { + val mx = m.like(m.nrow + that.nrow, m.ncol) + mx(0 until m.nrow, ::) := m + mx(m.nrow until mx.nrow, ::) := that + mx + } else m + } else that + } + + def rbind(that: Double): Matrix = { + val mx = m.like(m.nrow + 1, m.ncol) + mx(0 until m.nrow, ::) := m + if (that != 0.0) mx(m.nrow, ::) := that + mx + } }
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeOps.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeOps.scala b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeOps.scala index ba32304..e10a01b 100644 --- a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeOps.scala +++ b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeOps.scala @@ -24,13 +24,13 @@ import org.apache.mahout.math.{Vector, MatrixTimesOps, Matrix} */ object RLikeOps { - implicit def double2Scalar(x:Double) = new DoubleScalarOps(x) + implicit def double2Scalar(x:Double) = new RLikeDoubleScalarOps(x) implicit def v2vOps(v: Vector) = new RLikeVectorOps(v) implicit def el2elOps(el: Vector.Element) = new ElementOps(el) - implicit def times2timesOps(m: MatrixTimesOps) = new RLikeTimesOps(m) + implicit def el2Double(el:Vector.Element) = el.get() implicit def m2mOps(m: Matrix) = new RLikeMatrixOps(m) http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeTimesOps.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeTimesOps.scala b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeTimesOps.scala deleted file mode 100644 index 51f0f63..0000000 --- a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeTimesOps.scala +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.math.scalabindings - -import org.apache.mahout.math.{Matrix, MatrixTimesOps} - -class RLikeTimesOps(m: MatrixTimesOps) { - - def :%*%(that: Matrix) = m.timesRight(that) - - def %*%:(that: Matrix) = m.timesLeft(that) - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeVectorOps.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeVectorOps.scala b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeVectorOps.scala index d2198bd..38a55d6 100644 --- a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeVectorOps.scala +++ b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeVectorOps.scala @@ -17,7 +17,7 @@ package org.apache.mahout.math.scalabindings -import org.apache.mahout.math.Vector +import org.apache.mahout.math.{Matrix, Vector} import org.apache.mahout.math.function.Functions import RLikeOps._ @@ -67,5 +67,32 @@ class RLikeVectorOps(_v: Vector) extends VectorOps(_v) { /** Elementwise right-associative / */ def /:(that: Vector) = that.cloned /= v + def ^=(that: Double) = v.assign(Functions.POW, that) + + def ^=(that: Vector) = v.assign(that, Functions.POW) + + def ^(that: Double) = v.cloned ^= that + + def ^(that: Vector) = v.cloned ^= that + + def c(that: Vector) = { + if (v.length > 0) { + if (that.length > 0) { + val cv = v.like(v.length + that.length) + cv(0 until v.length) := cv + cv(v.length until cv.length) := that + cv + } else v + } else that + } + + def c(that: Double) = { + val cv = v.like(v.length + 1) + cv(0 until v.length) := v + cv(v.length) = that + cv + } + + def mean = sum / length } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/VectorOps.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/VectorOps.scala b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/VectorOps.scala index c20354d..ef9c494 100644 --- a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/VectorOps.scala +++ b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/VectorOps.scala @@ -38,8 +38,13 @@ class VectorOps(private[scalabindings] val v: Vector) { def update(r: Range, that: Vector) = apply(r) := that + /** R-like synonyms for java methods on vectors */ def sum = v.zSum() + def min = v.minValue() + + def max = v.maxValue() + def :=(that: Vector): Vector = { // assign op in Mahout requires same @@ -58,11 +63,30 @@ class VectorOps(private[scalabindings] val v: Vector) { def :=(that: Double): Vector = v.assign(that) + /** Functional assigment for a function with index and x */ def :=(f: (Int, Double) => Double): Vector = { for (i <- 0 until length) v(i) = f(i, v(i)) v } + /** Functional assignment for a function with just x (e.g. v := math.exp _) */ + def :=(f:(Double)=>Double):Vector = { + for (i <- 0 until length) v(i) = f(v(i)) + v + } + + /** Sparse iteration functional assignment using function receiving index and x */ + def ::=(f: (Int, Double) => Double): Vector = { + for (el <- v.nonZeroes) el := f(el.index, el.get) + v + } + + /** Sparse iteration functional assignment using a function recieving just x */ + def ::=(f: (Double) => Double): Vector = { + for (el <- v.nonZeroes) el := f(el.get) + v + } + def equiv(that: Vector) = length == that.length && v.all.view.zip(that.all).forall(t => t._1.get == t._2.get) @@ -121,21 +145,26 @@ class VectorOps(private[scalabindings] val v: Vector) { } class ElementOps(private[scalabindings] val el: Vector.Element) { + import RLikeOps._ + + def update(v: Double): Double = { el.set(v); v } + + def :=(that: Double) = update(that) - def apply = el.get() + def *(that: Vector.Element): Double = this * that - def update(v: Double) = el.set(v) + def *(that: Vector): Vector = el.get * that - def :=(v: Double) = el.set(v) + def +(that: Vector.Element): Double = this + that - def +(that: Double) = el.get() + that + def +(that: Vector) :Vector = el.get + that - def -(that: Double) = el.get() - that + def /(that: Vector.Element): Double = this / that - def :-(that: Double) = that - el.get() + def /(that:Vector):Vector = el.get / that - def /(that: Double) = el.get() / that + def -(that: Vector.Element): Double = this - that - def :/(that: Double) = that / el.get() + def -(that: Vector) :Vector = el.get - that } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/package.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/package.scala b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/package.scala index 36f5103..20dc9cd 100644 --- a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/package.scala +++ b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/package.scala @@ -18,12 +18,15 @@ package org.apache.mahout.math import org.apache.mahout.math.solver.EigenDecomposition +import collection._ +import JavaConversions._ /** * Mahout matrices and vectors' scala syntactic sugar */ package object scalabindings { + // Reserved "ALL" range final val `::`: Range = null @@ -125,7 +128,6 @@ package object scalabindings { val data = for (r <- rows) yield { r match { case n: Number => Array(n.doubleValue()) - case t: Product => t.productIterator.map(_.asInstanceOf[Number].doubleValue()).toArray case t: Vector => Array.tabulate(t.length)(t(_)) case t: Array[Double] => t case t: Iterable[_] => @@ -138,6 +140,7 @@ package object scalabindings { } return m } + case t: Product => t.productIterator.map(_.asInstanceOf[Number].doubleValue()).toArray case t: Array[Array[Double]] => if (rows.size == 1) return new DenseMatrix(t) else @@ -164,7 +167,7 @@ package object scalabindings { * (0,5)::(9,3)::Nil, * (2,3.5)::(7,8)::Nil * ) - * + * * }}} * * @param rows @@ -172,11 +175,18 @@ package object scalabindings { */ def sparse(rows: Vector*): SparseRowMatrix = { - import MatrixOps._ + import RLikeOps._ val nrow = rows.size val ncol = rows.map(_.size()).max val m = new SparseRowMatrix(nrow, ncol) - m := rows + m := rows.map { row => + if (row.length < ncol) { + val newRow = row.like(ncol) + newRow(0 until row.length) := row + newRow + } + else row + } m } @@ -249,23 +259,23 @@ package object scalabindings { (qrdec.getQ, qrdec.getR) } - /** - * Solution <tt>X</tt> of <tt>A*X = B</tt> using QR-Decomposition, where <tt>A</tt> is a square, non-singular matrix. + /** + * Solution <tt>X</tt> of <tt>A*X = B</tt> using QR-Decomposition, where <tt>A</tt> is a square, non-singular matrix. * * @param a * @param b * @return (X) */ def solve(a: Matrix, b: Matrix): Matrix = { - import MatrixOps._ - if (a.nrow != a.ncol) { - throw new IllegalArgumentException("supplied matrix A is not square") - } - val qr = new QRDecomposition(a cloned) - if (!qr.hasFullRank) { - throw new IllegalArgumentException("supplied matrix A is singular") - } - qr.solve(b) + import MatrixOps._ + if (a.nrow != a.ncol) { + throw new IllegalArgumentException("supplied matrix A is not square") + } + val qr = new QRDecomposition(a cloned) + if (!qr.hasFullRank) { + throw new IllegalArgumentException("supplied matrix A is singular") + } + qr.solve(b) } /** @@ -293,5 +303,46 @@ package object scalabindings { x(::, 0) } + /////////////////////////////////////////////////////////// + // Elementwise unary functions. Actually this requires creating clones to avoid side effects. For + // efficiency reasons one may want to actually do in-place exression assignments instead, e.g. + // + // m := exp _ + + import RLikeOps._ + import scala.math._ + + def mexp(m: Matrix): Matrix = m.cloned := exp _ + + def vexp(v: Vector): Vector = v.cloned := exp _ + + def mlog(m: Matrix): Matrix = m.cloned := log _ + + def vlog(v: Vector): Vector = v.cloned := log _ + + def mabs(m: Matrix): Matrix = m.cloned ::= (abs(_: Double)) + + def vabs(v: Vector): Vector = v.cloned ::= (abs(_: Double)) + + def msqrt(m: Matrix): Matrix = m.cloned ::= sqrt _ + + def vsqrt(v: Vector): Vector = v.cloned ::= sqrt _ + + def msignum(m: Matrix): Matrix = m.cloned ::= (signum(_: Double)) + + def vsignum(v: Vector): Vector = v.cloned ::= (signum(_: Double)) + + ////////////////////////////////////////////////////////// + // operation funcs + + + /** Matrix-matrix unary func */ + type MMUnaryFunc = (Matrix, Option[Matrix]) => Matrix + /** Binary matrix-matrix operations which may save result in-place, optionally */ + type MMBinaryFunc = (Matrix, Matrix, Option[Matrix]) => Matrix + type MVBinaryFunc = (Matrix, Vector, Option[Matrix]) => Matrix + type VMBinaryFunc = (Vector, Matrix, Option[Matrix]) => Matrix + type MDBinaryFunc = (Matrix, Double, Option[Matrix]) => Matrix + } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/main/scala/org/apache/mahout/util/IOUtilsScala.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/util/IOUtilsScala.scala b/math-scala/src/main/scala/org/apache/mahout/util/IOUtilsScala.scala new file mode 100644 index 0000000..b61bea4 --- /dev/null +++ b/math-scala/src/main/scala/org/apache/mahout/util/IOUtilsScala.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.util + +import org.apache.mahout.logging._ +import collection._ +import java.io.Closeable + +object IOUtilsScala { + + private final implicit val log = getLog(IOUtilsScala.getClass) + + /** + * Try to close every resource in the sequence, in order of the sequence. + * + * Report all encountered exceptions to logging. + * + * Rethrow last exception only (if any) + * @param closeables + */ + def close(closeables: Seq[Closeable]) = { + + var lastThr: Option[Throwable] = None + closeables.foreach { c => + try { + c.close() + } catch { + case t: Throwable => + error(t.getMessage, t) + lastThr = Some(t) + } + } + + // Rethrow most recent close exception (can throw only one) + lastThr.foreach(throw _) + } + + /** + * Same as [[IOUtilsScala.close( )]] but do not re-throw any exceptions. + * @param closeables + */ + def closeQuietly(closeables: Seq[Closeable]) = { + try { + close(closeables) + } catch { + case t: Throwable => // NOP + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeOpsSuiteBase.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeOpsSuiteBase.scala b/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeOpsSuiteBase.scala index 849db68..bb42121 100644 --- a/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeOpsSuiteBase.scala +++ b/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeOpsSuiteBase.scala @@ -46,6 +46,26 @@ trait DrmLikeOpsSuiteBase extends DistributedMahoutSuite with Matchers { } + test("allReduceBlock") { + + val mxA = dense((1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)) + val drmA = drmParallelize(mxA, numPartitions = 2) + + try { + val mxB = drmA.allreduceBlock { case (keys, block) â + block(::, 0 until 2).t %*% block(::, 2 until 3) + } + + val mxControl = mxA(::, 0 until 2).t %*% mxA(::, 2 until 3) + + (mxB - mxControl).norm should be < 1e-10 + + } catch { + case e: UnsupportedOperationException â // Some engines may not support this, so ignore. + } + + } + test("col range") { val inCoreA = dense((1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)) val A = drmParallelize(m = inCoreA, numPartitions = 2) http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeSuiteBase.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeSuiteBase.scala b/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeSuiteBase.scala index 6c9313c..f215fb7 100644 --- a/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeSuiteBase.scala +++ b/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeSuiteBase.scala @@ -68,9 +68,8 @@ trait DrmLikeSuiteBase extends DistributedMahoutSuite with Matchers { inCoreEmpty.nrow shouldBe 100 inCoreEmpty.ncol shouldBe 50 + } - } - } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/test/scala/org/apache/mahout/math/drm/RLikeDrmOpsSuiteBase.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/drm/RLikeDrmOpsSuiteBase.scala b/math-scala/src/test/scala/org/apache/mahout/math/drm/RLikeDrmOpsSuiteBase.scala index 2e6204d..b46ee30 100644 --- a/math-scala/src/test/scala/org/apache/mahout/math/drm/RLikeDrmOpsSuiteBase.scala +++ b/math-scala/src/test/scala/org/apache/mahout/math/drm/RLikeDrmOpsSuiteBase.scala @@ -24,7 +24,13 @@ import scalabindings._ import RLikeOps._ import RLikeDrmOps._ import decompositions._ -import org.apache.mahout.math.drm.logical.{OpAtB, OpAtA, OpAtx} +import org.apache.mahout.math.drm.logical._ +import org.apache.mahout.math.drm.logical.OpAtx +import org.apache.mahout.math.drm.logical.OpAtB +import org.apache.mahout.math.drm.logical.OpAtA +import org.apache.mahout.math.drm.logical.OpAewUnaryFuncFusion + +import scala.util.Random /** Common engine tests for distributed R-like DRM operations */ trait RLikeDrmOpsSuiteBase extends DistributedMahoutSuite with Matchers { @@ -188,10 +194,13 @@ trait RLikeDrmOpsSuiteBase extends DistributedMahoutSuite with Matchers { val A = drmParallelize(inCoreA, numPartitions = 2) .mapBlock()({ - case (keys, block) => keys.map(_.toString) -> block + case (keys, block) â keys.map(_.toString) â block }) - val B = A + 1.0 + // Dense-A' x sparse-B used to produce error. We sparsify B here to test this as well. + val B = (A + 1.0).mapBlock() { case (keys, block) â + keys â (new SparseRowMatrix(block.nrow, block.ncol) := block) + } val C = A.t %*% B @@ -204,6 +213,25 @@ trait RLikeDrmOpsSuiteBase extends DistributedMahoutSuite with Matchers { } + test ("C = A %*% B.t") { + + val inCoreA = dense((1, 2), (3, 4), (-3, -5)) + + val A = drmParallelize(inCoreA, numPartitions = 2) + + val B = A + 1.0 + + val C = A %*% B.t + + mahoutCtx.optimizerRewrite(C) should equal(OpABt[Int](A, B)) + + val inCoreC = C.collect + val inCoreControlC = inCoreA %*% (inCoreA + 1.0).t + + (inCoreC - inCoreControlC).norm should be < 1E-10 + + } + test("C = A %*% inCoreB") { val inCoreA = dense((1, 2, 3), (3, 4, 5), (4, 5, 6), (5, 6, 7)) @@ -503,6 +531,24 @@ trait RLikeDrmOpsSuiteBase extends DistributedMahoutSuite with Matchers { } + test("B = 1 cbind A") { + val inCoreA = dense((1, 2), (3, 4)) + val control = dense((1, 1, 2), (1, 3, 4)) + + val drmA = drmParallelize(inCoreA, numPartitions = 2) + + (control - (1 cbind drmA) ).norm should be < 1e-10 + } + + test("B = A cbind 1") { + val inCoreA = dense((1, 2), (3, 4)) + val control = dense((1, 2, 1), (3, 4, 1)) + + val drmA = drmParallelize(inCoreA, numPartitions = 2) + + (control - (drmA cbind 1) ).norm should be < 1e-10 + } + test("B = A + 1.0") { val inCoreA = dense((1, 2), (2, 3), (3, 4)) val controlB = inCoreA + 1.0 @@ -547,4 +593,46 @@ trait RLikeDrmOpsSuiteBase extends DistributedMahoutSuite with Matchers { (10 * drmA - (10 *: drmA)).norm shouldBe 0 } + + test("A * A -> sqr(A) rewrite ") { + val mxA = dense( + (1, 2, 3), + (3, 4, 5), + (7, 8, 9) + ) + + val mxAAControl = mxA * mxA + + val drmA = drmParallelize(mxA, 2) + val drmAA = drmA * drmA + + val optimized = drmAA.context.engine.optimizerRewrite(drmAA) + println(s"optimized:$optimized") + optimized.isInstanceOf[OpAewUnaryFunc[Int]] shouldBe true + + (mxAAControl -= drmAA).norm should be < 1e-10 + } + + test("B = 1 + 2 * (A * A) ew unary function fusion") { + val mxA = dense( + (1, 2, 3), + (3, 0, 5) + ) + val controlB = mxA.cloned := { (x) => 1 + 2 * x * x} + + val drmA = drmParallelize(mxA, 2) + + // We need to use parenthesis, otherwise optimizer will see it as (2A) * (A) and that would not + // be rewritten as 2 * sqr(A). It is not that clever (yet) to try commutativity optimizations. + val drmB = 1 + 2 * (drmA * drmA) + + val optimized = mahoutCtx.engine.optimizerRewrite(drmB) + println(s"optimizer rewritten:$optimized") + optimized.isInstanceOf[OpAewUnaryFuncFusion[Int]] shouldBe true + + (controlB - drmB).norm should be < 1e-10 + + } + + } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/MatrixOpsSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/MatrixOpsSuite.scala b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/MatrixOpsSuite.scala index d7b22d9..5c8a310 100644 --- a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/MatrixOpsSuite.scala +++ b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/MatrixOpsSuite.scala @@ -24,6 +24,8 @@ import org.apache.mahout.test.MahoutSuite import org.apache.mahout.math.{RandomAccessSparseVector, SequentialAccessSparseVector, Matrices} import org.apache.mahout.common.RandomUtils +import scala.util.Random + class MatrixOpsSuite extends FunSuite with MahoutSuite { @@ -93,12 +95,40 @@ class MatrixOpsSuite extends FunSuite with MahoutSuite { val e = eye(5) - printf("I(5)=\n%s\n", e) + println(s"I(5)=\n$e") a(0 to 1, 1 to 2) = dense((3, 2), (2, 3)) a(0 to 1, 1 to 2) := dense((3, 2), (2, 3)) + println(s"a=$a") + + a(0 to 1, 1 to 2) := { _ => 45} + println(s"a=$a") + +// a(0 to 1, 1 to 2) ::= { _ => 44} + println(s"a=$a") + + // Sparse assignment to a sparse block + val c = sparse(0 -> 1 :: Nil, 2 -> 2 :: Nil, 1 -> 5 :: Nil) + val d = c.cloned + + println(s"d=$d") + d.ncol shouldBe 3 + d(::, 1 to 2) ::= { _ => 4} + println(s"d=$d") + d(::, 1 to 2).sum shouldBe 8 + + d ::= {_ => 5} + d.sum shouldBe 15 + + val f = c.cloned.t + f ::= {_ => 6} + f.sum shouldBe 18 + + val g = c.cloned + g(::, 1 until g.nrow) ::= { x => if (x <= 0) 0.0 else 1.0} + g.sum shouldBe 3 } test("sparse") { @@ -182,4 +212,5 @@ class MatrixOpsSuite extends FunSuite with MahoutSuite { } + } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala index a943c5f..79d2899 100644 --- a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala +++ b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala @@ -17,9 +17,16 @@ package org.apache.mahout.math.scalabindings +import java.util + +import org.apache.log4j.Level +import org.apache.mahout.math._ import org.scalatest.FunSuite import RLikeOps._ import org.apache.mahout.test.MahoutSuite +import org.apache.mahout.logging._ +import scala.collection.JavaConversions._ +import scala.util.Random class RLikeMatrixOpsSuite extends FunSuite with MahoutSuite { @@ -63,6 +70,10 @@ class RLikeMatrixOpsSuite extends FunSuite with MahoutSuite { } + test("Uniform view") { + val mxUnif = Matrices.symmetricUniformView(5000000, 5000000, 1234) + } + /** Test dsl overloads over scala operations over matrices */ test ("scalarOps") { val a = dense( @@ -77,4 +88,269 @@ class RLikeMatrixOpsSuite extends FunSuite with MahoutSuite { } + test("Multiplication experimental performance") { + + getLog(MMul.getClass).setLevel(Level.DEBUG) + + val d = 300 + val n = 3 + + // Dense row-wise + val mxAd = new DenseMatrix(d, d) := Matrices.gaussianView(d, d, 134) + 1 + val mxBd = new DenseMatrix(d, d) := Matrices.gaussianView(d, d, 134) - 1 + + val rnd = new Random(1234) + + // Sparse rows + val mxAsr = (new SparseRowMatrix(d, + d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() + 1 else 0.0 }) cloned + val mxBsr = (new SparseRowMatrix(d, + d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() - 1 else 0.0 }) cloned + + // Hanging sparse rows + val mxAs = (new SparseMatrix(d, d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() + 1 else 0.0 }) cloned + val mxBs = (new SparseMatrix(d, d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() - 1 else 0.0 }) cloned + + // DIAGONAL + val mxD = diagv(dvec(Array.tabulate(d)(_ => rnd.nextGaussian()))) + + def time(op: => Unit): Long = { + val ms = System.currentTimeMillis() + op + System.currentTimeMillis() - ms + } + + def getMmulAvgs(mxA: Matrix, mxB: Matrix, n: Int) = { + + var control: Matrix = null + var mmulVal: Matrix = null + + val current = Stream.range(0, n).map { _ => time {control = mxA.times(mxB)} }.sum.toDouble / n + val experimental = Stream.range(0, n).map { _ => time {mmulVal = MMul(mxA, mxB, None)} }.sum.toDouble / n + (control - mmulVal).norm should be < 1e-10 + current -> experimental + } + + // Dense matrix tests. + println(s"Ad %*% Bd: ${getMmulAvgs(mxAd, mxBd, n)}") + println(s"Ad' %*% Bd: ${getMmulAvgs(mxAd.t, mxBd, n)}") + println(s"Ad %*% Bd': ${getMmulAvgs(mxAd, mxBd.t, n)}") + println(s"Ad' %*% Bd': ${getMmulAvgs(mxAd.t, mxBd.t, n)}") + println(s"Ad'' %*% Bd'': ${getMmulAvgs(mxAd.t.t, mxBd.t.t, n)}") + println + + // Sparse row matrix tests. + println(s"Asr %*% Bsr: ${getMmulAvgs(mxAsr, mxBsr, n)}") + println(s"Asr' %*% Bsr: ${getMmulAvgs(mxAsr.t, mxBsr, n)}") + println(s"Asr %*% Bsr': ${getMmulAvgs(mxAsr, mxBsr.t, n)}") + println(s"Asr' %*% Bsr': ${getMmulAvgs(mxAsr.t, mxBsr.t, n)}") + println(s"Asr'' %*% Bsr'': ${getMmulAvgs(mxAsr.t.t, mxBsr.t.t, n)}") + println + + // Sparse matrix tests. + println(s"Asm %*% Bsm: ${getMmulAvgs(mxAs, mxBs, n)}") + println(s"Asm' %*% Bsm: ${getMmulAvgs(mxAs.t, mxBs, n)}") + println(s"Asm %*% Bsm': ${getMmulAvgs(mxAs, mxBs.t, n)}") + println(s"Asm' %*% Bsm': ${getMmulAvgs(mxAs.t, mxBs.t, n)}") + println(s"Asm'' %*% Bsm'': ${getMmulAvgs(mxAs.t.t, mxBs.t.t, n)}") + println + + // Mixed sparse matrix tests. + println(s"Asm %*% Bsr: ${getMmulAvgs(mxAs, mxBsr, n)}") + println(s"Asm' %*% Bsr: ${getMmulAvgs(mxAs.t, mxBsr, n)}") + println(s"Asm %*% Bsr': ${getMmulAvgs(mxAs, mxBsr.t, n)}") + println(s"Asm' %*% Bsr': ${getMmulAvgs(mxAs.t, mxBsr.t, n)}") + println(s"Asm'' %*% Bsr'': ${getMmulAvgs(mxAs.t.t, mxBsr.t.t, n)}") + println + + println(s"Asr %*% Bsm: ${getMmulAvgs(mxAsr, mxBs, n)}") + println(s"Asr' %*% Bsm: ${getMmulAvgs(mxAsr.t, mxBs, n)}") + println(s"Asr %*% Bsm': ${getMmulAvgs(mxAsr, mxBs.t, n)}") + println(s"Asr' %*% Bsm': ${getMmulAvgs(mxAsr.t, mxBs.t, n)}") + println(s"Asr'' %*% Bsm'': ${getMmulAvgs(mxAsr.t.t, mxBs.t.t, n)}") + println + + // Mixed dense/sparse + println(s"Ad %*% Bsr: ${getMmulAvgs(mxAd, mxBsr, n)}") + println(s"Ad' %*% Bsr: ${getMmulAvgs(mxAd.t, mxBsr, n)}") + println(s"Ad %*% Bsr': ${getMmulAvgs(mxAd, mxBsr.t, n)}") + println(s"Ad' %*% Bsr': ${getMmulAvgs(mxAd.t, mxBsr.t, n)}") + println(s"Ad'' %*% Bsr'': ${getMmulAvgs(mxAd.t.t, mxBsr.t.t, n)}") + println + + println(s"Asr %*% Bd: ${getMmulAvgs(mxAsr, mxBd, n)}") + println(s"Asr' %*% Bd: ${getMmulAvgs(mxAsr.t, mxBd, n)}") + println(s"Asr %*% Bd': ${getMmulAvgs(mxAsr, mxBd.t, n)}") + println(s"Asr' %*% Bd': ${getMmulAvgs(mxAsr.t, mxBd.t, n)}") + println(s"Asr'' %*% Bd'': ${getMmulAvgs(mxAsr.t.t, mxBd.t.t, n)}") + println + + println(s"Ad %*% Bsm: ${getMmulAvgs(mxAd, mxBs, n)}") + println(s"Ad' %*% Bsm: ${getMmulAvgs(mxAd.t, mxBs, n)}") + println(s"Ad %*% Bsm': ${getMmulAvgs(mxAd, mxBs.t, n)}") + println(s"Ad' %*% Bsm': ${getMmulAvgs(mxAd.t, mxBs.t, n)}") + println(s"Ad'' %*% Bsm'': ${getMmulAvgs(mxAd.t.t, mxBs.t.t, n)}") + println + + println(s"Asm %*% Bd: ${getMmulAvgs(mxAs, mxBd, n)}") + println(s"Asm' %*% Bd: ${getMmulAvgs(mxAs.t, mxBd, n)}") + println(s"Asm %*% Bd': ${getMmulAvgs(mxAs, mxBd.t, n)}") + println(s"Asm' %*% Bd': ${getMmulAvgs(mxAs.t, mxBd.t, n)}") + println(s"Asm'' %*% Bd'': ${getMmulAvgs(mxAs.t.t, mxBd.t.t, n)}") + println + + // Diagonal cases + println(s"Ad %*% D: ${getMmulAvgs(mxAd, mxD, n)}") + println(s"Asr %*% D: ${getMmulAvgs(mxAsr, mxD, n)}") + println(s"Asm %*% D: ${getMmulAvgs(mxAs, mxD, n)}") + println(s"D %*% Ad: ${getMmulAvgs(mxD, mxAd, n)}") + println(s"D %*% Asr: ${getMmulAvgs(mxD, mxAsr, n)}") + println(s"D %*% Asm: ${getMmulAvgs(mxD, mxAs, n)}") + println + + println(s"Ad' %*% D: ${getMmulAvgs(mxAd.t, mxD, n)}") + println(s"Asr' %*% D: ${getMmulAvgs(mxAsr.t, mxD, n)}") + println(s"Asm' %*% D: ${getMmulAvgs(mxAs.t, mxD, n)}") + println(s"D %*% Ad': ${getMmulAvgs(mxD, mxAd.t, n)}") + println(s"D %*% Asr': ${getMmulAvgs(mxD, mxAsr.t, n)}") + println(s"D %*% Asm': ${getMmulAvgs(mxD, mxAs.t, n)}") + println + + // Self-squared cases + println(s"Ad %*% Ad': ${getMmulAvgs(mxAd, mxAd.t, n)}") + println(s"Ad' %*% Ad: ${getMmulAvgs(mxAd.t, mxAd, n)}") + println(s"Ad' %*% Ad'': ${getMmulAvgs(mxAd.t, mxAd.t.t, n)}") + println(s"Ad'' %*% Ad': ${getMmulAvgs(mxAd.t.t, mxAd.t, n)}") + + } + + + test("elementwise experimental performance") { + + val d = 500 + val n = 3 + + // Dense row-wise + val mxAd = new DenseMatrix(d, d) := Matrices.gaussianView(d, d, 134) + 1 + val mxBd = new DenseMatrix(d, d) := Matrices.gaussianView(d, d, 134) - 1 + + val rnd = new Random(1234) + + // Sparse rows + val mxAsr = (new SparseRowMatrix(d, + d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() + 1 else 0.0 }) cloned + val mxBsr = (new SparseRowMatrix(d, + d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() - 1 else 0.0 }) cloned + + // Hanging sparse rows + val mxAs = (new SparseMatrix(d, d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() + 1 else 0.0 }) cloned + val mxBs = (new SparseMatrix(d, d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() - 1 else 0.0 }) cloned + + // DIAGONAL + val mxD = diagv(dvec(Array.tabulate(d)(_ => rnd.nextGaussian()))) + + def time(op: => Unit): Long = { + val ms = System.currentTimeMillis() + op + System.currentTimeMillis() - ms + } + + def getEWAvgs(mxA: Matrix, mxB: Matrix, n: Int) = { + + var control: Matrix = null + var mmulVal: Matrix = null + + val current = Stream.range(0, n).map { _ => time {control = mxA + mxB} }.sum.toDouble / n + val experimental = Stream.range(0, n).map { _ => time {mmulVal = mxA + mxB} }.sum.toDouble / n + (control - mmulVal).norm should be < 1e-10 + current -> experimental + } + + // Dense matrix tests. + println(s"Ad + Bd: ${getEWAvgs(mxAd, mxBd, n)}") + println(s"Ad' + Bd: ${getEWAvgs(mxAd.t, mxBd, n)}") + println(s"Ad + Bd': ${getEWAvgs(mxAd, mxBd.t, n)}") + println(s"Ad' + Bd': ${getEWAvgs(mxAd.t, mxBd.t, n)}") + println(s"Ad'' + Bd'': ${getEWAvgs(mxAd.t.t, mxBd.t.t, n)}") + println + + // Sparse row matrix tests. + println(s"Asr + Bsr: ${getEWAvgs(mxAsr, mxBsr, n)}") + println(s"Asr' + Bsr: ${getEWAvgs(mxAsr.t, mxBsr, n)}") + println(s"Asr + Bsr': ${getEWAvgs(mxAsr, mxBsr.t, n)}") + println(s"Asr' + Bsr': ${getEWAvgs(mxAsr.t, mxBsr.t, n)}") + println(s"Asr'' + Bsr'': ${getEWAvgs(mxAsr.t.t, mxBsr.t.t, n)}") + println + + // Sparse matrix tests. + println(s"Asm + Bsm: ${getEWAvgs(mxAs, mxBs, n)}") + println(s"Asm' + Bsm: ${getEWAvgs(mxAs.t, mxBs, n)}") + println(s"Asm + Bsm': ${getEWAvgs(mxAs, mxBs.t, n)}") + println(s"Asm' + Bsm': ${getEWAvgs(mxAs.t, mxBs.t, n)}") + println(s"Asm'' + Bsm'': ${getEWAvgs(mxAs.t.t, mxBs.t.t, n)}") + println + + // Mixed sparse matrix tests. + println(s"Asm + Bsr: ${getEWAvgs(mxAs, mxBsr, n)}") + println(s"Asm' + Bsr: ${getEWAvgs(mxAs.t, mxBsr, n)}") + println(s"Asm + Bsr': ${getEWAvgs(mxAs, mxBsr.t, n)}") + println(s"Asm' + Bsr': ${getEWAvgs(mxAs.t, mxBsr.t, n)}") + println(s"Asm'' + Bsr'': ${getEWAvgs(mxAs.t.t, mxBsr.t.t, n)}") + println + + println(s"Asr + Bsm: ${getEWAvgs(mxAsr, mxBs, n)}") + println(s"Asr' + Bsm: ${getEWAvgs(mxAsr.t, mxBs, n)}") + println(s"Asr + Bsm': ${getEWAvgs(mxAsr, mxBs.t, n)}") + println(s"Asr' + Bsm': ${getEWAvgs(mxAsr.t, mxBs.t, n)}") + println(s"Asr'' + Bsm'': ${getEWAvgs(mxAsr.t.t, mxBs.t.t, n)}") + println + + // Mixed dense/sparse + println(s"Ad + Bsr: ${getEWAvgs(mxAd, mxBsr, n)}") + println(s"Ad' + Bsr: ${getEWAvgs(mxAd.t, mxBsr, n)}") + println(s"Ad + Bsr': ${getEWAvgs(mxAd, mxBsr.t, n)}") + println(s"Ad' + Bsr': ${getEWAvgs(mxAd.t, mxBsr.t, n)}") + println(s"Ad'' + Bsr'': ${getEWAvgs(mxAd.t.t, mxBsr.t.t, n)}") + println + + println(s"Asr + Bd: ${getEWAvgs(mxAsr, mxBd, n)}") + println(s"Asr' + Bd: ${getEWAvgs(mxAsr.t, mxBd, n)}") + println(s"Asr + Bd': ${getEWAvgs(mxAsr, mxBd.t, n)}") + println(s"Asr' + Bd': ${getEWAvgs(mxAsr.t, mxBd.t, n)}") + println(s"Asr'' + Bd'': ${getEWAvgs(mxAsr.t.t, mxBd.t.t, n)}") + println + + println(s"Ad + Bsm: ${getEWAvgs(mxAd, mxBs, n)}") + println(s"Ad' + Bsm: ${getEWAvgs(mxAd.t, mxBs, n)}") + println(s"Ad + Bsm': ${getEWAvgs(mxAd, mxBs.t, n)}") + println(s"Ad' + Bsm': ${getEWAvgs(mxAd.t, mxBs.t, n)}") + println(s"Ad'' + Bsm'': ${getEWAvgs(mxAd.t.t, mxBs.t.t, n)}") + println + + println(s"Asm + Bd: ${getEWAvgs(mxAs, mxBd, n)}") + println(s"Asm' + Bd: ${getEWAvgs(mxAs.t, mxBd, n)}") + println(s"Asm + Bd': ${getEWAvgs(mxAs, mxBd.t, n)}") + println(s"Asm' + Bd': ${getEWAvgs(mxAs.t, mxBd.t, n)}") + println(s"Asm'' + Bd'': ${getEWAvgs(mxAs.t.t, mxBd.t.t, n)}") + println + + // Diagonal cases + println(s"Ad + D: ${getEWAvgs(mxAd, mxD, n)}") + println(s"Asr + D: ${getEWAvgs(mxAsr, mxD, n)}") + println(s"Asm + D: ${getEWAvgs(mxAs, mxD, n)}") + println(s"D + Ad: ${getEWAvgs(mxD, mxAd, n)}") + println(s"D + Asr: ${getEWAvgs(mxD, mxAsr, n)}") + println(s"D + Asm: ${getEWAvgs(mxD, mxAs, n)}") + println + + println(s"Ad' + D: ${getEWAvgs(mxAd.t, mxD, n)}") + println(s"Asr' + D: ${getEWAvgs(mxAsr.t, mxD, n)}") + println(s"Asm' + D: ${getEWAvgs(mxAs.t, mxD, n)}") + println(s"D + Ad': ${getEWAvgs(mxD, mxAd.t, n)}") + println(s"D + Asr': ${getEWAvgs(mxD, mxAsr.t, n)}") + println(s"D + Asm': ${getEWAvgs(mxD, mxAs.t, n)}") + println + + } + } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala index 037f562..d264514 100644 --- a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala +++ b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala @@ -18,10 +18,12 @@ package org.apache.mahout.math.scalabindings import org.scalatest.FunSuite -import org.apache.mahout.math.{RandomAccessSparseVector, Vector} +import org.apache.mahout.math.{SequentialAccessSparseVector, RandomAccessSparseVector, Vector} import RLikeOps._ import org.apache.mahout.test.MahoutSuite +import scala.util.Random + /** VectorOps Suite */ class VectorOpsSuite extends FunSuite with MahoutSuite { @@ -79,4 +81,19 @@ class VectorOpsSuite extends FunSuite with MahoutSuite { } + test("sparse assignment") { + + val svec = new SequentialAccessSparseVector(30) + svec(1) = -0.5 + svec(3) = 0.5 + println(svec) + + svec(1 until svec.length) ::= ( _ => 0) + println(svec) + + svec.sum shouldBe 0 + + + } + } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java b/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java index e752422..a823d0b 100644 --- a/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java +++ b/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java @@ -19,13 +19,16 @@ package org.apache.mahout.math; import com.google.common.collect.AbstractIterator; import com.google.common.collect.Maps; +import org.apache.mahout.math.flavor.BackEnum; +import org.apache.mahout.math.flavor.MatrixFlavor; +import org.apache.mahout.math.flavor.TraversingStructureEnum; import org.apache.mahout.math.function.*; import java.util.Iterator; import java.util.Map; /** - * A few universal implementations of convenience functions + * A few universal implementations of convenience functions for a JVM-backed matrix. */ public abstract class AbstractMatrix implements Matrix { @@ -57,19 +60,24 @@ public abstract class AbstractMatrix implements Matrix { @Override public Iterator<MatrixSlice> iterateAll() { return new AbstractIterator<MatrixSlice>() { - private int slice; + private int row; @Override protected MatrixSlice computeNext() { - if (slice >= numSlices()) { + if (row >= numRows()) { return endOfData(); } - int i = slice++; + int i = row++; return new MatrixSlice(viewRow(i), i); } }; } + @Override + public Iterator<MatrixSlice> iterateNonEmpty() { + return iterator(); + } + /** * Abstracted out for the iterator * @@ -813,4 +821,12 @@ public abstract class AbstractMatrix implements Matrix { return returnString + ("}"); } } + + @Override + public MatrixFlavor getFlavor() { + throw new UnsupportedOperationException("Flavor support not implemented for this matrix."); + } + + ////////////// Matrix flavor trait /////////////////// + } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/ConstantVector.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/ConstantVector.java b/math/src/main/java/org/apache/mahout/math/ConstantVector.java index 86ab82b..847bf85 100644 --- a/math/src/main/java/org/apache/mahout/math/ConstantVector.java +++ b/math/src/main/java/org/apache/mahout/math/ConstantVector.java @@ -132,6 +132,11 @@ public class ConstantVector extends AbstractVector { return new DenseVector(size()); } + @Override + public Vector like(int cardinality) { + return new DenseVector(cardinality); + } + /** * Set the value at the given index, without checking bounds * http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/DelegatingVector.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/DelegatingVector.java b/math/src/main/java/org/apache/mahout/math/DelegatingVector.java index a1fd291..0b2e36b 100644 --- a/math/src/main/java/org/apache/mahout/math/DelegatingVector.java +++ b/math/src/main/java/org/apache/mahout/math/DelegatingVector.java @@ -310,6 +310,11 @@ public class DelegatingVector implements Vector, LengthCachingVector { } @Override + public Vector like(int cardinality) { + return new DelegatingVector(delegate.like(cardinality)); + } + + @Override public void setQuick(int index, double value) { delegate.setQuick(index, value); } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/DenseMatrix.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/DenseMatrix.java b/math/src/main/java/org/apache/mahout/math/DenseMatrix.java index 7f52c00..5c1ee12 100644 --- a/math/src/main/java/org/apache/mahout/math/DenseMatrix.java +++ b/math/src/main/java/org/apache/mahout/math/DenseMatrix.java @@ -17,6 +17,9 @@ package org.apache.mahout.math; +import org.apache.mahout.math.flavor.MatrixFlavor; +import org.apache.mahout.math.flavor.TraversingStructureEnum; + import java.util.Arrays; /** Matrix of doubles implemented using a 2-d array */ @@ -175,5 +178,9 @@ public class DenseMatrix extends AbstractMatrix { } return new DenseVector(values[row], true); } - + + @Override + public MatrixFlavor getFlavor() { + return MatrixFlavor.DENSELIKE; + } } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/DenseSymmetricMatrix.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/DenseSymmetricMatrix.java b/math/src/main/java/org/apache/mahout/math/DenseSymmetricMatrix.java index e9cf3f1..7252b9b 100644 --- a/math/src/main/java/org/apache/mahout/math/DenseSymmetricMatrix.java +++ b/math/src/main/java/org/apache/mahout/math/DenseSymmetricMatrix.java @@ -17,6 +17,8 @@ package org.apache.mahout.math; +import org.apache.mahout.math.flavor.TraversingStructureEnum; + /** * Economy packaging for a dense symmetric in-core matrix. */ http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/DenseVector.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/DenseVector.java b/math/src/main/java/org/apache/mahout/math/DenseVector.java index 5b3dea7..3633e58 100644 --- a/math/src/main/java/org/apache/mahout/math/DenseVector.java +++ b/math/src/main/java/org/apache/mahout/math/DenseVector.java @@ -136,6 +136,11 @@ public class DenseVector extends AbstractVector { } @Override + public Vector like(int cardinality) { + return new DenseVector(cardinality); + } + + @Override public void setQuick(int index, double value) { invalidateCachedLength(); values[index] = value; http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java b/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java index 3e20a4a..070fad2 100644 --- a/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java +++ b/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java @@ -17,6 +17,9 @@ package org.apache.mahout.math; +import org.apache.mahout.math.flavor.MatrixFlavor; +import org.apache.mahout.math.flavor.TraversingStructureEnum; + import java.util.Iterator; import java.util.NoSuchElementException; @@ -223,6 +226,11 @@ public class DiagonalMatrix extends AbstractMatrix implements MatrixTimesOps { } @Override + public Vector like(int cardinality) { + return new DenseVector(cardinality); + } + + @Override public void setQuick(int index, double value) { if (index == this.index) { diagonal.set(this.index, value); @@ -361,4 +369,10 @@ public class DiagonalMatrix extends AbstractMatrix implements MatrixTimesOps { } return m; } + + @Override + public MatrixFlavor getFlavor() { + return MatrixFlavor.DIAGONALLIKE; + } + } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/FileBasedSparseBinaryMatrix.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/FileBasedSparseBinaryMatrix.java b/math/src/main/java/org/apache/mahout/math/FileBasedSparseBinaryMatrix.java index ba09aa8..56600cd 100644 --- a/math/src/main/java/org/apache/mahout/math/FileBasedSparseBinaryMatrix.java +++ b/math/src/main/java/org/apache/mahout/math/FileBasedSparseBinaryMatrix.java @@ -437,6 +437,11 @@ public final class FileBasedSparseBinaryMatrix extends AbstractMatrix { return new RandomAccessSparseVector(size()); } + @Override + public Vector like(int cardinality) { + return new RandomAccessSparseVector(cardinality); + } + /** * Copy the vector for fast operations. * http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/FunctionalMatrixView.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/FunctionalMatrixView.java b/math/src/main/java/org/apache/mahout/math/FunctionalMatrixView.java index 2a13611..9028e23 100644 --- a/math/src/main/java/org/apache/mahout/math/FunctionalMatrixView.java +++ b/math/src/main/java/org/apache/mahout/math/FunctionalMatrixView.java @@ -17,6 +17,9 @@ package org.apache.mahout.math; +import org.apache.mahout.math.flavor.BackEnum; +import org.apache.mahout.math.flavor.MatrixFlavor; +import org.apache.mahout.math.flavor.TraversingStructureEnum; import org.apache.mahout.math.function.IntIntFunction; /** @@ -29,6 +32,7 @@ class FunctionalMatrixView extends AbstractMatrix { */ private IntIntFunction gf; private boolean denseLike; + private MatrixFlavor flavor; public FunctionalMatrixView(int rows, int columns, IntIntFunction gf) { this(rows, columns, gf, false); @@ -42,6 +46,7 @@ class FunctionalMatrixView extends AbstractMatrix { super(rows, columns); this.gf = gf; this.denseLike = denseLike; + flavor = new MatrixFlavor.FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.BLOCKIFIED, denseLike); } @Override @@ -87,4 +92,8 @@ class FunctionalMatrixView extends AbstractMatrix { return new MatrixVectorView(this, 0, column, 1, 0, denseLike); } + @Override + public MatrixFlavor getFlavor() { + return flavor; + } } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/Matrices.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/Matrices.java b/math/src/main/java/org/apache/mahout/math/Matrices.java index 4a0c50c..fc45a16 100644 --- a/math/src/main/java/org/apache/mahout/math/Matrices.java +++ b/math/src/main/java/org/apache/mahout/math/Matrices.java @@ -17,7 +17,9 @@ package org.apache.mahout.math; +import com.google.common.base.Preconditions; import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.flavor.TraversingStructureEnum; import org.apache.mahout.math.function.DoubleFunction; import org.apache.mahout.math.function.Functions; import org.apache.mahout.math.function.IntIntFunction; @@ -63,16 +65,14 @@ public final class Matrices { * @return transposed view of original matrix */ public static final Matrix transposedView(final Matrix m) { - IntIntFunction tf = new IntIntFunction() { - @Override - public double apply(int row, int col) { - return m.getQuick(col, row); - } - }; - // TODO: Matrix api does not support denseLike() interrogation. - // so our guess has to be rough here. - return functionalMatrixView(m.numCols(), m.numRows(), tf, m instanceof DenseMatrix); + Preconditions.checkArgument(!(m instanceof SparseColumnMatrix)); + + if (m instanceof TransposedMatrixView) { + return ((TransposedMatrixView) m).getDelegate(); + } else { + return new TransposedMatrixView(m); + } } /** http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/Matrix.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/Matrix.java b/math/src/main/java/org/apache/mahout/math/Matrix.java index afdbac5..47ba5cf 100644 --- a/math/src/main/java/org/apache/mahout/math/Matrix.java +++ b/math/src/main/java/org/apache/mahout/math/Matrix.java @@ -17,6 +17,7 @@ package org.apache.mahout.math; +import org.apache.mahout.math.flavor.MatrixFlavor; import org.apache.mahout.math.function.DoubleDoubleFunction; import org.apache.mahout.math.function.DoubleFunction; import org.apache.mahout.math.function.VectorFunction; @@ -403,4 +404,10 @@ public interface Matrix extends Cloneable, VectorIterable { * @return A vector that shares storage with the original matrix. */ Vector viewDiagonal(); + + /** + * Get matrix structural flavor (operations performance hints). This is optional operation, may + * throw {@link java.lang.UnsupportedOperationException}. + */ + MatrixFlavor getFlavor(); } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java b/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java index 074d7a6..52ae722 100644 --- a/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java +++ b/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java @@ -211,6 +211,11 @@ public class MatrixVectorView extends AbstractVector { return matrix.like(size(), 1).viewColumn(0); } + @Override + public Vector like(int cardinality) { + return matrix.like(cardinality, 1).viewColumn(0); + } + /** * Set the value at the given index, without checking bounds * http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/MatrixView.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/MatrixView.java b/math/src/main/java/org/apache/mahout/math/MatrixView.java index e2f7f48..86760d5 100644 --- a/math/src/main/java/org/apache/mahout/math/MatrixView.java +++ b/math/src/main/java/org/apache/mahout/math/MatrixView.java @@ -17,6 +17,8 @@ package org.apache.mahout.math; +import org.apache.mahout.math.flavor.MatrixFlavor; + /** Implements subset view of a Matrix */ public class MatrixView extends AbstractMatrix { @@ -151,4 +153,8 @@ public class MatrixView extends AbstractMatrix { return new VectorView(matrix.viewRow(row + offset[ROW]), offset[COL], columnSize()); } + @Override + public MatrixFlavor getFlavor() { + return matrix.getFlavor(); + } } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/NamedVector.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/NamedVector.java b/math/src/main/java/org/apache/mahout/math/NamedVector.java index 0bf49c8..d4fa609 100644 --- a/math/src/main/java/org/apache/mahout/math/NamedVector.java +++ b/math/src/main/java/org/apache/mahout/math/NamedVector.java @@ -177,6 +177,11 @@ public class NamedVector implements Vector { } @Override + public Vector like(int cardinality) { + return new NamedVector(delegate.like(cardinality), name); + } + + @Override public Vector minus(Vector x) { return delegate.minus(x); } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/PermutedVectorView.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/PermutedVectorView.java b/math/src/main/java/org/apache/mahout/math/PermutedVectorView.java index f34f2b0..a76f78c 100644 --- a/math/src/main/java/org/apache/mahout/math/PermutedVectorView.java +++ b/math/src/main/java/org/apache/mahout/math/PermutedVectorView.java @@ -204,6 +204,11 @@ public class PermutedVectorView extends AbstractVector { return vector.like(); } + @Override + public Vector like(int cardinality) { + return vector.like(cardinality); + } + /** * Set the value at the given index, without checking bounds * http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java b/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java index dbe5d3a..3efac7e 100644 --- a/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java +++ b/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java @@ -142,6 +142,11 @@ public class RandomAccessSparseVector extends AbstractVector { } @Override + public Vector like(int cardinality) { + return new RandomAccessSparseVector(cardinality, values.size()); + } + + @Override public int getNumNondefaultElements() { return values.size(); } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java b/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java index 331662c..f7d67a7 100644 --- a/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java +++ b/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java @@ -180,6 +180,11 @@ public class SequentialAccessSparseVector extends AbstractVector { } @Override + public Vector like(int cardinality) { + return new SequentialAccessSparseVector(cardinality); + } + + @Override public int getNumNondefaultElements() { return values.getNumMappings(); } @@ -214,6 +219,8 @@ public class SequentialAccessSparseVector extends AbstractVector { @Override public Iterator<Element> iterateNonZero() { + + // TODO: this is a bug, since nonDefaultIterator doesn't hold to non-zero contract. return new NonDefaultIterator(); } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/SparseColumnMatrix.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/SparseColumnMatrix.java b/math/src/main/java/org/apache/mahout/math/SparseColumnMatrix.java index f62d553..eeffc78 100644 --- a/math/src/main/java/org/apache/mahout/math/SparseColumnMatrix.java +++ b/math/src/main/java/org/apache/mahout/math/SparseColumnMatrix.java @@ -17,9 +17,13 @@ package org.apache.mahout.math; +import org.apache.mahout.math.flavor.TraversingStructureEnum; + /** * sparse matrix with general element values whose columns are accessible quickly. Implemented as a column array of * SparseVectors. + * + * @deprecated tons of inconsistences. Use transpose view of SparseRowMatrix for fast column-wise iteration. */ public class SparseColumnMatrix extends AbstractMatrix { @@ -31,11 +35,19 @@ public class SparseColumnMatrix extends AbstractMatrix { * @param columns a RandomAccessSparseVector[] array of columns * @param columnVectors */ - public SparseColumnMatrix(int rows, int columns, RandomAccessSparseVector[] columnVectors) { + public SparseColumnMatrix(int rows, int columns, Vector[] columnVectors) { + this(rows, columns, columnVectors, false); + } + + public SparseColumnMatrix(int rows, int columns, Vector[] columnVectors, boolean shallow) { super(rows, columns); - this.columnVectors = columnVectors.clone(); - for (int col = 0; col < columnSize(); col++) { - this.columnVectors[col] = this.columnVectors[col].clone(); + if (shallow) { + this.columnVectors = columnVectors; + } else { + this.columnVectors = columnVectors.clone(); + for (int col = 0; col < columnSize(); col++) { + this.columnVectors[col] = this.columnVectors[col].clone(); + } } } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/SparseMatrix.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/SparseMatrix.java b/math/src/main/java/org/apache/mahout/math/SparseMatrix.java index 88e15a0..bf4f1a0 100644 --- a/math/src/main/java/org/apache/mahout/math/SparseMatrix.java +++ b/math/src/main/java/org/apache/mahout/math/SparseMatrix.java @@ -18,6 +18,8 @@ package org.apache.mahout.math; import com.google.common.collect.AbstractIterator; +import org.apache.mahout.math.flavor.MatrixFlavor; +import org.apache.mahout.math.flavor.TraversingStructureEnum; import org.apache.mahout.math.function.DoubleDoubleFunction; import org.apache.mahout.math.function.Functions; import org.apache.mahout.math.function.IntObjectProcedure; @@ -40,11 +42,23 @@ public class SparseMatrix extends AbstractMatrix { * @param columns * @param rowVectors */ - public SparseMatrix(int rows, int columns, Map<Integer, RandomAccessSparseVector> rowVectors) { + public SparseMatrix(int rows, int columns, Map<Integer, Vector> rowVectors) { + this(rows, columns, rowVectors, false); + } + + public SparseMatrix(int rows, int columns, Map<Integer, Vector> rowVectors, boolean shallow) { + + // Why this is passing in a map? iterating it is pretty inefficient as opposed to simple lists... super(rows, columns); this.rowVectors = new OpenIntObjectHashMap<Vector>(); - for (Map.Entry<Integer, RandomAccessSparseVector> entry : rowVectors.entrySet()) { - this.rowVectors.put(entry.getKey(), entry.getValue().clone()); + if (shallow) { + for (Map.Entry<Integer, Vector> entry : rowVectors.entrySet()) { + this.rowVectors.put(entry.getKey(), entry.getValue()); + } + } else { + for (Map.Entry<Integer, Vector> entry : rowVectors.entrySet()) { + this.rowVectors.put(entry.getKey(), entry.getValue().clone()); + } } } @@ -66,7 +80,11 @@ public class SparseMatrix extends AbstractMatrix { } @Override - public Iterator<MatrixSlice> iterator() { + public int numSlices() { + return rowVectors.size(); + } + + public Iterator<MatrixSlice> iterateNonEmpty() { final IntArrayList keys = new IntArrayList(rowVectors.size()); rowVectors.keys(keys); return new AbstractIterator<MatrixSlice>() { @@ -221,4 +239,8 @@ public class SparseMatrix extends AbstractMatrix { return rowVectors.keys(); } + @Override + public MatrixFlavor getFlavor() { + return MatrixFlavor.SPARSEROWLIKE; + } } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java index 3021f3b..6e06769 100644 --- a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java +++ b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java @@ -17,6 +17,8 @@ package org.apache.mahout.math; +import org.apache.mahout.math.flavor.MatrixFlavor; +import org.apache.mahout.math.flavor.TraversingStructureEnum; import org.apache.mahout.math.function.Functions; /** @@ -226,4 +228,9 @@ public class SparseRowMatrix extends AbstractMatrix { } } } + + @Override + public MatrixFlavor getFlavor() { + return MatrixFlavor.SPARSELIKE; + } } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/TransposedMatrixView.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/TransposedMatrixView.java b/math/src/main/java/org/apache/mahout/math/TransposedMatrixView.java new file mode 100644 index 0000000..c67cb47 --- /dev/null +++ b/math/src/main/java/org/apache/mahout/math/TransposedMatrixView.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math; + +import org.apache.mahout.math.flavor.BackEnum; +import org.apache.mahout.math.flavor.MatrixFlavor; +import org.apache.mahout.math.flavor.TraversingStructureEnum; +import org.apache.mahout.math.function.DoubleDoubleFunction; +import org.apache.mahout.math.function.DoubleFunction; + +/** + * Matrix View backed by an {@link org.apache.mahout.math.function.IntIntFunction} + */ +class TransposedMatrixView extends AbstractMatrix { + + private Matrix m; + + public TransposedMatrixView(Matrix m) { + super(m.numCols(), m.numRows()); + this.m = m; + } + + @Override + public Matrix assignColumn(int column, Vector other) { + m.assignRow(column,other); + return this; + } + + @Override + public Matrix assignRow(int row, Vector other) { + m.assignColumn(row,other); + return this; + } + + @Override + public double getQuick(int row, int column) { + return m.getQuick(column,row); + } + + @Override + public Matrix like() { + return m.like(rows, columns); + } + + @Override + public Matrix like(int rows, int columns) { + return m.like(rows,columns); + } + + @Override + public void setQuick(int row, int column, double value) { + m.setQuick(column, row, value); + } + + @Override + public Vector viewRow(int row) { + return m.viewColumn(row); + } + + @Override + public Vector viewColumn(int column) { + return m.viewRow(column); + } + + @Override + public Matrix assign(double value) { + return m.assign(value); + } + + @Override + public Matrix assign(Matrix other, DoubleDoubleFunction function) { + if (other instanceof TransposedMatrixView) { + m.assign(((TransposedMatrixView) other).m, function); + } else { + m.assign(new TransposedMatrixView(other), function); + } + return this; + } + + @Override + public Matrix assign(Matrix other) { + if (other instanceof TransposedMatrixView) { + return m.assign(((TransposedMatrixView) other).m); + } else { + return m.assign(new TransposedMatrixView(other)); + } + } + + @Override + public Matrix assign(DoubleFunction function) { + return m.assign(function); + } + + @Override + public MatrixFlavor getFlavor() { + return flavor; + } + + private MatrixFlavor flavor = new MatrixFlavor() { + @Override + public BackEnum getBacking() { + return m.getFlavor().getBacking(); + } + + @Override + public TraversingStructureEnum getStructure() { + TraversingStructureEnum flavor = m.getFlavor().getStructure(); + switch (flavor) { + case COLWISE: + return TraversingStructureEnum.ROWWISE; + case SPARSECOLWISE: + return TraversingStructureEnum.SPARSEROWWISE; + case ROWWISE: + return TraversingStructureEnum.COLWISE; + case SPARSEROWWISE: + return TraversingStructureEnum.SPARSECOLWISE; + default: + return flavor; + } + } + + @Override + public boolean isDense() { + return m.getFlavor().isDense(); + } + }; + + Matrix getDelegate() { + return m; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/UpperTriangular.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/UpperTriangular.java b/math/src/main/java/org/apache/mahout/math/UpperTriangular.java index a0cb3cd..29fa6a0 100644 --- a/math/src/main/java/org/apache/mahout/math/UpperTriangular.java +++ b/math/src/main/java/org/apache/mahout/math/UpperTriangular.java @@ -17,6 +17,10 @@ package org.apache.mahout.math; +import org.apache.mahout.math.flavor.BackEnum; +import org.apache.mahout.math.flavor.MatrixFlavor; +import org.apache.mahout.math.flavor.TraversingStructureEnum; + /** * * Quick and dirty implementation of some {@link org.apache.mahout.math.Matrix} methods @@ -148,4 +152,9 @@ public class UpperTriangular extends AbstractMatrix { return values; } + @Override + public MatrixFlavor getFlavor() { + // We kind of consider ourselves a vector-backed but dense matrix for mmul, etc. purposes. + return new MatrixFlavor.FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.VECTORBACKED, true); + } } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/Vector.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/Vector.java b/math/src/main/java/org/apache/mahout/math/Vector.java index 0d1a003..4480b0a 100644 --- a/math/src/main/java/org/apache/mahout/math/Vector.java +++ b/math/src/main/java/org/apache/mahout/math/Vector.java @@ -190,6 +190,14 @@ public interface Vector extends Cloneable { Vector like(); /** + * Return a new empty vector of the same underlying class as the receiver with given cardinality + * + * @param cardinality + * @return + */ + Vector like(int cardinality); + + /** * Return a new vector containing the element by element difference of the recipient and the argument * * @param x a Vector http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/VectorIterable.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/VectorIterable.java b/math/src/main/java/org/apache/mahout/math/VectorIterable.java index 451c589..8414fdb 100644 --- a/math/src/main/java/org/apache/mahout/math/VectorIterable.java +++ b/math/src/main/java/org/apache/mahout/math/VectorIterable.java @@ -21,8 +21,12 @@ import java.util.Iterator; public interface VectorIterable extends Iterable<MatrixSlice> { + /* Iterate all rows in order */ Iterator<MatrixSlice> iterateAll(); + /* Iterate all non empty rows in arbitrary order */ + Iterator<MatrixSlice> iterateNonEmpty(); + int numSlices(); int numRows(); http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/VectorView.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/VectorView.java b/math/src/main/java/org/apache/mahout/math/VectorView.java index b503712..d61a038 100644 --- a/math/src/main/java/org/apache/mahout/math/VectorView.java +++ b/math/src/main/java/org/apache/mahout/math/VectorView.java @@ -69,6 +69,11 @@ public class VectorView extends AbstractVector { } @Override + public Vector like(int cardinality) { + return vector.like(cardinality); + } + + @Override public double getQuick(int index) { return vector.getQuick(offset + index); } @@ -122,7 +127,7 @@ public class VectorView extends AbstractVector { while (it.hasNext()) { Element el = it.next(); if (isInView(el.index()) && el.get() != 0) { - Element decorated = vector.getElement(el.index()); + Element decorated = el; /* vector.getElement(el.index()); */ return new DecoratorElement(decorated); } } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/flavor/BackEnum.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/flavor/BackEnum.java b/math/src/main/java/org/apache/mahout/math/flavor/BackEnum.java new file mode 100644 index 0000000..1782f04 --- /dev/null +++ b/math/src/main/java/org/apache/mahout/math/flavor/BackEnum.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.flavor; + +/** + * Matrix backends + */ +public enum BackEnum { + JVMMEM, + NETLIB_BLAS +} http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java b/math/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java new file mode 100644 index 0000000..2b5c444 --- /dev/null +++ b/math/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.flavor; + +/** A set of matrix structure properties that I denote as "flavor" (by analogy to quarks) */ +public interface MatrixFlavor { + + /** + * Whether matrix is backed by a native system -- such as java memory, lapack/atlas, Magma etc. + */ + BackEnum getBacking(); + + /** + * Structure flavors + */ + TraversingStructureEnum getStructure() ; + + boolean isDense(); + + /** + * This default for {@link org.apache.mahout.math.DenseMatrix}-like structures + */ + static final MatrixFlavor DENSELIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.ROWWISE, true); + /** + * This is default flavor for {@link org.apache.mahout.math.SparseRowMatrix}-like. + */ + static final MatrixFlavor SPARSELIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.ROWWISE, false); + + /** + * This is default flavor for {@link org.apache.mahout.math.SparseMatrix}-like structures, i.e. sparse matrix blocks, + * where few, perhaps most, rows may be missing entirely. + */ + static final MatrixFlavor SPARSEROWLIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.SPARSEROWWISE, false); + + /** + * This is default flavor for {@link org.apache.mahout.math.DiagonalMatrix} and the likes. + */ + static final MatrixFlavor DIAGONALLIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.VECTORBACKED, false); + + static final class FlavorImpl implements MatrixFlavor { + private BackEnum pBacking; + private TraversingStructureEnum pStructure; + private boolean pDense; + + public FlavorImpl(BackEnum backing, TraversingStructureEnum structure, boolean dense) { + pBacking = backing; + pStructure = structure; + pDense = dense; + } + + @Override + public BackEnum getBacking() { + return pBacking; + } + + @Override + public TraversingStructureEnum getStructure() { + return pStructure; + } + + @Override + public boolean isDense() { + return pDense; + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java b/math/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java new file mode 100644 index 0000000..13c2cf4 --- /dev/null +++ b/math/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.flavor; + +/** STRUCTURE HINT */ +public enum TraversingStructureEnum { + + UNKNOWN, + + /** + * Backing vectors are directly available as row views. + */ + ROWWISE, + + /** + * Column vectors are directly available as column views. + */ + COLWISE, + + /** + * Only some row-wise vectors are really present (can use iterateNonEmpty). Corresponds to + * [[org.apache.mahout.math.SparseMatrix]]. + */ + SPARSEROWWISE, + + SPARSECOLWISE, + + SPARSEHASH, + + VECTORBACKED, + + BLOCKIFIED +} http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/test/java/org/apache/mahout/math/MatricesTest.java ---------------------------------------------------------------------- diff --git a/math/src/test/java/org/apache/mahout/math/MatricesTest.java b/math/src/test/java/org/apache/mahout/math/MatricesTest.java index 1b6169e..9405429 100644 --- a/math/src/test/java/org/apache/mahout/math/MatricesTest.java +++ b/math/src/test/java/org/apache/mahout/math/MatricesTest.java @@ -65,8 +65,8 @@ public class MatricesTest extends MahoutTestCase { m.set(1, 1, 33.0); Matrix mt = Matrices.transposedView(m); - assertTrue(!mt.viewColumn(0).isDense()); - assertTrue(!mt.viewRow(0).isDense()); + assertTrue(mt.viewColumn(0).isDense() == m.viewRow(0).isDense()); + assertTrue(mt.viewRow(0).isDense() == m.viewColumn(0).isDense()); m = new DenseMatrix(10,10); m.set(1, 1, 33.0); http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/mr/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java b/mr/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java index 1a6ff16..de5e216 100644 --- a/mr/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java @@ -133,6 +133,11 @@ public class DistributedRowMatrix implements VectorIterable, Configurable { } @Override + public Iterator<MatrixSlice> iterateNonEmpty() { + return iterator(); + } + + @Override public Iterator<MatrixSlice> iterateAll() { try { Path pathPattern = rowPath; http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java index 7033efe..af79cb4 100644 --- a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java @@ -586,6 +586,11 @@ public class GivensThinSolver { } @Override + public Vector like(int cardinality) { + throw new UnsupportedOperationException(); + } + + @Override public void setQuick(int index, double value) { viewed.setQuick(rowNum, index, value); http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala ---------------------------------------------------------------------- diff --git a/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala b/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala index 5ffc18c..4d0615a 100644 --- a/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala +++ b/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala @@ -12,13 +12,14 @@ class MahoutSparkILoop extends SparkILoop { private val postInitScript = "import org.apache.mahout.math._" :: - "import scalabindings._" :: - "import RLikeOps._" :: - "import drm._" :: - "import RLikeDrmOps._" :: - "import org.apache.mahout.sparkbindings._" :: - "import collection.JavaConversions._" :: - Nil + "import scalabindings._" :: + "import RLikeOps._" :: + "import drm._" :: + "import RLikeDrmOps._" :: + "import decompositions._" :: + "import org.apache.mahout.sparkbindings._" :: + "import collection.JavaConversions._" :: + Nil override protected def postInitialization() { super.postInitialization() http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/pom.xml ---------------------------------------------------------------------- diff --git a/spark/pom.xml b/spark/pom.xml index 33e0d1b..7155115 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -119,6 +119,22 @@ </executions> </plugin> + <!-- create test jar so other modules can reuse the math test utility classes. + DO NOT REMOVE! Testing framework is useful in subordinate/contrib projects! + Please contact @dlyubimov. + --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <executions> + <execution> + <goals> + <goal>test-jar</goal> + </goals> + <phase>package</phase> + </execution> + </executions> + </plugin> </plugins> </build> http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/common/DrmMetadata.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/common/DrmMetadata.scala b/spark/src/main/scala/org/apache/mahout/common/DrmMetadata.scala index 5bbccb1..0aba319 100644 --- a/spark/src/main/scala/org/apache/mahout/common/DrmMetadata.scala +++ b/spark/src/main/scala/org/apache/mahout/common/DrmMetadata.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.mahout.common import scala.reflect.ClassTag http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/common/HDFSUtil.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/common/HDFSUtil.scala b/spark/src/main/scala/org/apache/mahout/common/HDFSUtil.scala index f5f87d7..c949f92 100644 --- a/spark/src/main/scala/org/apache/mahout/common/HDFSUtil.scala +++ b/spark/src/main/scala/org/apache/mahout/common/HDFSUtil.scala @@ -17,10 +17,12 @@ package org.apache.mahout.common +import org.apache.spark.SparkContext + /** High level Hadoop version-specific hdfs manipulations we need in context of our operations. */ trait HDFSUtil { /** Read DRM header information off (H)DFS. */ - def readDrmHeader(path:String):DrmMetadata + def readDrmHeader(path:String)(implicit sc:SparkContext):DrmMetadata } http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/common/Hadoop1HDFSUtil.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/common/Hadoop1HDFSUtil.scala b/spark/src/main/scala/org/apache/mahout/common/Hadoop1HDFSUtil.scala index 047104a..399508d 100644 --- a/spark/src/main/scala/org/apache/mahout/common/Hadoop1HDFSUtil.scala +++ b/spark/src/main/scala/org/apache/mahout/common/Hadoop1HDFSUtil.scala @@ -17,10 +17,10 @@ package org.apache.mahout.common - import org.apache.hadoop.io.{Writable, SequenceFile} import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.conf.Configuration +import org.apache.spark.SparkContext import collection._ import JavaConversions._ @@ -30,14 +30,16 @@ import JavaConversions._ */ object Hadoop1HDFSUtil extends HDFSUtil { - /** - * Read the header of a sequence file and determine the Key and Value type - * @param path - * @return - */ - def readDrmHeader(path: String): DrmMetadata = { + + /** Read DRM header information off (H)DFS. */ + override def readDrmHeader(path: String)(implicit sc: SparkContext): DrmMetadata = { + val dfsPath = new Path(path) - val fs = dfsPath.getFileSystem(new Configuration()) + + val fs = dfsPath.getFileSystem(sc.hadoopConfiguration) + + // Apparently getFileSystem() doesn't set conf?? + fs.setConf(sc.hadoopConfiguration) val partFilePath:Path = fs.listStatus(dfsPath)
