http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala new file mode 100644 index 0000000..77d2d46 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common + +import scala.collection.mutable + +/** + * Map used to store configuration parameters for algorithms. The parameter + * values are stored in a [[Map]] being identified by a [[Parameter]] object. ParameterMaps can + * be fused. This operation is left associative, meaning that latter ParameterMaps can override + * parameter values defined in a preceding ParameterMap. + * + * @param map Map containing parameter settings + */ +class ParameterMap(val map: mutable.Map[Parameter[_], Any]) extends Serializable { + + def this() = { + this(new mutable.HashMap[Parameter[_], Any]()) + } + + /** + * Adds a new parameter value to the ParameterMap. + * + * @param parameter Key + * @param value Value associated with the given key + * @tparam T Type of value + */ + def add[T](parameter: Parameter[T], value: T): ParameterMap = { + map += (parameter -> value) + this + } + + /** + * Retrieves a parameter value associated to a given key. The value is returned as an Option. + * If there is no value associated to the given key, then the default value of the [[Parameter]] + * is returned. + * + * @param parameter Key + * @tparam T Type of the value to retrieve + * @return Some(value) if an value is associated to the given key, otherwise the default value + * defined by parameter + */ + def get[T](parameter: Parameter[T]): Option[T] = { + if(map.isDefinedAt(parameter)) { + map.get(parameter).asInstanceOf[Option[T]] + } else { + parameter.defaultValue + } + } + + /** + * Retrieves a parameter value associated to a given key. If there is no value contained in the + * map, then the default value of the [[Parameter]] is checked. If the default value is defined, + * then it is returned. If the default is undefined, then a [[NoSuchElementException]] is thrown. + * + * @param parameter Key + * @tparam T Type of value + * @return Value associated with the given key or its default value + */ + def apply[T](parameter: Parameter[T]): T = { + if(map.isDefinedAt(parameter)) { + map(parameter).asInstanceOf[T] + } else { + parameter.defaultValue match { + case Some(value) => value + case None => throw new NoSuchElementException(s"Could not retrieve " + + s"parameter value $parameter.") + } + } + } + + /** + * Adds the parameter values contained in parameters to itself. + * + * @param parameters [[ParameterMap]] containing the parameter values to be added + * @return this after inserting the parameter values from parameters + */ + def ++(parameters: ParameterMap): ParameterMap = { + val result = new ParameterMap(map) + result.map ++= parameters.map + + result + } +} + +object ParameterMap { + val Empty = new ParameterMap + + def apply(): ParameterMap = { + new ParameterMap + } +} + +/** + * Base trait for parameter keys + * + * @tparam T Type of parameter value associated to this parameter key + */ +trait Parameter[T] { + + /** + * Default value of parameter. If no such value exists, then returns [[None]] + */ + val defaultValue: Option[T] +}
http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala new file mode 100644 index 0000000..4628c71 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common + +import org.apache.flink.ml.math.Vector + +// TODO(tvas): This provides an abstraction for the weights +// but at the same time it leads to the creation of many objects as we have to pack and unpack +// the weights and the intercept often during SGD. + +/** This class represents a weight vector with an intercept, as it is required for many supervised + * learning tasks + * @param weights The vector of weights + * @param intercept The intercept (bias) weight + */ +case class WeightVector(weights: Vector, intercept: Double) extends Serializable {} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/WithParameters.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/WithParameters.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/WithParameters.scala new file mode 100644 index 0000000..24ac9e3 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/WithParameters.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common + +/** + * Adds a [[ParameterMap]] which can be used to store configuration values + */ +trait WithParameters { + val parameters = new ParameterMap +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/BLAS.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/BLAS.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/BLAS.scala new file mode 100644 index 0000000..8ea3b65 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/BLAS.scala @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} +import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} + +/** + * BLAS routines for vectors and matrices. + * + * Original code from the Apache Spark project: + * http://git.io/vfZUe + */ +object BLAS extends Serializable { + + @transient private var _f2jBLAS: NetlibBLAS = _ + @transient private var _nativeBLAS: NetlibBLAS = _ + + // For level-1 routines, we use Java implementation. + private def f2jBLAS: NetlibBLAS = { + if (_f2jBLAS == null) { + _f2jBLAS = new F2jBLAS + } + _f2jBLAS + } + + /** + * y += a * x + */ + def axpy(a: Double, x: Vector, y: Vector): Unit = { + require(x.size == y.size) + y match { + case dy: DenseVector => + x match { + case sx: SparseVector => + axpy(a, sx, dy) + case dx: DenseVector => + axpy(a, dx, dy) + case _ => + throw new UnsupportedOperationException( + s"axpy doesn't support x type ${x.getClass}.") + } + case _ => + throw new IllegalArgumentException( + s"axpy only supports adding to a dense vector but got type ${y.getClass}.") + } + } + + /** + * y += a * x + */ + private def axpy(a: Double, x: DenseVector, y: DenseVector): Unit = { + val n = x.size + f2jBLAS.daxpy(n, a, x.data, 1, y.data, 1) + } + + /** + * y += a * x + */ + private def axpy(a: Double, x: SparseVector, y: DenseVector): Unit = { + val xValues = x.data + val xIndices = x.indices + val yValues = y.data + val nnz = xIndices.size + + if (a == 1.0) { + var k = 0 + while (k < nnz) { + yValues(xIndices(k)) += xValues(k) + k += 1 + } + } else { + var k = 0 + while (k < nnz) { + yValues(xIndices(k)) += a * xValues(k) + k += 1 + } + } + } + + /** + * dot(x, y) + */ + def dot(x: Vector, y: Vector): Double = { + require(x.size == y.size, + "BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" + + " x.size = " + x.size + ", y.size = " + y.size) + (x, y) match { + case (dx: DenseVector, dy: DenseVector) => + dot(dx, dy) + case (sx: SparseVector, dy: DenseVector) => + dot(sx, dy) + case (dx: DenseVector, sy: SparseVector) => + dot(sy, dx) + case (sx: SparseVector, sy: SparseVector) => + dot(sx, sy) + case _ => + throw new IllegalArgumentException(s"dot doesn't support (${x.getClass}, ${y.getClass}).") + } + } + + /** + * dot(x, y) + */ + private def dot(x: DenseVector, y: DenseVector): Double = { + val n = x.size + f2jBLAS.ddot(n, x.data, 1, y.data, 1) + } + + /** + * dot(x, y) + */ + private def dot(x: SparseVector, y: DenseVector): Double = { + val xValues = x.data + val xIndices = x.indices + val yValues = y.data + val nnz = xIndices.size + + var sum = 0.0 + var k = 0 + while (k < nnz) { + sum += xValues(k) * yValues(xIndices(k)) + k += 1 + } + sum + } + + /** + * dot(x, y) + */ + private def dot(x: SparseVector, y: SparseVector): Double = { + val xValues = x.data + val xIndices = x.indices + val yValues = y.data + val yIndices = y.indices + val nnzx = xIndices.size + val nnzy = yIndices.size + + var kx = 0 + var ky = 0 + var sum = 0.0 + // y catching x + while (kx < nnzx && ky < nnzy) { + val ix = xIndices(kx) + while (ky < nnzy && yIndices(ky) < ix) { + ky += 1 + } + if (ky < nnzy && yIndices(ky) == ix) { + sum += xValues(kx) * yValues(ky) + ky += 1 + } + kx += 1 + } + sum + } + + /** + * y = x + */ + def copy(x: Vector, y: Vector): Unit = { + val n = y.size + require(x.size == n) + y match { + case dy: DenseVector => + x match { + case sx: SparseVector => + val sxIndices = sx.indices + val sxValues = sx.data + val dyValues = dy.data + val nnz = sxIndices.size + + var i = 0 + var k = 0 + while (k < nnz) { + val j = sxIndices(k) + while (i < j) { + dyValues(i) = 0.0 + i += 1 + } + dyValues(i) = sxValues(k) + i += 1 + k += 1 + } + while (i < n) { + dyValues(i) = 0.0 + i += 1 + } + case dx: DenseVector => + Array.copy(dx.data, 0, dy.data, 0, n) + } + case _ => + throw new IllegalArgumentException(s"y must be dense in copy but got ${y.getClass}") + } + } + + /** + * x = a * x + */ + def scal(a: Double, x: Vector): Unit = { + x match { + case sx: SparseVector => + f2jBLAS.dscal(sx.data.size, a, sx.data, 1) + case dx: DenseVector => + f2jBLAS.dscal(dx.data.size, a, dx.data, 1) + case _ => + throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.") + } + } + + // For level-3 routines, we use the native BLAS. + private def nativeBLAS: NetlibBLAS = { + if (_nativeBLAS == null) { + _nativeBLAS = NativeBLAS + } + _nativeBLAS + } + + /** + * A := alpha * x * x^T^ + A + * @param alpha a real scalar that will be multiplied to x * x^T^. + * @param x the vector x that contains the n elements. + * @param A the symmetric matrix A. Size of n x n. + */ + def syr(alpha: Double, x: Vector, A: DenseMatrix) { + val mA = A.numRows + val nA = A.numCols + require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA") + require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}") + + x match { + case dv: DenseVector => syr(alpha, dv, A) + case sv: SparseVector => syr(alpha, sv, A) + case _ => + throw new IllegalArgumentException(s"syr doesn't support vector type ${x.getClass}.") + } + } + + private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) { + val nA = A.numRows + val mA = A.numCols + + nativeBLAS.dsyr("U", x.size, alpha, x.data, 1, A.data, nA) + + // Fill lower triangular part of A + var i = 0 + while (i < mA) { + var j = i + 1 + while (j < nA) { + A(j, i) = A(i, j) + j += 1 + } + i += 1 + } + } + + private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) { + val mA = A.numCols + val xIndices = x.indices + val xValues = x.data + val nnz = xValues.length + val Avalues = A.data + + var i = 0 + while (i < nnz) { + val multiplier = alpha * xValues(i) + val offset = xIndices(i) * mA + var j = 0 + while (j < nnz) { + Avalues(xIndices(j) + offset) += multiplier * xValues(j) + j += 1 + } + i += 1 + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala new file mode 100644 index 0000000..74d4d8f --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import breeze.linalg.{ Matrix => BreezeMatrix, DenseMatrix => BreezeDenseMatrix, +CSCMatrix => BreezeCSCMatrix, DenseVector => BreezeDenseVector, SparseVector => BreezeSparseVector, +Vector => BreezeVector} + +/** This class contains convenience function to wrap a matrix/vector into a breeze matrix/vector + * and to unwrap it again. + * + */ +object Breeze { + + implicit class Matrix2BreezeConverter(matrix: Matrix) { + def asBreeze: BreezeMatrix[Double] = { + matrix match { + case dense: DenseMatrix => + new BreezeDenseMatrix[Double]( + dense.numRows, + dense.numCols, + dense.data) + + case sparse: SparseMatrix => + new BreezeCSCMatrix[Double]( + sparse.data, + sparse.numRows, + sparse.numCols, + sparse.colPtrs, + sparse.rowIndices + ) + } + } + } + + implicit class Breeze2MatrixConverter(matrix: BreezeMatrix[Double]) { + def fromBreeze: Matrix = { + matrix match { + case dense: BreezeDenseMatrix[Double] => + new DenseMatrix(dense.rows, dense.cols, dense.data) + + case sparse: BreezeCSCMatrix[Double] => + new SparseMatrix(sparse.rows, sparse.cols, sparse.rowIndices, sparse.colPtrs, sparse.data) + } + } + } + + implicit class BreezeArrayConverter[T](array: Array[T]) { + def asBreeze: BreezeDenseVector[T] = { + new BreezeDenseVector[T](array) + } + } + + implicit class Breeze2VectorConverter(vector: BreezeVector[Double]) { + def fromBreeze[T <: Vector: BreezeVectorConverter]: T = { + val converter = implicitly[BreezeVectorConverter[T]] + converter.convert(vector) + } + } + + implicit class Vector2BreezeConverter(vector: Vector) { + def asBreeze: BreezeVector[Double] = { + vector match { + case dense: DenseVector => + new breeze.linalg.DenseVector(dense.data) + + case sparse: SparseVector => + new BreezeSparseVector(sparse.indices, sparse.data, sparse.size) + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/BreezeVectorConverter.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/BreezeVectorConverter.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/BreezeVectorConverter.scala new file mode 100644 index 0000000..0bb24f3 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/BreezeVectorConverter.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import breeze.linalg.{Vector => BreezeVector} + +/** Type class which allows the conversion from Breeze vectors to Flink vectors + * + * @tparam T Resulting type of the conversion + */ +trait BreezeVectorConverter[T <: Vector] extends Serializable { + /** Converts a Breeze vector into a Flink vector of type T + * + * @param vector Breeze vector + * @return Flink vector of type T + */ + def convert(vector: BreezeVector[Double]): T +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala new file mode 100644 index 0000000..4ae565e --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +/** + * Dense matrix implementation of [[Matrix]]. Stores data in column major order in a continuous + * double array. + * + * @param numRows Number of rows + * @param numCols Number of columns + * @param data Array of matrix elements in column major order + */ +case class DenseMatrix( + val numRows: Int, + val numCols: Int, + val data: Array[Double]) + extends Matrix + with Serializable{ + + import DenseMatrix._ + + require(numRows * numCols == data.length, s"The number of values ${data.length} does " + + s"not correspond to its dimensions ($numRows, $numCols).") + + /** + * Element wise access function + * + * @param row row index + * @param col column index + * @return matrix entry at (row, col) + */ + override def apply(row: Int, col: Int): Double = { + val index = locate(row, col) + + data(index) + } + + override def toString: String = { + val result = StringBuilder.newBuilder + result.append(s"DenseMatrix($numRows, $numCols)\n") + + val linewidth = LINE_WIDTH + + val columnsFieldWidths = for(row <- 0 until math.min(numRows, MAX_ROWS)) yield { + var column = 0 + var maxFieldWidth = 0 + + while(column * maxFieldWidth < linewidth && column < numCols) { + val fieldWidth = printEntry(row, column).length + 2 + + if(fieldWidth > maxFieldWidth) { + maxFieldWidth = fieldWidth + } + + if(column * maxFieldWidth < linewidth) { + column += 1 + } + } + + (column, maxFieldWidth) + } + + val (columns, fieldWidths) = columnsFieldWidths.unzip + + val maxColumns = columns.min + val fieldWidth = fieldWidths.max + + for(row <- 0 until math.min(numRows, MAX_ROWS)) { + for(col <- 0 until maxColumns) { + val str = printEntry(row, col) + + result.append(" " * (fieldWidth - str.length) + str) + } + + if(maxColumns < numCols) { + result.append("...") + } + + result.append("\n") + } + + if(numRows > MAX_ROWS) { + result.append("...\n") + } + + result.toString() + } + + override def equals(obj: Any): Boolean = { + obj match { + case dense: DenseMatrix => + numRows == dense.numRows && numCols == dense.numCols && data.sameElements(dense.data) + case _ => false + } + } + + override def hashCode: Int = { + val hashCodes = List(numRows.hashCode(), numCols.hashCode(), java.util.Arrays.hashCode(data)) + + hashCodes.foldLeft(3){(left, right) => left * 41 + right} + } + + /** Element wise update function + * + * @param row row index + * @param col column index + * @param value value to set at (row, col) + */ + override def update(row: Int, col: Int, value: Double): Unit = { + val index = locate(row, col) + + data(index) = value + } + + def toSparseMatrix: SparseMatrix = { + val entries = for(row <- 0 until numRows; col <- 0 until numCols) yield { + (row, col, apply(row, col)) + } + + SparseMatrix.fromCOO(numRows, numCols, entries.filter(_._3 != 0)) + } + + /** Calculates the linear index of the respective matrix entry + * + * @param row + * @param col + * @return + */ + private def locate(row: Int, col: Int): Int = { + require(0 <= row && row < numRows && 0 <= col && col < numCols, + (row, col) + " not in [0, " + numRows + ") x [0, " + numCols + ")") + + row + col * numRows + } + + /** Converts the entry at (row, col) to string + * + * @param row + * @param col + * @return + */ + private def printEntry(row: Int, col: Int): String = { + val index = locate(row, col) + + data(index).toString + } + + /** Copies the matrix instance + * + * @return Copy of itself + */ + override def copy: DenseMatrix = { + new DenseMatrix(numRows, numCols, data.clone) + } +} + +object DenseMatrix { + + val LINE_WIDTH = 100 + val MAX_ROWS = 50 + + def apply(numRows: Int, numCols: Int, values: Array[Int]): DenseMatrix = { + new DenseMatrix(numRows, numCols, values.map(_.toDouble)) + } + + def apply(numRows: Int, numCols: Int, values: Double*): DenseMatrix = { + new DenseMatrix(numRows, numCols, values.toArray) + } + + def zeros(numRows: Int, numCols: Int): DenseMatrix = { + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(0.0)) + } + + def eye(numRows: Int, numCols: Int): DenseMatrix = { + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(1.0)) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala new file mode 100644 index 0000000..5e70741 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import breeze.linalg.{SparseVector => BreezeSparseVector, DenseVector => BreezeDenseVector, Vector => BreezeVector} + +/** + * Dense vector implementation of [[Vector]]. The data is represented in a continuous array of + * doubles. + * + * @param data Array of doubles to store the vector elements + */ +case class DenseVector( + data: Array[Double]) + extends Vector + with Serializable { + + /** + * Number of elements in a vector + * @return + */ + override def size: Int = { + data.length + } + + /** + * Element wise access function + * + * @param index index of the accessed element + * @return element at the given index + */ + override def apply(index: Int): Double = { + require(0 <= index && index < data.length, index + " not in [0, " + data.length + ")") + data(index) + } + + override def toString: String = { + s"DenseVector(${data.mkString(", ")})" + } + + override def equals(obj: Any): Boolean = { + obj match { + case dense: DenseVector => data.length == dense.data.length && data.sameElements(dense.data) + case _ => false + } + } + + override def hashCode: Int = { + java.util.Arrays.hashCode(data) + } + + /** + * Copies the vector instance + * + * @return Copy of the vector instance + */ + override def copy: DenseVector = { + DenseVector(data.clone()) + } + + /** Updates the element at the given index with the provided value + * + * @param index Index whose value is updated. + * @param value The value used to update the index. + */ + override def update(index: Int, value: Double): Unit = { + require(0 <= index && index < data.length, index + " not in [0, " + data.length + ")") + + data(index) = value + } + + /** Returns the dot product of the recipient and the argument + * + * @param other a Vector + * @return a scalar double of dot product + */ + override def dot(other: Vector): Double = { + require(size == other.size, "The size of vector must be equal.") + + other match { + case SparseVector(_, otherIndices, otherData) => + otherIndices.zipWithIndex.map { + case (idx, sparseIdx) => data(idx) * otherData(sparseIdx) + }.sum + case _ => (0 until size).map(i => data(i) * other(i)).sum + } + } + + /** Returns the outer product (a.k.a. Kronecker product) of `this` + * with `other`. The result will given in [[org.apache.flink.ml.math.SparseMatrix]] + * representation if `other` is sparse and as [[org.apache.flink.ml.math.DenseMatrix]] otherwise. + * + * @param other a Vector + * @return the [[org.apache.flink.ml.math.Matrix]] which equals the outer product of `this` + * with `other.` + */ + override def outer(other: Vector): Matrix = { + val numRows = size + val numCols = other.size + + other match { + case sv: SparseVector => + val entries = for { + i <- 0 until numRows + (j, k) <- sv.indices.zipWithIndex + value = this(i) * sv.data(k) + if value != 0 + } yield (i, j, value) + + SparseMatrix.fromCOO(numRows, numCols, entries) + case _ => + val values = for { + i <- 0 until numRows + j <- 0 until numCols + } yield this(i) * other(j) + + DenseMatrix(numRows, numCols, values.toArray) + } + } + + /** Magnitude of a vector + * + * @return + */ + override def magnitude: Double = math.sqrt(data.map(x => x * x).sum) + + def toSparseVector: SparseVector = { + val nonZero = (0 until size).zip(data).filter(_._2 != 0) + + SparseVector.fromCOO(size, nonZero) + } +} + +object DenseVector { + + def apply(values: Double*): DenseVector = { + new DenseVector(values.toArray) + } + + def apply(values: Array[Int]): DenseVector = { + new DenseVector(values.map(_.toDouble)) + } + + def zeros(size: Int): DenseVector = { + init(size, 0.0) + } + + def eye(size: Int): DenseVector = { + init(size, 1.0) + } + + def init(size: Int, value: Double): DenseVector = { + new DenseVector(Array.fill(size)(value)) + } + + /** BreezeVectorConverter implementation for [[org.apache.flink.ml.math.DenseVector]] + * + * This allows to convert Breeze vectors into [[DenseVector]]. + */ + implicit val denseVectorConverter = new BreezeVectorConverter[DenseVector] { + override def convert(vector: BreezeVector[Double]): DenseVector = { + vector match { + case dense: BreezeDenseVector[Double] => new DenseVector(dense.data) + case sparse: BreezeSparseVector[Double] => new DenseVector(sparse.toDenseVector.data) + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala new file mode 100644 index 0000000..ba6a781 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +/** Base trait for a matrix representation + * + */ +trait Matrix { + + /** Number of rows + * + * @return + */ + def numRows: Int + + /** Number of columns + * + * @return + */ + def numCols: Int + + /** Element wise access function + * + * @param row row index + * @param col column index + * @return matrix entry at (row, col) + */ + def apply(row: Int, col: Int): Double + + /** Element wise update function + * + * @param row row index + * @param col column index + * @param value value to set at (row, col) + */ + def update(row: Int, col: Int, value: Double): Unit + + /** Copies the matrix instance + * + * @return Copy of itself + */ + def copy: Matrix + + def equalsMatrix(matrix: Matrix): Boolean = { + if(numRows == matrix.numRows && numCols == matrix.numCols) { + val coordinates = for(row <- 0 until numRows; col <- 0 until numCols) yield (row, col) + coordinates forall { case(row, col) => this.apply(row, col) == matrix(row, col)} + } else { + false + } + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala new file mode 100644 index 0000000..fe58ddb --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import scala.util.Sorting + +/** Sparse matrix using the compressed sparse column (CSC) representation. + * + * More details concerning the compressed sparse column (CSC) representation can be found + * [http://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_column_.28CSC_or_CCS.29]. + * + * @param numRows Number of rows + * @param numCols Number of columns + * @param rowIndices Array containing the row indices of non-zero entries + * @param colPtrs Array containing the starting offsets in data for each column + * @param data Array containing the non-zero entries in column-major order + */ +class SparseMatrix( + val numRows: Int, + val numCols: Int, + val rowIndices: Array[Int], + val colPtrs: Array[Int], + val data: Array[Double]) + extends Matrix + with Serializable { + + /** Element wise access function + * + * @param row row index + * @param col column index + * @return matrix entry at (row, col) + */ + override def apply(row: Int, col: Int): Double = { + + val index = locate(row, col) + + if(index < 0){ + 0 + } else { + data(index) + } + } + + def toDenseMatrix: DenseMatrix = { + val result = DenseMatrix.zeros(numRows, numCols) + + for(row <- 0 until numRows; col <- 0 until numCols) { + result(row, col) = apply(row, col) + } + + result + } + + /** Element wise update function + * + * @param row row index + * @param col column index + * @param value value to set at (row, col) + */ + override def update(row: Int, col: Int, value: Double): Unit = { + val index = locate(row, col) + + if(index < 0) { + throw new IllegalArgumentException("Cannot update zero value of sparse matrix at index " + + s"($row, $col)") + } else { + data(index) = value + } + } + + override def toString: String = { + val result = StringBuilder.newBuilder + + result.append(s"SparseMatrix($numRows, $numCols)\n") + + var columnIndex = 0 + + val fieldWidth = math.max(numRows, numCols).toString.length + val valueFieldWidth = data.map(_.toString.length).max + 2 + + for(index <- 0 until colPtrs.last) { + while(colPtrs(columnIndex + 1) <= index){ + columnIndex += 1 + } + + val rowStr = rowIndices(index).toString + val columnStr = columnIndex.toString + val valueStr = data(index).toString + + result.append("(" + " " * (fieldWidth - rowStr.length) + rowStr + "," + + " " * (fieldWidth - columnStr.length) + columnStr + ")") + result.append(" " * (valueFieldWidth - valueStr.length) + valueStr) + result.append("\n") + } + + result.toString + } + + override def equals(obj: Any): Boolean = { + obj match { + case sm: SparseMatrix if numRows == sm.numRows && numCols == sm.numCols => + rowIndices.sameElements(sm.rowIndices) && colPtrs.sameElements(sm.colPtrs) && + data.sameElements(sm.data) + case _ => false + } + } + + override def hashCode: Int = { + val hashCodes = List(numRows.hashCode(), numCols.hashCode(), + java.util.Arrays.hashCode(rowIndices), java.util.Arrays.hashCode(colPtrs), + java.util.Arrays.hashCode(data)) + + hashCodes.foldLeft(5){(left, right) => left * 41 + right} + } + + private def locate(row: Int, col: Int): Int = { + require(0 <= row && row < numRows && 0 <= col && col < numCols, + (row, col) + " not in [0, " + numRows + ") x [0, " + numCols + ")") + + val startIndex = colPtrs(col) + val endIndex = colPtrs(col + 1) + + java.util.Arrays.binarySearch(rowIndices, startIndex, endIndex, row) + } + + /** Copies the matrix instance + * + * @return Copy of itself + */ + override def copy: SparseMatrix = { + new SparseMatrix(numRows, numCols, rowIndices.clone, colPtrs.clone(), data.clone) + } +} + +object SparseMatrix{ + + /** Constructs a sparse matrix from a coordinate list (COO) representation where each entry + * is stored as a tuple of (rowIndex, columnIndex, value). + * @param numRows + * @param numCols + * @param entries + * @return + */ + def fromCOO(numRows: Int, numCols: Int, entries: (Int, Int, Double)*): SparseMatrix = { + fromCOO(numRows, numCols, entries) + } + + /** Constructs a sparse matrix from a coordinate list (COO) representation where each entry + * is stored as a tuple of (rowIndex, columnIndex, value). + * + * @param numRows + * @param numCols + * @param entries + * @return + */ + def fromCOO(numRows: Int, numCols: Int, entries: Iterable[(Int, Int, Double)]): SparseMatrix = { + val entryArray = entries.toArray + + entryArray.foreach{ case (row, col, _) => + require(0 <= row && row < numRows && 0 <= col && col <= numCols, + (row, col) + " not in [0, " + numRows + ") x [0, " + numCols + ")") + } + + val COOOrdering = new Ordering[(Int, Int, Double)] { + override def compare(x: (Int, Int, Double), y: (Int, Int, Double)): Int = { + if(x._2 < y._2) { + -1 + } else if(x._2 > y._2) { + 1 + } else { + x._1 - y._1 + } + } + } + + Sorting.quickSort(entryArray)(COOOrdering) + + val nnz = entryArray.length + + val data = new Array[Double](nnz) + val rowIndices = new Array[Int](nnz) + val colPtrs = new Array[Int](numCols + 1) + + var (lastRow, lastCol, lastValue) = entryArray(0) + + rowIndices(0) = lastRow + data(0) = lastValue + + var i = 1 + var lastDataIndex = 0 + + while(i < nnz) { + val (curRow, curCol, curValue) = entryArray(i) + + if(lastRow == curRow && lastCol == curCol) { + // add values with identical coordinates + data(lastDataIndex) += curValue + } else { + lastDataIndex += 1 + data(lastDataIndex) = curValue + rowIndices(lastDataIndex) = curRow + lastRow = curRow + } + + while(lastCol < curCol) { + lastCol += 1 + colPtrs(lastCol) = lastDataIndex + } + + i += 1 + } + + lastDataIndex += 1 + while(lastCol < numCols) { + colPtrs(lastCol + 1) = lastDataIndex + lastCol += 1 + } + + val prunedRowIndices = if(lastDataIndex < nnz) { + val prunedArray = new Array[Int](lastDataIndex) + rowIndices.copyToArray(prunedArray) + prunedArray + } else { + rowIndices + } + + val prunedData = if(lastDataIndex < nnz) { + val prunedArray = new Array[Double](lastDataIndex) + data.copyToArray(prunedArray) + prunedArray + } else { + data + } + + new SparseMatrix(numRows, numCols, prunedRowIndices, colPtrs, prunedData) + } + + /** Convenience method to convert a single tuple with an integer value into a SparseMatrix. + * The problem is that providing a single tuple to the fromCOO method, the Scala type inference + * cannot infer that the tuple has to be of type (Int, Int, Double) because of the overloading + * with the Iterable type. + * + * @param numRows + * @param numCols + * @param entry + * @return + */ + def fromCOO(numRows: Int, numCols: Int, entry: (Int, Int, Int)): SparseMatrix = { + fromCOO(numRows, numCols, (entry._1, entry._2, entry._3.toDouble)) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala new file mode 100644 index 0000000..fec018f --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import breeze.linalg.{SparseVector => BreezeSparseVector, DenseVector => BreezeDenseVector, Vector => BreezeVector} + +import scala.util.Sorting + +/** Sparse vector implementation storing the data in two arrays. One index contains the sorted + * indices of the non-zero vector entries and the other the corresponding vector entries + */ +case class SparseVector( + size: Int, + indices: Array[Int], + data: Array[Double]) + extends Vector + with Serializable { + + /** Updates the element at the given index with the provided value + * + * @param index Index whose value is updated. + * @param value The value used to update the index. + */ + override def update(index: Int, value: Double): Unit = { + val resolvedIndex = locate(index) + + if (resolvedIndex < 0) { + throw new IllegalArgumentException("Cannot update zero value of sparse vector at index " + + index) + } else { + data(resolvedIndex) = value + } + } + + /** Copies the vector instance + * + * @return Copy of the vector instance + */ + override def copy: SparseVector = { + new SparseVector(size, indices.clone, data.clone) + } + + /** Returns the dot product of the recipient and the argument + * + * @param other a Vector + * @return a scalar double of dot product + */ + override def dot(other: Vector): Double = { + require(size == other.size, "The size of vector must be equal.") + other match { + case DenseVector(otherData) => + indices.zipWithIndex.map { case (sparseIdx, idx) => data(idx) * otherData(sparseIdx) }.sum + case SparseVector(_, otherIndices, otherData) => + var left = 0 + var right = 0 + var result = 0.0 + + while (left < indices.length && right < otherIndices.length) { + if (indices(left) < otherIndices(right)) { + left += 1 + } else if (otherIndices(right) < indices(left)) { + right += 1 + } else { + result += data(left) * otherData(right) + left += 1 + right += 1 + } + } + result + } + } + + /** Returns the outer product (a.k.a. Kronecker product) of `this` + * with `other`. The result is given in [[org.apache.flink.ml.math.SparseMatrix]] + * representation. + * + * @param other a Vector + * @return the [[org.apache.flink.ml.math.SparseMatrix]] which equals the outer product of `this` + * with `other.` + */ + override def outer(other: Vector): SparseMatrix = { + val numRows = size + val numCols = other.size + + val entries = other match { + case sv: SparseVector => + for { + (i, k) <- indices.zipWithIndex + (j, l) <- sv.indices.zipWithIndex + value = data(k) * sv.data(l) + if value != 0 + } yield (i, j, value) + case _ => + for { + (i, k) <- indices.zipWithIndex + j <- 0 until numCols + value = data(k) * other(j) + if value != 0 + } yield (i, j, value) + } + + SparseMatrix.fromCOO(numRows, numCols, entries) + } + + + /** Magnitude of a vector + * + * @return + */ + override def magnitude: Double = math.sqrt(data.map(x => x * x).sum) + + /** Element wise access function + * + * * @param index index of the accessed element + * @return element with index + */ + override def apply(index: Int): Double = { + val resolvedIndex = locate(index) + + if(resolvedIndex < 0) { + 0 + } else { + data(resolvedIndex) + } + } + + def toDenseVector: DenseVector = { + val denseVector = DenseVector.zeros(size) + + for(index <- 0 until size) { + denseVector(index) = this(index) + } + + denseVector + } + + override def equals(obj: Any): Boolean = { + obj match { + case sv: SparseVector if size == sv.size => + indices.sameElements(sv.indices) && data.sameElements(sv.data) + case _ => false + } + } + + override def hashCode: Int = { + val hashCodes = List(size.hashCode, java.util.Arrays.hashCode(indices), + java.util.Arrays.hashCode(data)) + + hashCodes.foldLeft(3){ (left, right) => left * 41 + right} + } + + override def toString: String = { + val entries = indices.zip(data).mkString(", ") + "SparseVector(" + entries + ")" + } + + private def locate(index: Int): Int = { + require(0 <= index && index < size, index + " not in [0, " + size + ")") + + java.util.Arrays.binarySearch(indices, 0, indices.length, index) + } +} + +object SparseVector { + + /** Constructs a sparse vector from a coordinate list (COO) representation where each entry + * is stored as a tuple of (index, value). + * + * @param size + * @param entries + * @return + */ + def fromCOO(size: Int, entries: (Int, Double)*): SparseVector = { + fromCOO(size, entries) + } + + /** Constructs a sparse vector from a coordinate list (COO) representation where each entry + * is stored as a tuple of (index, value). + * + * @param size + * @param entries + * @return + */ + def fromCOO(size: Int, entries: Iterable[(Int, Double)]): SparseVector = { + val entryArray = entries.toArray + + entryArray.foreach { case (index, _) => + require(0 <= index && index < size, index + " not in [0, " + size + ")") + } + + val COOOrdering = new Ordering[(Int, Double)] { + override def compare(x: (Int, Double), y: (Int, Double)): Int = { + x._1 - y._1 + } + } + + Sorting.quickSort(entryArray)(COOOrdering) + + // calculate size of the array + val arraySize = entryArray.foldLeft((-1, 0)){ case ((lastIndex, numRows), (index, _)) => + if(lastIndex == index) { + (lastIndex, numRows) + } else { + (index, numRows + 1) + } + }._2 + + val indices = new Array[Int](arraySize) + val data = new Array[Double](arraySize) + + val (index, value) = entryArray(0) + + indices(0) = index + data(0) = value + + var i = 1 + var lastIndex = indices(0) + var lastDataIndex = 0 + + while(i < entryArray.length) { + val (curIndex, curValue) = entryArray(i) + + if(curIndex == lastIndex) { + data(lastDataIndex) += curValue + } else { + lastDataIndex += 1 + data(lastDataIndex) = curValue + indices(lastDataIndex) = curIndex + lastIndex = curIndex + } + + i += 1 + } + + new SparseVector(size, indices, data) + } + + /** Convenience method to be able to instantiate a SparseVector with a single element. The Scala + * type inference mechanism cannot infer that the second tuple value has to be of type Double + * if only a single tuple is provided. + * + * @param size + * @param entry + * @return + */ + def fromCOO(size: Int, entry: (Int, Int)): SparseVector = { + fromCOO(size, (entry._1, entry._2.toDouble)) + } + + /** BreezeVectorConverter implementation for [[org.apache.flink.ml.math.SparseVector]] + * + * This allows to convert Breeze vectors into [[SparseVector]] + */ + implicit val sparseVectorConverter = new BreezeVectorConverter[SparseVector] { + override def convert(vector: BreezeVector[Double]): SparseVector = { + vector match { + case dense: BreezeDenseVector[Double] => + SparseVector.fromCOO( + dense.length, + dense.iterator.toIterable) + case sparse: BreezeSparseVector[Double] => + new SparseVector( + sparse.length, + sparse.index.take(sparse.used), + sparse.data.take(sparse.used)) + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala new file mode 100644 index 0000000..e52328d --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import breeze.linalg.{SparseVector => BreezeSparseVector, DenseVector => BreezeDenseVector, Vector => BreezeVector} + +/** Base trait for Vectors + * + */ +trait Vector extends Serializable { + + /** Number of elements in a vector + * + * @return + */ + def size: Int + + /** Element wise access function + * + * * @param index index of the accessed element + * @return element with index + */ + def apply(index: Int): Double + + /** Updates the element at the given index with the provided value + * + * @param index + * @param value + */ + def update(index: Int, value: Double): Unit + + /** Copies the vector instance + * + * @return Copy of the vector instance + */ + def copy: Vector + + /** Returns the dot product of the recipient and the argument + * + * @param other a Vector + * @return a scalar double of dot product + */ + def dot(other: Vector): Double + + /** Returns the outer product of the recipient and the argument + * + * + * @param other a Vector + * @return a matrix + */ + def outer(other: Vector): Matrix + + /** Magnitude of a vector + * + * @return + */ + def magnitude: Double + + def equalsVector(vector: Vector): Boolean = { + if(size == vector.size) { + (0 until size) forall { idx => + this(idx) == vector(idx) + } + } else { + false + } + } +} + +object Vector{ + /** BreezeVectorConverter implementation for [[Vector]] + * + * This allows to convert Breeze vectors into [[Vector]]. + */ + implicit val vectorConverter = new BreezeVectorConverter[Vector] { + override def convert(vector: BreezeVector[Double]): Vector = { + vector match { + case dense: BreezeDenseVector[Double] => new DenseVector(dense.data) + + case sparse: BreezeSparseVector[Double] => + new SparseVector( + sparse.length, + sparse.index.take(sparse.used), + sparse.data.take(sparse.used)) + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/VectorBuilder.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/VectorBuilder.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/VectorBuilder.scala new file mode 100644 index 0000000..3bbf146 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/VectorBuilder.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +/** Type class to allow the vector construction from different data types + * + * @tparam T Subtype of [[Vector]] + */ +trait VectorBuilder[T <: Vector] extends Serializable { + /** Builds a [[Vector]] of type T from a List[Double] + * + * @param data Input data where the index denotes the resulting index of the vector + * @return A vector of type T + */ + def build(data: List[Double]): T +} + +object VectorBuilder{ + + /** Type class implementation for [[org.apache.flink.ml.math.DenseVector]] */ + implicit val denseVectorBuilder = new VectorBuilder[DenseVector] { + override def build(data: List[Double]): DenseVector = { + new DenseVector(data.toArray) + } + } + + /** Type class implementation for [[org.apache.flink.ml.math.SparseVector]] */ + implicit val sparseVectorBuilder = new VectorBuilder[SparseVector] { + override def build(data: List[Double]): SparseVector = { + // Enrich elements with explicit indices and filter out zero entries + SparseVector.fromCOO(data.length, data.indices.zip(data).filter(_._2 != 0.0)) + } + } + + /** Type class implementation for [[Vector]] */ + implicit val vectorBuilder = new VectorBuilder[Vector] { + override def build(data: List[Double]): Vector = { + new DenseVector(data.toArray) + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala new file mode 100644 index 0000000..4c7f254 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml + +/** + * Convenience methods to handle Flink's [[org.apache.flink.ml.math.Matrix]] and [[Vector]] + * abstraction. + */ +package object math { + implicit class RichMatrix(matrix: Matrix) extends Iterable[(Int, Int, Double)] { + + override def iterator: Iterator[(Int, Int, Double)] = { + new Iterator[(Int, Int, Double)] { + var index = 0 + + override def hasNext: Boolean = { + index < matrix.numRows * matrix.numCols + } + + override def next(): (Int, Int, Double) = { + val row = index % matrix.numRows + val column = index / matrix.numRows + + index += 1 + + (row, column, matrix(row, column)) + } + } + } + + def valueIterator: Iterator[Double] = { + val it = iterator + + new Iterator[Double] { + override def hasNext: Boolean = it.hasNext + + override def next(): Double = it.next._3 + } + } + + } + + implicit class RichVector(vector: Vector) extends Iterable[(Int, Double)] { + + override def iterator: Iterator[(Int, Double)] = { + new Iterator[(Int, Double)] { + var index = 0 + + override def hasNext: Boolean = { + index < vector.size + } + + override def next(): (Int, Double) = { + val resultIndex = index + + index += 1 + + (resultIndex, vector(resultIndex)) + } + } + } + + def valueIterator: Iterator[Double] = { + val it = iterator + + new Iterator[Double] { + override def hasNext: Boolean = it.hasNext + + override def next(): Double = it.next._2 + } + } + } + + /** Stores the vector values in a dense array + * + * @param vector + * @return Array containing the vector values + */ + def vector2Array(vector: Vector): Array[Double] = { + vector match { + case dense: DenseVector => dense.data.clone + + case sparse: SparseVector => + val result = new Array[Double](sparse.size) + + for((index, value) <- sparse) { + result(index) = value + } + + result + + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/ChebyshevDistanceMetric.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/ChebyshevDistanceMetric.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/ChebyshevDistanceMetric.scala new file mode 100644 index 0000000..055ede3 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/ChebyshevDistanceMetric.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.metrics.distances + +import org.apache.flink.ml.math.Vector + +/** This class implements a Chebyshev distance metric. The class calculates the distance between + * the given vectors by finding the maximum difference between each coordinate. + * + * @see http://en.wikipedia.org/wiki/Chebyshev_distance + */ +class ChebyshevDistanceMetric extends DistanceMetric { + override def distance(a: Vector, b: Vector): Double = { + checkValidArguments(a, b) + (0 until a.size).map(i => math.abs(a(i) - b(i))).max + } +} + +object ChebyshevDistanceMetric { + def apply() = new ChebyshevDistanceMetric() +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/CosineDistanceMetric.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/CosineDistanceMetric.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/CosineDistanceMetric.scala new file mode 100644 index 0000000..f32ea26 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/CosineDistanceMetric.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.metrics.distances + +import org.apache.flink.ml.math.Vector + +/** This class implements a cosine distance metric. The class calculates the distance between + * the given vectors by dividing the dot product of two vectors by the product of their lengths. + * We convert the result of division to a usable distance. So, 1 - cos(angle) is actually returned. + * + * @see http://en.wikipedia.org/wiki/Cosine_similarity + */ +class CosineDistanceMetric extends DistanceMetric { + override def distance(a: Vector, b: Vector): Double = { + checkValidArguments(a, b) + + val dotProd: Double = a.dot(b) + val denominator: Double = a.magnitude * b.magnitude + if (dotProd == 0 && denominator == 0) { + 0 + } else { + 1 - dotProd / denominator + } + } +} + +object CosineDistanceMetric { + def apply() = new CosineDistanceMetric() +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/DistanceMetric.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/DistanceMetric.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/DistanceMetric.scala new file mode 100644 index 0000000..21573fe --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/DistanceMetric.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.metrics.distances + +import org.apache.flink.ml.math.Vector + +/** DistanceMeasure interface is used for object which determines distance between two points. + */ +trait DistanceMetric extends Serializable { + /** Returns the distance between the arguments. + * + * @param a a Vector defining a multi-dimensional point in some space + * @param b a Vector defining a multi-dimensional point in some space + * @return a scalar double of the distance + */ + def distance(a: Vector, b: Vector): Double + + protected def checkValidArguments(a: Vector, b: Vector) = { + require(a.size == b.size, "The each size of vectors must be same to calculate distance.") + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/EuclideanDistanceMetric.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/EuclideanDistanceMetric.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/EuclideanDistanceMetric.scala new file mode 100644 index 0000000..153fb93 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/EuclideanDistanceMetric.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.metrics.distances + +import org.apache.flink.ml.math.Vector + +/** This class implements a Euclidean distance metric. The metric calculates the distance between + * the given two vectors by summing the square root of the squared differences between + * each coordinate. + * + * http://en.wikipedia.org/wiki/Euclidean_distance + * + * If you don't care about the true distance and only need for comparison, + * [[SquaredEuclideanDistanceMetric]] will be faster because it doesn't calculate the actual + * square root of the distances. + * + * @see http://en.wikipedia.org/wiki/Euclidean_distance + */ +class EuclideanDistanceMetric extends SquaredEuclideanDistanceMetric { + override def distance(a: Vector, b: Vector): Double = math.sqrt(super.distance(a, b)) +} + +object EuclideanDistanceMetric { + def apply() = new EuclideanDistanceMetric() +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/ManhattanDistanceMetric.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/ManhattanDistanceMetric.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/ManhattanDistanceMetric.scala new file mode 100644 index 0000000..5983f79 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/ManhattanDistanceMetric.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.metrics.distances + +import org.apache.flink.ml.math.Vector + +/** This class implements a Manhattan distance metric. The class calculates the distance between + * the given vectors by summing the differences between each coordinate. + * + * @see http://en.wikipedia.org/wiki/Taxicab_geometry + */ +class ManhattanDistanceMetric extends DistanceMetric{ + override def distance(a: Vector, b: Vector): Double = { + checkValidArguments(a, b) + (0 until a.size).map(i => math.abs(a(i) - b(i))).sum + } +} + +object ManhattanDistanceMetric { + def apply() = new ManhattanDistanceMetric() +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/MinkowskiDistanceMetric.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/MinkowskiDistanceMetric.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/MinkowskiDistanceMetric.scala new file mode 100644 index 0000000..50161d4 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/MinkowskiDistanceMetric.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.metrics.distances + +import org.apache.flink.ml.math.Vector + +/** This class implements a Minkowski distance metric. The metric is a generalization of + * L(p) distances: Euclidean distance and Manhattan distance. If you need for a special case of + * p = 1 or p = 2, use [[ManhattanDistanceMetric]], [[EuclideanDistanceMetric]]. This class is + * useful for high exponents. + * + * @param p the norm exponent of space + * + * @see http://en.wikipedia.org/wiki/Minkowski_distance + */ +class MinkowskiDistanceMetric(val p: Double) extends DistanceMetric { + override def distance(a: Vector, b: Vector): Double = { + checkValidArguments(a, b) + math.pow((0 until a.size).map(i => math.pow(math.abs(a(i) - b(i)), p)).sum, 1 / p) + } +} + +object MinkowskiDistanceMetric { + def apply(p: Double) = new MinkowskiDistanceMetric(p) +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/SquaredEuclideanDistanceMetric.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/SquaredEuclideanDistanceMetric.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/SquaredEuclideanDistanceMetric.scala new file mode 100644 index 0000000..fe546e9 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/SquaredEuclideanDistanceMetric.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.metrics.distances + +import org.apache.flink.ml.math.Vector + +/** This class is like [[EuclideanDistanceMetric]] but it does not take the square root. + * + * The value calculated by this class is not exact Euclidean distance, but it saves on computation + * when you need the value for only comparison. + */ +class SquaredEuclideanDistanceMetric extends DistanceMetric { + override def distance(a: Vector, b: Vector): Double = { + checkValidArguments(a, b) + (0 until a.size).map(i => math.pow(a(i) - b(i), 2)).sum + } +} + +object SquaredEuclideanDistanceMetric { + def apply() = new SquaredEuclideanDistanceMetric() +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/TanimotoDistanceMetric.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/TanimotoDistanceMetric.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/TanimotoDistanceMetric.scala new file mode 100644 index 0000000..5141c98 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/metrics/distances/TanimotoDistanceMetric.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.metrics.distances + +import org.apache.flink.ml.math.Vector + +/** This class implements a Tanimoto distance metric. The class calculates the distance between + * the given vectors. The vectors are assumed as bit-wise vectors. We convert the result of + * division to a usable distance. So, 1 - similarity is actually returned. + * + * @see http://en.wikipedia.org/wiki/Jaccard_index + */ +class TanimotoDistanceMetric extends DistanceMetric { + override def distance(a: Vector, b: Vector): Double = { + checkValidArguments(a, b) + + val dotProd: Double = a.dot(b) + 1 - dotProd / (a.magnitude * a.magnitude + b.magnitude * b.magnitude - dotProd) + } +} + +object TanimotoDistanceMetric { + def apply() = new TanimotoDistanceMetric() +}