[FLINK-1718] [ml] Adds sparse matrix and sparse vector types
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/9219af7b Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/9219af7b Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/9219af7b Branch: refs/heads/master Commit: 9219af7b63321ea78af67579bdc68eecd895acaa Parents: c635802 Author: Till Rohrmann <trohrm...@apache.org> Authored: Wed Mar 25 15:27:58 2015 +0100 Committer: Till Rohrmann <trohrm...@apache.org> Committed: Wed Apr 1 10:56:47 2015 +0200 ---------------------------------------------------------------------- flink-staging/flink-ml/pom.xml | 6 +- .../org/apache/flink/ml/math/DenseMatrix.scala | 125 +++++++++- .../org/apache/flink/ml/math/DenseVector.scala | 42 ++-- .../scala/org/apache/flink/ml/math/Matrix.scala | 58 +++-- .../org/apache/flink/ml/math/SparseMatrix.scala | 235 +++++++++++++++++++ .../org/apache/flink/ml/math/SparseVector.scala | 156 ++++++++++++ .../scala/org/apache/flink/ml/math/Vector.scala | 52 ++-- .../org/apache/flink/ml/math/package.scala | 6 +- .../apache/flink/ml/math/DenseMatrixSuite.scala | 69 ------ .../apache/flink/ml/math/DenseMatrixTest.scala | 89 +++++++ .../apache/flink/ml/math/DenseVectorSuite.scala | 50 ---- .../apache/flink/ml/math/DenseVectorTest.scala | 52 ++++ .../apache/flink/ml/math/SparseMatrixTest.scala | 94 ++++++++ .../apache/flink/ml/math/SparseVectorTest.scala | 79 +++++++ 14 files changed, 926 insertions(+), 187 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/pom.xml ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/pom.xml b/flink-staging/flink-ml/pom.xml index 4f251e5..899d266 100644 --- a/flink-staging/flink-ml/pom.xml +++ b/flink-staging/flink-ml/pom.xml @@ -41,9 +41,9 @@ </dependency> <dependency> - <groupId>com.github.fommil.netlib</groupId> - <artifactId>core</artifactId> - <version>1.1.2</version> + <groupId>org.scalanlp</groupId> + <artifactId>breeze_2.10</artifactId> + <version>0.11.1</version> </dependency> <dependency> http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala index f3bd630..72eae05 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala @@ -24,13 +24,15 @@ package org.apache.flink.ml.math * * @param numRows Number of rows * @param numCols Number of columns - * @param values Array of matrix elements in column major order + * @param data Array of matrix elements in column major order */ case class DenseMatrix(val numRows: Int, val numCols: Int, - val values: Array[Double]) extends Matrix { + val data: Array[Double]) extends Matrix { - require(numRows * numCols == values.length, s"The number of values ${values.length} does " + + import DenseMatrix._ + + require(numRows * numCols == data.length, s"The number of values ${data.length} does " + s"not correspond to its dimensions ($numRows, $numCols).") /** @@ -41,32 +43,129 @@ case class DenseMatrix(val numRows: Int, * @return matrix entry at (row, col) */ override def apply(row: Int, col: Int): Double = { - require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).") - require(0 <= col && col < numCols, s"Col $col is out of bounds [0, $numCols).") + val index = locate(row, col) - val index = col * numRows + row - - values(index) + data(index) } override def toString: String = { - s"DenseMatrix($numRows, $numCols, ${values.mkString(", ")})" + val result = StringBuilder.newBuilder + result.append(s"DenseMatrix($numRows, $numCols)\n") + + val linewidth = LINE_WIDTH + + val columnsFieldWidths = for(row <- 0 until math.min(numRows, MAX_ROWS)) yield { + var column = 0 + var maxFieldWidth = 0 + + while(column * maxFieldWidth < linewidth && column < numCols) { + val fieldWidth = printEntry(row, column).length + 2 + + if(fieldWidth > maxFieldWidth) { + maxFieldWidth = fieldWidth + } + + if(column * maxFieldWidth < linewidth) { + column += 1 + } + } + + (column, maxFieldWidth) + } + + val (columns, fieldWidths) = columnsFieldWidths.unzip + + val maxColumns = columns.min + val fieldWidth = fieldWidths.max + + for(row <- 0 until math.min(numRows, MAX_ROWS)) { + for(col <- 0 until maxColumns) { + val str = printEntry(row, col) + + result.append(" " * (fieldWidth - str.length) + str) + } + + if(maxColumns < numCols) { + result.append("...") + } + + result.append("\n") + } + + if(numRows > MAX_ROWS) { + result.append("...\n") + } + + result.toString() } override def equals(obj: Any): Boolean = { obj match { case dense: DenseMatrix => - numRows == dense.numRows && numCols == dense.numCols && values.zip(dense.values).forall { - case (a, b) => a == b - } - case _ => false + numRows == dense.numRows && numCols == dense.numCols && data.sameElements(dense.data) + case _ => super.equals(obj) + } + } + + /** Element wise update function + * + * @param row row index + * @param col column index + * @param value value to set at (row, col) + */ + override def update(row: Int, col: Int, value: Double): Unit = { + val index = locate(row, col) + + data(index) = value + } + + def toSparseMatrix: SparseMatrix = { + val entries = for(row <- 0 until numRows; col <- 0 until numCols) yield { + (row, col, apply(row, col)) } + + SparseMatrix.fromCOO(numRows, numCols, entries.filter(_._3 != 0)) } + /** Calculates the linear index of the respective matrix entry + * + * @param row + * @param col + * @return + */ + private def locate(row: Int, col: Int): Int = { + require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).") + require(0 <= col && col < numCols, s"Col $col is out of bounds [0, $numCols).") + + row + col * numRows + } + + /** Converts the entry at (row, col) to string + * + * @param row + * @param col + * @return + */ + private def printEntry(row: Int, col: Int): String = { + val index = locate(row, col) + + data(index).toString + } + + /** Copies the matrix instance + * + * @return Copy of itself + */ + override def copy: DenseMatrix = { + new DenseMatrix(numRows, numCols, data.clone) + } } object DenseMatrix { + val LINE_WIDTH = 100 + val MAX_ROWS = 50 + def apply(numRows: Int, numCols: Int, values: Array[Int]): DenseMatrix = { new DenseMatrix(numRows, numCols, values.map(_.toDouble)) } http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala index d407a70..6d41d47 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala @@ -22,16 +22,16 @@ package org.apache.flink.ml.math * Dense vector implementation of [[Vector]]. The data is represented in a continuous array of * doubles. * - * @param values Array of doubles to store the vector elements + * @param data Array of doubles to store the vector elements */ -case class DenseVector(val values: Array[Double]) extends Vector { +case class DenseVector(val data: Array[Double]) extends Vector { /** * Number of elements in a vector * @return */ override def size: Int = { - values.length + data.length } /** @@ -41,23 +41,19 @@ case class DenseVector(val values: Array[Double]) extends Vector { * @return element at the given index */ override def apply(index: Int): Double = { - require(0 <= index && index < values.length, s"Index $index is out of bounds " + - s"[0, ${values.length})") - values(index) + require(0 <= index && index < data.length, s"Index $index is out of bounds " + + s"[0, ${data.length})") + data(index) } override def toString: String = { - s"DenseVector(${values.mkString(", ")})" + s"DenseVector(${data.mkString(", ")})" } override def equals(obj: Any): Boolean = { obj match { - case dense: DenseVector => - values.length == dense.values.length && values.zip(dense.values).forall{ - case (a,b) => a == b - } - - case _ => false + case dense: DenseVector => data.length == dense.data.length && data.sameElements(dense.data) + case _ => super.equals(obj) } } @@ -67,7 +63,25 @@ case class DenseVector(val values: Array[Double]) extends Vector { * @return Copy of the vector instance */ override def copy: Vector = { - DenseVector(values.clone()) + DenseVector(data.clone()) + } + + /** Updates the element at the given index with the provided value + * + * @param index + * @param value + */ + override def update(index: Int, value: Double): Unit = { + require(0 <= index && index < data.length, s"Index $index is out of bounds " + + s"[0, ${data.length})") + + data(index) = value + } + + def toSparseVector: SparseVector = { + val nonZero = (0 until size).zip(data).filter(_._2 != 0) + + SparseVector.fromCOO(size, nonZero) } } http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala index 62ea85a..11b4e55 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala @@ -18,28 +18,52 @@ package org.apache.flink.ml.math -/** - * Base trait for a matrix representation - */ +/** Base trait for a matrix representation + * + */ trait Matrix { - /** - * Number of rows - * @return - */ + /** Number of rows + * + * @return + */ def numRows: Int - /** - * Number of columns - * @return - */ + /** Number of columns + * + * @return + */ def numCols: Int - /** - * Element wise access function - * @param row row index - * @param col column index - * @return matrix entry at (row, col) - */ + /** Element wise access function + * + * @param row row index + * @param col column index + * @return matrix entry at (row, col) + */ def apply(row: Int, col: Int): Double + + /** Element wise update function + * + * @param row row index + * @param col column index + * @param value value to set at (row, col) + */ + def update(row: Int, col: Int, value: Double): Unit + + /** Copies the matrix instance + * + * @return Copy of itself + */ + def copy: Matrix + + override def equals(obj: Any): Boolean = { + obj match { + case matrix: Matrix if numRows == matrix.numRows && numCols == matrix.numCols => + val coordinates = for(row <- 0 until numRows; col <- 0 until numCols) yield (row, col) + coordinates forall { case(row, col) => this.apply(row, col) == matrix(row, col)} + case _ => false + } + } + } http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala new file mode 100644 index 0000000..a46202c --- /dev/null +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import scala.util.Sorting + +/** Sparse matrix using the compressed sparse column (CSC) representation. + * + * More details concerning the compressed sparse column (CSC) representation can be found + * [http://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_column_.28CSC_or_CCS.29]. + * + * @param numRows Number of rows + * @param numCols Number of columns + * @param rowIndices Array containing the row indices of non-zero entries + * @param colPtrs Array containing the starting offsets in data for each column + * @param data Array containing the non-zero entries in column-major order + */ +class SparseMatrix( + val numRows: Int, + val numCols: Int, + val rowIndices: Array[Int], + val colPtrs: Array[Int], + val data: Array[Double]) + extends Matrix { + + /** Element wise access function + * + * @param row row index + * @param col column index + * @return matrix entry at (row, col) + */ + override def apply(row: Int, col: Int): Double = { + + val index = locate(row, col) + + if(index < 0){ + 0 + } else { + data(index) + } + } + + def toDenseMatrix: DenseMatrix = { + val result = DenseMatrix.zeros(numRows, numCols) + + for(row <- 0 until numRows; col <- 0 until numCols) { + result(row, col) = apply(row, col) + } + + result + } + + /** Element wise update function + * + * @param row row index + * @param col column index + * @param value value to set at (row, col) + */ + override def update(row: Int, col: Int, value: Double): Unit = { + val index = locate(row, col) + + if(index < 0) { + throw new IllegalArgumentException("Cannot update zero value of sparse matrix at index " + + s"($row, $col)") + } else { + data(index) = value + } + } + + override def toString: String = { + val result = StringBuilder.newBuilder + + result.append(s"SparseMatrix($numRows, $numCols)\n") + + var columnIndex = 0 + + val fieldWidth = math.max(numRows, numCols).toString.length + val valueFieldWidth = data.map(_.toString.length).max + 2 + + for(index <- 0 until colPtrs.last) { + while(colPtrs(columnIndex + 1) <= index){ + columnIndex += 1 + } + + val rowStr = rowIndices(index).toString + val columnStr = columnIndex.toString + val valueStr = data(index).toString + + result.append("(" + " " * (fieldWidth - rowStr.length) + rowStr + "," + + " " * (fieldWidth - columnStr.length) + columnStr + ")") + result.append(" " * (valueFieldWidth - valueStr.length) + valueStr) + result.append("\n") + } + + result.toString + } + + private def locate(row: Int, col: Int): Int = { + require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).") + require(0 <= col && col < numCols, s"Col $col is out of bounds [0, $numCols).") + + val startIndex = colPtrs(col) + val endIndex = colPtrs(col+1) + + java.util.Arrays.binarySearch(rowIndices, startIndex, endIndex, row) + } + + /** Copies the matrix instance + * + * @return Copy of itself + */ + override def copy: SparseMatrix = { + new SparseMatrix(numRows, numCols, rowIndices.clone, colPtrs.clone(), data.clone) + } +} + +object SparseMatrix{ + + /** Constructs a sparse matrix from a coordinate list (COO) representation where each entry + * is stored as a tuple of (rowIndex, columnIndex, value). + * @param numRows + * @param numCols + * @param entries + * @return + */ + def fromCOO(numRows: Int, numCols: Int, entries: (Int, Int, Double)*): SparseMatrix = { + fromCOO(numRows, numCols, entries) + } + + /** Constructs a sparse matrix from a coordinate list (COO) representation where each entry + * is stored as a tuple of (rowIndex, columnIndex, value). + * + * @param numRows + * @param numCols + * @param entries + * @return + */ + def fromCOO(numRows: Int, numCols: Int, entries: Iterable[(Int, Int, Double)]): SparseMatrix = { + val entryArray = entries.toArray + + entryArray.foreach{ case (row, col, _) => + require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).") + require(0 <= col && col < numCols, s"Columm $col is out of bounds [0, $numCols).") + } + + val COOOrdering = new Ordering[(Int, Int, Double)] { + override def compare(x: (Int, Int, Double), y: (Int, Int, Double)): Int = { + if(x._2 < y._2) { + -1 + } else if(x._2 > y._2) { + 1 + } else { + x._1 - y._1 + } + } + } + + Sorting.quickSort(entryArray)(COOOrdering) + + val nnz = entryArray.length + + val data = new Array[Double](nnz) + val rowIndices = new Array[Int](nnz) + val colPtrs = new Array[Int](numCols + 1) + + var (lastRow, lastCol, lastValue) = entryArray(0) + + rowIndices(0) = lastRow + data(0) = lastValue + + var i = 1 + var lastDataIndex = 0 + + while(i < nnz) { + val (curRow, curCol, curValue) = entryArray(i) + + if(lastRow == curRow && lastCol == curCol) { + // add values with identical coordinates + data(lastDataIndex) += curValue + } else { + lastDataIndex += 1 + data(lastDataIndex) = curValue + rowIndices(lastDataIndex) = curRow + lastRow = curRow + } + + while(lastCol < curCol) { + lastCol += 1 + colPtrs(lastCol) = lastDataIndex + } + + i += 1 + } + + lastDataIndex += 1 + while(lastCol < numCols) { + colPtrs(lastCol + 1) = lastDataIndex + lastCol += 1 + } + + val prunedRowIndices = if(lastDataIndex < nnz) { + val prunedArray = new Array[Int](lastDataIndex) + rowIndices.copyToArray(prunedArray) + prunedArray + } else { + rowIndices + } + + val prunedData = if(lastDataIndex < nnz) { + val prunedArray = new Array[Double](lastDataIndex) + data.copyToArray(prunedArray) + prunedArray + } else { + data + } + + new SparseMatrix(numRows, numCols, prunedRowIndices, colPtrs, prunedData) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala new file mode 100644 index 0000000..93da362 --- /dev/null +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import scala.util.Sorting + +/** Sparse vector implementation storing the data in two arrays. One index contains the sorted + * indices of the non-zero vector entries and the other the corresponding vector entries + */ +class SparseVector( + val size: Int, + val indices: Array[Int], + val data: Array[Double]) + extends Vector { + /** Updates the element at the given index with the provided value + * + * @param index + * @param value + */ + override def update(index: Int, value: Double): Unit = { + val resolvedIndex = locate(index) + + if (resolvedIndex < 0) { + throw new IllegalArgumentException("Cannot update zero value of sparse vector at index " + + index) + } else { + data(resolvedIndex) = value + } + } + + /** Copies the vector instance + * + * @return Copy of the vector instance + */ + override def copy: Vector = { + new SparseVector(size, indices.clone, data.clone) + } + + /** Element wise access function + * + * * @param index index of the accessed element + * @return element with index + */ + override def apply(index: Int): Double = { + val resolvedIndex = locate(index) + + if(resolvedIndex < 0) { + 0 + } else { + data(resolvedIndex) + } + } + + def toDenseVector: DenseVector = { + val denseVector = DenseVector.zeros(size) + + for(index <- 0 until size) { + denseVector(index) = this(index) + } + + denseVector + } + + private def locate(index: Int): Int = { + require(0 <= index && index < size, s"Index $index is out of bounds [0, $size).") + + java.util.Arrays.binarySearch(indices, 0, indices.length, index) + } +} + +object SparseVector { + + /** Constructs a sparse vector from a coordinate list (COO) representation where each entry + * is stored as a tuple of (index, value). + * + * @param size + * @param entries + * @return + */ + def fromCOO(size: Int, entries: (Int, Double)*): SparseVector = { + fromCOO(size, entries) + } + + /** Constructs a sparse vector from a coordinate list (COO) representation where each entry + * is stored as a tuple of (index, value). + * + * @param size + * @param entries + * @return + */ + def fromCOO(size: Int, entries: Iterable[(Int, Double)]): SparseVector = { + val entryArray = entries.toArray + + val COOOrdering = new Ordering[(Int, Double)] { + override def compare(x: (Int, Double), y: (Int, Double)): Int = { + x._1 - y._1 + } + } + + Sorting.quickSort(entryArray)(COOOrdering) + + // calculate size of the array + val arraySize = entryArray.foldLeft((-1, 0)){ case ((lastIndex, numRows), (index, _)) => + if(lastIndex == index) { + (lastIndex, numRows) + } else { + (index, numRows + 1) + } + }._2 + + val indices = new Array[Int](arraySize) + val data = new Array[Double](arraySize) + + val (index, value) = entryArray(0) + + indices(0) = index + data(0) = value + + var i = 1 + var lastIndex = indices(0) + var lastDataIndex = 0 + + while(i < entryArray.length) { + val (curIndex, curValue) = entryArray(i) + + if(curIndex == lastIndex) { + data(lastDataIndex) += curValue + } else { + lastDataIndex += 1 + data(lastDataIndex) = curValue + indices(lastDataIndex) = curIndex + lastIndex = curIndex + } + + i += 1 + } + + new SparseVector(size, indices, data) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala index 20d820c..7e7c32c 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala @@ -18,29 +18,45 @@ package org.apache.flink.ml.math -/** - * Base trait for Vectors - */ +/** Base trait for Vectors + * + */ trait Vector { - /** - * Number of elements in a vector - * @return - */ + /** Number of elements in a vector + * + * @return + */ def size: Int - /** - * Element wise access function - * - * @param index index of the accessed element - * @return element with index - */ + /** Element wise access function + * + * * @param index index of the accessed element + * @return element with index + */ def apply(index: Int): Double - /** - * Copies the vector instance - * - * @return Copy of the vector instance - */ + /** Updates the element at the given index with the provided value + * + * @param index + * @param value + */ + def update(index: Int, value: Double): Unit + + /** Copies the vector instance + * + * @return Copy of the vector instance + */ def copy: Vector + + override def equals(obj: Any): Boolean = { + obj match { + case vector: Vector if size == vector.size => + 0 until size forall { idx => + this(idx) == vector(idx) + } + + case _ => false + } + } } http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala index e82e38f..3ab6143 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala @@ -27,7 +27,7 @@ package object math { override def iterator: Iterator[Double] = { matrix match { - case dense: DenseMatrix => dense.values.iterator + case dense: DenseMatrix => dense.data.iterator } } } @@ -35,14 +35,14 @@ package object math { implicit class RichVector(vector: Vector) extends Iterable[Double] { override def iterator: Iterator[Double] = { vector match { - case dense: DenseVector => dense.values.iterator + case dense: DenseVector => dense.data.iterator } } } implicit def vector2Array(vector: Vector): Array[Double] = { vector match { - case dense: DenseVector => dense.values + case dense: DenseVector => dense.data } } } http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala deleted file mode 100644 index be5db08..0000000 --- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.math - -import org.scalatest.FlatSpec - -class DenseMatrixSuite extends FlatSpec { - - behavior of "A DenseMatrix" - - it should "contain the initialization data after intialization" in { - val numRows = 10 - val numCols = 13 - - val data = Array.range(0, numRows*numCols) - - val matrix = DenseMatrix(numRows, numCols, data) - - assertResult(numRows)(matrix.numRows) - assertResult(numCols)(matrix.numCols) - - for(row <- 0 until numRows; col <- 0 until numCols) { - assertResult(data(col*numRows + row))(matrix(row, col)) - } - } - - it should "throw an IllegalArgumentException in case of an invalid index access" in { - val numRows = 10 - val numCols = 13 - - val matrix = DenseMatrix.zeros(numRows, numCols) - - intercept[IllegalArgumentException] { - matrix(-1, 2) - } - - intercept[IllegalArgumentException] { - matrix(0, -1) - } - - intercept[IllegalArgumentException] { - matrix(numRows, 0) - } - - intercept[IllegalArgumentException] { - matrix(0, numCols) - } - - intercept[IllegalArgumentException] { - matrix(numRows, numCols) - } - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixTest.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixTest.scala new file mode 100644 index 0000000..12001fc --- /dev/null +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixTest.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import org.junit.Test +import org.scalatest.ShouldMatchers + +class DenseMatrixTest extends ShouldMatchers { + + @Test + def testDataAfterInitialization: Unit = { + val numRows = 10 + val numCols = 13 + + val data = Array.range(0, numRows*numCols) + + val matrix = DenseMatrix(numRows, numCols, data) + + assertResult(numRows)(matrix.numRows) + assertResult(numCols)(matrix.numCols) + + for(row <- 0 until numRows; col <- 0 until numCols) { + assertResult(data(col*numRows + row))(matrix(row, col)) + } + } + + @Test + def testIllegalArgumentExceptionInCaseOfInvalidIndexAccess: Unit = { + val numRows = 10 + val numCols = 13 + + val matrix = DenseMatrix.zeros(numRows, numCols) + + intercept[IllegalArgumentException] { + matrix(-1, 2) + } + + intercept[IllegalArgumentException] { + matrix(0, -1) + } + + intercept[IllegalArgumentException] { + matrix(numRows, 0) + } + + intercept[IllegalArgumentException] { + matrix(0, numCols) + } + + intercept[IllegalArgumentException] { + matrix(numRows, numCols) + } + } + + @Test + def testCopy: Unit = { + val numRows = 4 + val numCols = 5 + + val data = Array.range(0, numRows*numCols) + + val denseMatrix = DenseMatrix.apply(numRows, numCols, data) + + val copy = denseMatrix.copy + + + denseMatrix should equal(copy) + + copy(0, 0) = 1 + + denseMatrix should not equal(copy) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorSuite.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorSuite.scala deleted file mode 100644 index ae1e012..0000000 --- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorSuite.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.math - -import org.scalatest.FlatSpec - -class DenseVectorSuite extends FlatSpec { - - behavior of "A DenseVector" - - it should "contain the initialization data after initialization" in { - val data = Array.range(1,10) - - val vector = DenseVector(data) - - assertResult(data.length)(vector.size) - - data.zip(vector).foreach{case (expected, actual) => assertResult(expected)(actual)} - } - - it should "throw an IllegalArgumentException in case of an illegal index access" in { - val size = 10 - - val vector = DenseVector.zeros(size) - - intercept[IllegalArgumentException] { - vector(-1) - } - - intercept[IllegalArgumentException] { - vector(size) - } - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala new file mode 100644 index 0000000..5da9fe2 --- /dev/null +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import org.junit.Test +import org.scalatest.ShouldMatchers + + +class DenseVectorTest extends ShouldMatchers { + + @Test + def testDataAfterInitialization { + val data = Array.range(1,10) + + val vector = DenseVector(data) + + assertResult(data.length)(vector.size) + + data.zip(vector).foreach{case (expected, actual) => assertResult(expected)(actual)} + } + + @Test + def testIllegalArgumentExceptionInCaseOfIllegalIndexAccess { + val size = 10 + + val vector = DenseVector.zeros(size) + + intercept[IllegalArgumentException] { + vector(-1) + } + + intercept[IllegalArgumentException] { + vector(size) + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala new file mode 100644 index 0000000..a0e1d27 --- /dev/null +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import org.junit.Test +import org.scalatest.ShouldMatchers + +class SparseMatrixTest extends ShouldMatchers { + + @Test + def testSparseMatrixFromCOO: Unit = { + val sparseMatrix = SparseMatrix.fromCOO(5, 5, (0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17), + (3, 3, 88), (4 , 2, 99), (1, 4, 91), (3, 4, -1)) + + val expectedSparseMatrix = SparseMatrix.fromCOO(5, 5, (3, 4, 42), (2, 1, 17), (3, 3, 88), + (4, 2, 99), (1, 4, 91)) + + val expectedDenseMatrix = DenseMatrix.zeros(5, 5) + expectedDenseMatrix(3, 4) = 42 + expectedDenseMatrix(2, 1) = 17 + expectedDenseMatrix(3, 3) = 88 + expectedDenseMatrix(4, 2) = 99 + expectedDenseMatrix(1, 4) = 91 + + sparseMatrix should equal(expectedSparseMatrix) + sparseMatrix should equal(expectedDenseMatrix) + + sparseMatrix.toDenseMatrix.data.sameElements(expectedDenseMatrix.data) should be(true) + + sparseMatrix(0, 1) = 10 + + intercept[IllegalArgumentException]{ + sparseMatrix(1, 1) = 1 + } + } + + @Test + def testInvalidIndexAccess: Unit = { + val sparseVector = SparseVector.fromCOO(5, (1, 1), (3, 3), (4, 4)) + + intercept[IllegalArgumentException] { + sparseVector(-1) + } + + intercept[IllegalArgumentException] { + sparseVector(5) + } + + sparseVector(0) should equal(0) + sparseVector(3) should equal(3) + } + + @Test + def testSparseMatrixFromCOOWithInvalidIndices: Unit = { + intercept[IllegalArgumentException]{ + val sparseMatrix = SparseMatrix.fromCOO(5 ,5, (5, 0, 10), (0, 0, 0), (0, 1, 0), (3, 4, 43), + (2, 1, 17)) + } + + intercept[IllegalArgumentException]{ + val sparseMatrix = SparseMatrix.fromCOO(5, 5, (0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17), + (-1, 4, 20)) + } + } + + @Test + def testSparseMatrixCopy: Unit = { + val sparseMatrix = SparseMatrix.fromCOO(4, 4, (0, 1, 2), (2, 3, 1), (2, 0, 42), (1, 3, 3)) + + val copy = sparseMatrix.copy + + sparseMatrix should equal(copy) + + copy(2, 3) = 2 + + sparseMatrix should not equal(copy) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala new file mode 100644 index 0000000..5e514c6 --- /dev/null +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import org.junit.Test +import org.scalatest.ShouldMatchers + +class SparseVectorTest extends ShouldMatchers{ + + @Test + def testDataAfterInitialization: Unit = { + val sparseVector = SparseVector.fromCOO(5, (0, 1), (2, 0), (4, 42), (0, 3)) + val expectedSparseVector = SparseVector.fromCOO(5, (0, 4), (4, 42)) + val expectedDenseVector = DenseVector.zeros(5) + + expectedDenseVector(0) = 4 + expectedDenseVector(4) = 42 + + sparseVector should equal(expectedSparseVector) + sparseVector should equal(expectedDenseVector) + + val denseVector = sparseVector.toDenseVector + + denseVector should equal(expectedDenseVector) + } + + @Test + def testInvalidIndexAccess: Unit = { + val sparseVector = SparseVector.fromCOO(5, (0, 1), (4, 10), (3, 5)) + + intercept[IllegalArgumentException] { + sparseVector(-1) + } + + intercept[IllegalArgumentException] { + sparseVector(5) + } + } + + @Test + def testSparseVectorFromCOOWithInvalidIndices: Unit = { + intercept[IllegalArgumentException] { + val sparseVector = SparseVector.fromCOO(5, (0, 1), (-1, 34), (3, 2)) + } + + intercept[IllegalArgumentException] { + val sparseVector = SparseVector.fromCOO(5, (0, 1), (4,3), (5, 1)) + } + } + + @Test + def testSparseVectorCopy: Unit = { + val sparseVector = SparseVector.fromCOO(5, (0, 1), (4, 3), (3, 2)) + + val copy = sparseVector.copy + + sparseVector should equal(copy) + + copy(3) = 3 + + sparseVector should not equal(copy) + } +}