http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/preprocessing/StandardScaler.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/preprocessing/StandardScaler.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/preprocessing/StandardScaler.scala new file mode 100644 index 0000000..82e8abf --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/preprocessing/StandardScaler.scala @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.preprocessing + +import breeze.linalg +import breeze.numerics.sqrt +import breeze.numerics.sqrt._ +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.scala._ +import org.apache.flink.ml.common.{LabeledVector, Parameter, ParameterMap} +import org.apache.flink.ml.math.Breeze._ +import org.apache.flink.ml.math.{BreezeVectorConverter, Vector} +import org.apache.flink.ml.pipeline.{TransformOperation, FitOperation, +Transformer} +import org.apache.flink.ml.preprocessing.StandardScaler.{Mean, Std} + +import scala.reflect.ClassTag + +/** Scales observations, so that all features have a user-specified mean and standard deviation. + * By default for [[StandardScaler]] transformer mean=0.0 and std=1.0. + * + * This transformer takes a subtype of [[Vector]] of values and maps it to a + * scaled subtype of [[Vector]] such that each feature has a user-specified mean and standard + * deviation. + * + * This transformer can be prepended to all [[Transformer]] and + * [[org.apache.flink.ml.pipeline.Predictor]] implementations which expect as input a subtype + * of [[Vector]]. + * + * @example + * {{{ + * val trainingDS: DataSet[Vector] = env.fromCollection(data) + * val transformer = StandardScaler().setMean(10.0).setStd(2.0) + * + * transformer.fit(trainingDS) + * val transformedDS = transformer.transform(trainingDS) + * }}} + * + * =Parameters= + * + * - [[Mean]]: The mean value of transformed data set; by default equal to 0 + * - [[Std]]: The standard deviation of the transformed data set; by default + * equal to 1 + */ +class StandardScaler extends Transformer[StandardScaler] { + + private[preprocessing] var metricsOption: Option[ + DataSet[(linalg.Vector[Double], linalg.Vector[Double])] + ] = None + + /** Sets the target mean of the transformed data + * + * @param mu the user-specified mean value. + * @return the StandardScaler instance with its mean value set to the user-specified value + */ + def setMean(mu: Double): StandardScaler = { + parameters.add(Mean, mu) + this + } + + /** Sets the target standard deviation of the transformed data + * + * @param std the user-specified std value. In case the user gives 0.0 value as input, + * the std is set to the default value: 1.0. + * @return the StandardScaler instance with its std value set to the user-specified value + */ + def setStd(std: Double): StandardScaler = { + if (std == 0.0) { + return this + } + parameters.add(Std, std) + this + } +} + +object StandardScaler { + + // ====================================== Parameters ============================================= + + case object Mean extends Parameter[Double] { + override val defaultValue: Option[Double] = Some(0.0) + } + + case object Std extends Parameter[Double] { + override val defaultValue: Option[Double] = Some(1.0) + } + + // ==================================== Factory methods ========================================== + + def apply(): StandardScaler = { + new StandardScaler() + } + + // ====================================== Operations ============================================= + + /** Trains the [[org.apache.flink.ml.preprocessing.StandardScaler]] by learning the mean and + * standard deviation of the training data. These values are used inthe transform step + * to transform the given input data. + * + * @tparam T Input data type which is a subtype of [[Vector]] + * @return + */ + implicit def fitVectorStandardScaler[T <: Vector] = new FitOperation[StandardScaler, T] { + override def fit(instance: StandardScaler, fitParameters: ParameterMap, input: DataSet[T]) + : Unit = { + val metrics = extractFeatureMetrics(input) + + instance.metricsOption = Some(metrics) + } + } + + /** Trains the [[StandardScaler]] by learning the mean and standard deviation of the training + * data which is of type [[LabeledVector]]. The mean and standard deviation are used to + * transform the given input data. + * + */ + implicit val fitLabeledVectorStandardScaler = { + new FitOperation[StandardScaler, LabeledVector] { + override def fit( + instance: StandardScaler, + fitParameters: ParameterMap, + input: DataSet[LabeledVector]) + : Unit = { + val vectorDS = input.map(_.vector) + val metrics = extractFeatureMetrics(vectorDS) + + instance.metricsOption = Some(metrics) + } + } + } + + /** Trains the [[StandardScaler]] by learning the mean and standard deviation of the training + * data which is of type ([[Vector]], Double). The mean and standard deviation are used to + * transform the given input data. + * + */ + implicit def fitLabelVectorTupleStandardScaler + [T <: Vector: BreezeVectorConverter: TypeInformation: ClassTag] = { + new FitOperation[StandardScaler, (T, Double)] { + override def fit( + instance: StandardScaler, + fitParameters: ParameterMap, + input: DataSet[(T, Double)]) + : Unit = { + val vectorDS = input.map(_._1) + val metrics = extractFeatureMetrics(vectorDS) + + instance.metricsOption = Some(metrics) + } + } + } + + /** Calculates in one pass over the data the features' mean and standard deviation. + * For the calculation of the Standard deviation with one pass over the data, + * the Youngs & Cramer algorithm was used: + * [[http://www.cs.yale.edu/publications/techreports/tr222.pdf]] + * + * + * @param dataSet The data set for which we want to calculate mean and variance + * @return DataSet containing a single tuple of two vectors (meanVector, stdVector). + * The first vector represents the mean vector and the second is the standard + * deviation vector. + */ + private def extractFeatureMetrics[T <: Vector](dataSet: DataSet[T]) + : DataSet[(linalg.Vector[Double], linalg.Vector[Double])] = { + val metrics = dataSet.map{ + v => (1.0, v.asBreeze, linalg.Vector.zeros[Double](v.size)) + }.reduce{ + (metrics1, metrics2) => { + /* We use formula 1.5b of the cited technical report for the combination of partial + * sum of squares. According to 1.5b: + * val temp1 : m/n(m+n) + * val temp2 : n/m + */ + val temp1 = metrics1._1 / (metrics2._1 * (metrics1._1 + metrics2._1)) + val temp2 = metrics2._1 / metrics1._1 + val tempVector = (metrics1._2 * temp2) - metrics2._2 + val tempS = (metrics1._3 + metrics2._3) + (tempVector :* tempVector) * temp1 + + (metrics1._1 + metrics2._1, metrics1._2 + metrics2._2, tempS) + } + }.map{ + metric => { + val varianceVector = sqrt(metric._3 / metric._1) + + for (i <- 0 until varianceVector.size) { + if (varianceVector(i) == 0.0) { + varianceVector.update(i, 1.0) + } + } + (metric._2 / metric._1, varianceVector) + } + } + metrics + } + + /** Base class for StandardScaler's [[TransformOperation]]. This class has to be extended for + * all types which are supported by [[StandardScaler]]'s transform operation. + * + * @tparam T + */ + abstract class StandardScalerTransformOperation[T: TypeInformation: ClassTag] + extends TransformOperation[ + StandardScaler, + (linalg.Vector[Double], linalg.Vector[Double]), + T, + T] { + + var mean: Double = _ + var std: Double = _ + + override def getModel( + instance: StandardScaler, + transformParameters: ParameterMap) + : DataSet[(linalg.Vector[Double], linalg.Vector[Double])] = { + mean = transformParameters(Mean) + std = transformParameters(Std) + + instance.metricsOption match { + case Some(metrics) => metrics + case None => + throw new RuntimeException("The StandardScaler has not been fitted to the data. " + + "This is necessary to estimate the mean and standard deviation of the data.") + } + } + + def scale[V <: Vector: BreezeVectorConverter]( + vector: V, + model: (linalg.Vector[Double], linalg.Vector[Double])) + : V = { + val (broadcastMean, broadcastStd) = model + var myVector = vector.asBreeze + myVector -= broadcastMean + myVector :/= broadcastStd + myVector = (myVector :* std) + mean + myVector.fromBreeze + } + } + + /** [[TransformOperation]] to transform [[Vector]] types + * + * @tparam T + * @return + */ + implicit def transformVectors[T <: Vector: BreezeVectorConverter: TypeInformation: ClassTag] = { + new StandardScalerTransformOperation[T]() { + override def transform( + vector: T, + model: (linalg.Vector[Double], linalg.Vector[Double])) + : T = { + scale(vector, model) + } + } + } + + /** [[TransformOperation]] to transform tuples of type ([[Vector]], [[Double]]). + * + * @tparam T + * @return + */ + implicit def transformTupleVectorDouble[ + T <: Vector: BreezeVectorConverter: TypeInformation: ClassTag] = { + new StandardScalerTransformOperation[(T, Double)] { + override def transform( + element: (T, Double), + model: (linalg.Vector[Double], linalg.Vector[Double])) + : (T, Double) = { + (scale(element._1, model), element._2) + } + } + } + + /** [[TransformOperation]] to transform [[LabeledVector]]. + * + */ + implicit val transformLabeledVector = new StandardScalerTransformOperation[LabeledVector] { + override def transform( + element: LabeledVector, + model: (linalg.Vector[Double], linalg.Vector[Double])) + : LabeledVector = { + val LabeledVector(label, vector) = element + + LabeledVector(label, scale(vector, model)) + } + } +}
http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala new file mode 100644 index 0000000..d8af42f --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala @@ -0,0 +1,1009 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation + +import java.{util, lang} + +import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint +import org.apache.flink.api.scala._ +import org.apache.flink.api.common.operators.Order +import org.apache.flink.core.memory.{DataOutputView, DataInputView} +import org.apache.flink.ml.common._ +import org.apache.flink.ml.pipeline.{FitOperation, PredictDataSetOperation, Predictor} +import org.apache.flink.types.Value +import org.apache.flink.util.Collector +import org.apache.flink.api.common.functions.{Partitioner => FlinkPartitioner, GroupReduceFunction, CoGroupFunction} + +import com.github.fommil.netlib.BLAS.{ getInstance => blas } +import com.github.fommil.netlib.LAPACK.{ getInstance => lapack } +import org.netlib.util.intW + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +/** Alternating least squares algorithm to calculate a matrix factorization. + * + * Given a matrix `R`, ALS calculates two matricess `U` and `V` such that `R ~~ U^TV`. The + * unknown row dimension is given by the number of latent factors. Since matrix factorization + * is often used in the context of recommendation, we'll call the first matrix the user and the + * second matrix the item matrix. The `i`th column of the user matrix is `u_i` and the `i`th + * column of the item matrix is `v_i`. The matrix `R` is called the ratings matrix and + * `(R)_{i,j} = r_{i,j}`. + * + * In order to find the user and item matrix, the following problem is solved: + * + * `argmin_{U,V} sum_(i,j\ with\ r_{i,j} != 0) (r_{i,j} - u_{i}^Tv_{j})^2 + + * lambda (sum_(i) n_{u_i} ||u_i||^2 + sum_(j) n_{v_j} ||v_j||^2)` + * + * with `\lambda` being the regularization factor, `n_{u_i}` being the number of items the user `i` + * has rated and `n_{v_j}` being the number of times the item `j` has been rated. This + * regularization scheme to avoid overfitting is called weighted-lambda-regularization. Details + * can be found in the work of [[http://dx.doi.org/10.1007/978-3-540-68880-8_32 Zhou et al.]]. + * + * By fixing one of the matrices `U` or `V` one obtains a quadratic form which can be solved. The + * solution of the modified problem is guaranteed to decrease the overall cost function. By + * applying this step alternately to the matrices `U` and `V`, we can iteratively improve the + * matrix factorization. + * + * The matrix `R` is given in its sparse representation as a tuple of `(i, j, r)` where `i` is the + * row index, `j` is the column index and `r` is the matrix value at position `(i,j)`. + * + * @example + * {{{ + * val inputDS: DataSet[(Int, Int, Double)] = env.readCsvFile[(Int, Int, Double)]( + * pathToTrainingFile) + * + * val als = ALS() + * .setIterations(10) + * .setNumFactors(10) + * + * als.fit(inputDS) + * + * val data2Predict: DataSet[(Int, Int)] = env.readCsvFile[(Int, Int)](pathToData) + * + * als.predict(data2Predict) + * }}} + * + * =Parameters= + * + * - [[org.apache.flink.ml.recommendation.ALS.NumFactors]]: + * The number of latent factors. It is the dimension of the calculated user and item vectors. + * (Default value: '''10''') + * + * - [[org.apache.flink.ml.recommendation.ALS.Lambda]]: + * Regularization factor. Tune this value in order to avoid overfitting/generalization. + * (Default value: '''1''') + * + * - [[org.apache.flink.ml.regression.MultipleLinearRegression.Iterations]]: + * The number of iterations to perform. (Default value: '''10''') + * + * - [[org.apache.flink.ml.recommendation.ALS.Blocks]]: + * The number of blocks into which the user and item matrix a grouped. The fewer + * blocks one uses, the less data is sent redundantly. However, bigger blocks entail bigger + * update messages which have to be stored on the Heap. If the algorithm fails because of + * an OutOfMemoryException, then try to increase the number of blocks. (Default value: '''None''') + * + * - [[org.apache.flink.ml.recommendation.ALS.Seed]]: + * Random seed used to generate the initial item matrix for the algorithm. + * (Default value: '''0''') + * + * - [[org.apache.flink.ml.recommendation.ALS.TemporaryPath]]: + * Path to a temporary directory into which intermediate results are stored. If + * this value is set, then the algorithm is split into two preprocessing steps, the ALS iteration + * and a post-processing step which calculates a last ALS half-step. The preprocessing steps + * calculate the [[org.apache.flink.ml.recommendation.ALS.OutBlockInformation]] and + * [[org.apache.flink.ml.recommendation.ALS.InBlockInformation]] for the given rating matrix. + * The result of the individual steps are stored in the specified directory. By splitting the + * algorithm into multiple smaller steps, Flink does not have to split the available memory + * amongst too many operators. This allows the system to process bigger individual messasges and + * improves the overall performance. (Default value: '''None''') + * + * The ALS implementation is based on Spark's MLLib implementation of ALS which you can find + * [[https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/mllib/ + * recommendation/ALS.scala here]]. + */ +class ALS extends Predictor[ALS] { + + import ALS._ + + // Stores the matrix factorization after the fitting phase + var factorsOption: Option[(DataSet[Factors], DataSet[Factors])] = None + + /** Sets the number of latent factors/row dimension of the latent model + * + * @param numFactors + * @return + */ + def setNumFactors(numFactors: Int): ALS = { + parameters.add(NumFactors, numFactors) + this + } + + /** Sets the regularization coefficient lambda + * + * @param lambda + * @return + */ + def setLambda(lambda: Double): ALS = { + parameters.add(Lambda, lambda) + this + } + + /** Sets the number of iterations of the ALS algorithm + * + * @param iterations + * @return + */ + def setIterations(iterations: Int): ALS = { + parameters.add(Iterations, iterations) + this + } + + /** Sets the number of blocks into which the user and item matrix shall be partitioned + * + * @param blocks + * @return + */ + def setBlocks(blocks: Int): ALS = { + parameters.add(Blocks, blocks) + this + } + + /** Sets the random seed for the initial item matrix initialization + * + * @param seed + * @return + */ + def setSeed(seed: Long): ALS = { + parameters.add(Seed, seed) + this + } + + /** Sets the temporary path into which intermediate results are written in order to increase + * performance. + * + * @param temporaryPath + * @return + */ + def setTemporaryPath(temporaryPath: String): ALS = { + parameters.add(TemporaryPath, temporaryPath) + this + } + + /** Empirical risk of the trained model (matrix factorization). + * + * @param labeledData Reference data + * @param riskParameters Additional parameters for the empirical risk calculation + * @return + */ + def empiricalRisk( + labeledData: DataSet[(Int, Int, Double)], + riskParameters: ParameterMap = ParameterMap.Empty) + : DataSet[Double] = { + val resultingParameters = parameters ++ riskParameters + + val lambda = resultingParameters(Lambda) + + val data = labeledData map { + x => (x._1, x._2) + } + + factorsOption match { + case Some((userFactors, itemFactors)) => { + val predictions = data.join(userFactors, JoinHint.REPARTITION_HASH_SECOND).where(0) + .equalTo(0).join(itemFactors, JoinHint.REPARTITION_HASH_SECOND).where("_1._2") + .equalTo(0).map { + triple => { + val (((uID, iID), uFactors), iFactors) = triple + + val uFactorsVector = uFactors.factors + val iFactorsVector = iFactors.factors + + val squaredUNorm2 = blas.ddot( + uFactorsVector.length, + uFactorsVector, + 1, + uFactorsVector, + 1) + val squaredINorm2 = blas.ddot( + iFactorsVector.length, + iFactorsVector, + 1, + iFactorsVector, + 1) + + val prediction = blas.ddot(uFactorsVector.length, uFactorsVector, 1, iFactorsVector, 1) + + (uID, iID, prediction, squaredUNorm2, squaredINorm2) + } + } + + labeledData.join(predictions).where(0,1).equalTo(0,1) { + (left, right) => { + val (_, _, expected) = left + val (_, _, predicted, squaredUNorm2, squaredINorm2) = right + + val residual = expected - predicted + + residual * residual + lambda * (squaredUNorm2 + squaredINorm2) + } + } reduce { + _ + _ + } + } + + case None => throw new RuntimeException("The ALS model has not been fitted to data. " + + "Prior to predicting values, it has to be trained on data.") + } + } +} + +object ALS { + val USER_FACTORS_FILE = "userFactorsFile" + val ITEM_FACTORS_FILE = "itemFactorsFile" + + // ========================================= Parameters ========================================== + + case object NumFactors extends Parameter[Int] { + val defaultValue: Option[Int] = Some(10) + } + + case object Lambda extends Parameter[Double] { + val defaultValue: Option[Double] = Some(1.0) + } + + case object Iterations extends Parameter[Int] { + val defaultValue: Option[Int] = Some(10) + } + + case object Blocks extends Parameter[Int] { + val defaultValue: Option[Int] = None + } + + case object Seed extends Parameter[Long] { + val defaultValue: Option[Long] = Some(0L) + } + + case object TemporaryPath extends Parameter[String] { + val defaultValue: Option[String] = None + } + + // ==================================== ALS type definitions ===================================== + + /** Representation of a user-item rating + * + * @param user User ID of the rating user + * @param item Item iD of the rated item + * @param rating Rating value + */ + case class Rating(user: Int, item: Int, rating: Double) + + /** Latent factor model vector + * + * @param id + * @param factors + */ + case class Factors(id: Int, factors: Array[Double]) { + override def toString = s"($id, ${factors.mkString(",")})" + } + + case class Factorization(userFactors: DataSet[Factors], itemFactors: DataSet[Factors]) + + case class OutBlockInformation(elementIDs: Array[Int], outLinks: OutLinks) { + override def toString: String = { + s"OutBlockInformation:((${elementIDs.mkString(",")}), ($outLinks))" + } + } + + class OutLinks(var links: Array[scala.collection.mutable.BitSet]) extends Value { + def this() = this(null) + + override def toString: String = { + s"${links.mkString("\n")}" + } + + override def write(out: DataOutputView): Unit = { + out.writeInt(links.length) + links foreach { + link => { + val bitMask = link.toBitMask + out.writeInt(bitMask.length) + for (element <- bitMask) { + out.writeLong(element) + } + } + } + } + + override def read(in: DataInputView): Unit = { + val length = in.readInt() + links = new Array[scala.collection.mutable.BitSet](length) + + for (i <- 0 until length) { + val bitMaskLength = in.readInt() + val bitMask = new Array[Long](bitMaskLength) + for (j <- 0 until bitMaskLength) { + bitMask(j) = in.readLong() + } + links(i) = mutable.BitSet.fromBitMask(bitMask) + } + } + + def apply(idx: Int) = links(idx) + } + + case class InBlockInformation(elementIDs: Array[Int], ratingsForBlock: Array[BlockRating]) { + + override def toString: String = { + s"InBlockInformation:((${elementIDs.mkString(",")}), (${ratingsForBlock.mkString("\n")}))" + } + } + + case class BlockRating(var ratings: Array[(Array[Int], Array[Double])]) { + def apply(idx: Int) = ratings(idx) + + override def toString: String = { + ratings.map { + case (left, right) => s"((${left.mkString(",")}),(${right.mkString(",")}))" + }.mkString(",") + } + } + + case class BlockedFactorization(userFactors: DataSet[(Int, Array[Array[Double]])], + itemFactors: DataSet[(Int, Array[Array[Double]])]) + + class BlockIDPartitioner extends FlinkPartitioner[Int] { + override def partition(blockID: Int, numberOfPartitions: Int): Int = { + blockID % numberOfPartitions + } + } + + class BlockIDGenerator(blocks: Int) extends Serializable { + def apply(id: Int): Int = { + id % blocks + } + } + + // ================================= Factory methods ============================================= + + def apply(): ALS = { + new ALS() + } + + // ===================================== Operations ============================================== + + /** Predict operation which calculates the matrix entry for the given indices */ + implicit val predictRating = new PredictDataSetOperation[ALS, (Int, Int), (Int ,Int, Double)] { + override def predictDataSet( + instance: ALS, + predictParameters: ParameterMap, + input: DataSet[(Int, Int)]) + : DataSet[(Int, Int, Double)] = { + + instance.factorsOption match { + case Some((userFactors, itemFactors)) => { + input.join(userFactors, JoinHint.REPARTITION_HASH_SECOND).where(0).equalTo(0) + .join(itemFactors, JoinHint.REPARTITION_HASH_SECOND).where("_1._2").equalTo(0).map { + triple => { + val (((uID, iID), uFactors), iFactors) = triple + + val uFactorsVector = uFactors.factors + val iFactorsVector = iFactors.factors + + val prediction = blas.ddot( + uFactorsVector.length, + uFactorsVector, + 1, + iFactorsVector, + 1) + + (uID, iID, prediction) + } + } + } + + case None => throw new RuntimeException("The ALS model has not been fitted to data. " + + "Prior to predicting values, it has to be trained on data.") + } + } + } + + /** Calculates the matrix factorization for the given ratings. A rating is defined as + * a tuple of user ID, item ID and the corresponding rating. + * + * @return Factorization containing the user and item matrix + */ + implicit val fitALS = new FitOperation[ALS, (Int, Int, Double)] { + override def fit( + instance: ALS, + fitParameters: ParameterMap, + input: DataSet[(Int, Int, Double)]) + : Unit = { + val resultParameters = instance.parameters ++ fitParameters + + val userBlocks = resultParameters.get(Blocks).getOrElse(input.count.toInt) + val itemBlocks = userBlocks + val persistencePath = resultParameters.get(TemporaryPath) + val seed = resultParameters(Seed) + val factors = resultParameters(NumFactors) + val iterations = resultParameters(Iterations) + val lambda = resultParameters(Lambda) + + val ratings = input.map { + entry => { + val (userID, itemID, rating) = entry + Rating(userID, itemID, rating) + } + } + + val blockIDPartitioner = new BlockIDPartitioner() + + val ratingsByUserBlock = ratings.map{ + rating => + val blockID = rating.user % userBlocks + (blockID, rating) + } partitionCustom(blockIDPartitioner, 0) + + val ratingsByItemBlock = ratings map { + rating => + val blockID = rating.item % itemBlocks + (blockID, new Rating(rating.item, rating.user, rating.rating)) + } partitionCustom(blockIDPartitioner, 0) + + val (uIn, uOut) = createBlockInformation(userBlocks, itemBlocks, ratingsByUserBlock, + blockIDPartitioner) + val (iIn, iOut) = createBlockInformation(itemBlocks, userBlocks, ratingsByItemBlock, + blockIDPartitioner) + + val (userIn, userOut) = persistencePath match { + case Some(path) => FlinkMLTools.persist(uIn, uOut, path + "userIn", path + "userOut") + case None => (uIn, uOut) + } + + val (itemIn, itemOut) = persistencePath match { + case Some(path) => FlinkMLTools.persist(iIn, iOut, path + "itemIn", path + "itemOut") + case None => (iIn, iOut) + } + + val initialItems = itemOut.partitionCustom(blockIDPartitioner, 0).map{ + outInfos => + val blockID = outInfos._1 + val infos = outInfos._2 + + (blockID, infos.elementIDs.map{ + id => + val random = new Random(id ^ seed) + randomFactors(factors, random) + }) + }.withForwardedFields("0") + + // iteration to calculate the item matrix + val items = initialItems.iterate(iterations) { + items => { + val users = updateFactors(userBlocks, items, itemOut, userIn, factors, lambda, + blockIDPartitioner) + updateFactors(itemBlocks, users, userOut, itemIn, factors, lambda, blockIDPartitioner) + } + } + + val pItems = persistencePath match { + case Some(path) => FlinkMLTools.persist(items, path + "items") + case None => items + } + + // perform last half-step to calculate the user matrix + val users = updateFactors(userBlocks, pItems, itemOut, userIn, factors, lambda, + blockIDPartitioner) + + instance.factorsOption = Some(( + unblock(users, userOut, blockIDPartitioner), + unblock(pItems, itemOut, blockIDPartitioner))) + } + } + + /** Calculates a single half step of the ALS optimization. The result is the new value for + * either the user or item matrix, depending with which matrix the method was called. + * + * @param numUserBlocks Number of blocks in the respective dimension + * @param items Fixed matrix value for the half step + * @param itemOut Out information to know where to send the vectors + * @param userIn In information for the cogroup step + * @param factors Number of latent factors + * @param lambda Regularization constant + * @param blockIDPartitioner Custom Flink partitioner + * @return New value for the optimized matrix (either user or item) + */ + def updateFactors(numUserBlocks: Int, + items: DataSet[(Int, Array[Array[Double]])], + itemOut: DataSet[(Int, OutBlockInformation)], + userIn: DataSet[(Int, InBlockInformation)], + factors: Int, + lambda: Double, blockIDPartitioner: FlinkPartitioner[Int]): + DataSet[(Int, Array[Array[Double]])] = { + // send the item vectors to the blocks whose users have rated the items + val partialBlockMsgs = itemOut.join(items).where(0).equalTo(0). + withPartitioner(blockIDPartitioner).apply { + (left, right, col: Collector[(Int, Int, Array[Array[Double]])]) => { + val blockID = left._1 + val outInfo = left._2 + val factors = right._2 + var userBlock = 0 + var itemIdx = 0 + + while(userBlock < numUserBlocks){ + itemIdx = 0 + val buffer = new ArrayBuffer[Array[Double]] + while(itemIdx < outInfo.elementIDs.length){ + if(outInfo.outLinks(userBlock)(itemIdx)){ + buffer += factors(itemIdx) + } + itemIdx += 1 + } + + if(buffer.nonEmpty){ + // send update message to userBlock + col.collect(userBlock, blockID, buffer.toArray) + } + + userBlock += 1 + } + } + } + + // collect the partial update messages and calculate for each user block the new user vectors + partialBlockMsgs.coGroup(userIn).where(0).equalTo(0).sortFirstGroup(1, Order.ASCENDING). + withPartitioner(blockIDPartitioner).apply{ + new CoGroupFunction[(Int, Int, Array[Array[Double]]), (Int, + InBlockInformation), (Int, Array[Array[Double]])](){ + + // in order to save space, store only the upper triangle of the XtX matrix + val triangleSize = (factors*factors - factors)/2 + factors + val matrix = Array.fill(triangleSize)(0.0) + val fullMatrix = Array.fill(factors * factors)(0.0) + val userXtX = new ArrayBuffer[Array[Double]]() + val userXy = new ArrayBuffer[Array[Double]]() + val numRatings = new ArrayBuffer[Int]() + + override def coGroup(left: lang.Iterable[(Int, Int, Array[Array[Double]])], + right: lang.Iterable[(Int, InBlockInformation)], + collector: Collector[(Int, Array[Array[Double]])]): Unit = { + // there is only one InBlockInformation per user block + val inInfo = right.iterator().next()._2 + val updates = left.iterator() + + val numUsers = inInfo.elementIDs.length + var blockID = -1 + + var i = 0 + + // clear old matrices and allocate new ones + val matricesToClear = if (numUsers > userXtX.length) { + val oldLength = userXtX.length + + while(i < (numUsers - oldLength)) { + userXtX += Array.fill(triangleSize)(0.0) + userXy += Array.fill(factors)(0.0) + numRatings.+=(0) + + i += 1 + } + + oldLength + } else { + numUsers + } + + i = 0 + while(i < matricesToClear){ + numRatings(i) = 0 + + util.Arrays.fill(userXtX(i), 0.0) + util.Arrays.fill(userXy(i), 0.0) + + i += 1 + } + + var itemBlock = 0 + + // build XtX matrices and Xy vector + while(updates.hasNext){ + val update = updates.next() + val blockFactors = update._3 + blockID = update._1 + + var p = 0 + while(p < blockFactors.length){ + val vector = blockFactors(p) + + outerProduct(vector, matrix, factors) + + val (users, ratings) = inInfo.ratingsForBlock(itemBlock)(p) + + var i = 0 + while (i < users.length) { + numRatings(users(i)) += 1 + blas.daxpy(matrix.length, 1, matrix, 1, userXtX(users(i)), 1) + blas.daxpy(vector.length, ratings(i), vector, 1, userXy(users(i)), 1) + + i += 1 + } + p += 1 + } + + itemBlock += 1 + } + + val array = new Array[Array[Double]](numUsers) + + i = 0 + while(i < numUsers){ + generateFullMatrix(userXtX(i), fullMatrix, factors) + + var f = 0 + + // add regularization constant + while(f < factors){ + fullMatrix(f*factors + f) += lambda * numRatings(i) + f += 1 + } + + // calculate new user vector + val result = new intW(0) + lapack.dposv("U", factors, 1, fullMatrix, factors , userXy(i), factors, result) + array(i) = userXy(i) + + i += 1 + } + + collector.collect((blockID, array)) + } + } + }.withForwardedFieldsFirst("0").withForwardedFieldsSecond("0") + } + + /** Creates the meta information needed to route the item and user vectors to the respective user + * and item blocks. + * * @param userBlocks + * @param itemBlocks + * @param ratings + * @param blockIDPartitioner + * @return + */ + def createBlockInformation(userBlocks: Int, itemBlocks: Int, ratings: DataSet[(Int, Rating)], + blockIDPartitioner: BlockIDPartitioner): + (DataSet[(Int, InBlockInformation)], DataSet[(Int, OutBlockInformation)]) = { + val blockIDGenerator = new BlockIDGenerator(itemBlocks) + + val usersPerBlock = createUsersPerBlock(ratings) + + val outBlockInfos = createOutBlockInformation(ratings, usersPerBlock, itemBlocks, + blockIDGenerator) + + val inBlockInfos = createInBlockInformation(ratings, usersPerBlock, blockIDGenerator) + + (inBlockInfos, outBlockInfos) + } + + /** Calculates the userIDs in ascending order of each user block + * + * @param ratings + * @return + */ + def createUsersPerBlock(ratings: DataSet[(Int, Rating)]): DataSet[(Int, Array[Int])] = { + ratings.map{ x => (x._1, x._2.user)}.withForwardedFields("0").groupBy(0). + sortGroup(1, Order.ASCENDING).reduceGroup { + users => { + val result = ArrayBuffer[Int]() + var id = -1 + var oldUser = -1 + + while(users.hasNext) { + val user = users.next() + + id = user._1 + + if (user._2 != oldUser) { + result.+=(user._2) + oldUser = user._2 + } + } + + val userIDs = result.toArray + (id, userIDs) + } + }.withForwardedFields("0") + } + + /** Creates the outgoing block information + * + * Creates for every user block the outgoing block information. The out block information + * contains for every item block a [[scala.collection.mutable.BitSet]] which indicates which + * user vector has to be sent to this block. If a vector v has to be sent to a block b, then + * bitsets(b)'s bit v is set to 1, otherwise 0. Additionally the user IDataSet are replaced by + * the user vector's index value. + * + * @param ratings + * @param usersPerBlock + * @param itemBlocks + * @param blockIDGenerator + * @return + */ + def createOutBlockInformation(ratings: DataSet[(Int, Rating)], + usersPerBlock: DataSet[(Int, Array[Int])], + itemBlocks: Int, blockIDGenerator: BlockIDGenerator): + DataSet[(Int, OutBlockInformation)] = { + ratings.coGroup(usersPerBlock).where(0).equalTo(0).apply { + (ratings, users) => + val userIDs = users.next()._2 + val numUsers = userIDs.length + + val userIDToPos = userIDs.zipWithIndex.toMap + + val shouldDataSend = Array.fill(itemBlocks)(new scala.collection.mutable.BitSet(numUsers)) + var blockID = -1 + while (ratings.hasNext) { + val r = ratings.next() + + val pos = + try { + userIDToPos(r._2.user) + }catch{ + case e: NoSuchElementException => + throw new RuntimeException(s"Key ${r._2.user} not found. BlockID $blockID. " + + s"Elements in block ${userIDs.take(5).mkString(", ")}. " + + s"UserIDList contains ${userIDs.contains(r._2.user)}.", e) + } + + blockID = r._1 + shouldDataSend(blockIDGenerator(r._2.item))(pos) = true + } + + (blockID, OutBlockInformation(userIDs, new OutLinks(shouldDataSend))) + }.withForwardedFieldsFirst("0").withForwardedFieldsSecond("0") + } + + /** Creates the incoming block information + * + * Creates for every user block the incoming block information. The incoming block information + * contains the userIDs of the users in the respective block and for every item block a + * BlockRating instance. The BlockRating instance describes for every incoming set of item + * vectors of an item block, which user rated these items and what the rating was. For that + * purpose it contains for every incoming item vector a tuple of an id array us and a rating + * array rs. The array us contains the indices of the users having rated the respective + * item vector with the ratings in rs. + * + * @param ratings + * @param usersPerBlock + * @param blockIDGenerator + * @return + */ + def createInBlockInformation(ratings: DataSet[(Int, Rating)], + usersPerBlock: DataSet[(Int, Array[Int])], + blockIDGenerator: BlockIDGenerator): + DataSet[(Int, InBlockInformation)] = { + // Group for every user block the users which have rated the same item and collect their ratings + val partialInInfos = ratings.map { x => (x._1, x._2.item, x._2.user, x._2.rating)} + .withForwardedFields("0").groupBy(0, 1).reduceGroup { + x => + var userBlockID = -1 + var itemID = -1 + val userIDs = ArrayBuffer[Int]() + val ratings = ArrayBuffer[Double]() + + while (x.hasNext) { + val (uBlockID, item, user, rating) = x.next + userBlockID = uBlockID + itemID = item + + userIDs += user + ratings += rating + } + + (userBlockID, blockIDGenerator(itemID), itemID, (userIDs.toArray, ratings.toArray)) + }.withForwardedFields("0") + + // Aggregate all ratings for items belonging to the same item block. Sort ascending with + // respect to the itemID, because later the item vectors of the update message are sorted + // accordingly. + val collectedPartialInfos = partialInInfos.groupBy(0, 1).sortGroup(2, Order.ASCENDING). + reduceGroup { + new GroupReduceFunction[(Int, Int, Int, (Array[Int], Array[Double])), (Int, + Int, Array[(Array[Int], Array[Double])])](){ + val buffer = new ArrayBuffer[(Array[Int], Array[Double])] + + override def reduce(iterable: lang.Iterable[(Int, Int, Int, (Array[Int], + Array[Double]))], collector: Collector[(Int, Int, Array[(Array[Int], + Array[Double])])]): Unit = { + + val infos = iterable.iterator() + var counter = 0 + + var blockID = -1 + var itemBlockID = -1 + + while (infos.hasNext && counter < buffer.length) { + val info = infos.next() + blockID = info._1 + itemBlockID = info._2 + + buffer(counter) = info._4 + + counter += 1 + } + + while (infos.hasNext) { + val info = infos.next() + blockID = info._1 + itemBlockID = info._2 + + buffer += info._4 + + counter += 1 + } + + val array = new Array[(Array[Int], Array[Double])](counter) + + buffer.copyToArray(array) + + collector.collect((blockID, itemBlockID, array)) + } + } + }.withForwardedFields("0", "1") + + // Aggregate all item block ratings with respect to their user block ID. Sort the blocks with + // respect to their itemBlockID, because the block update messages are sorted the same way + collectedPartialInfos.coGroup(usersPerBlock).where(0).equalTo(0). + sortFirstGroup(1, Order.ASCENDING).apply{ + new CoGroupFunction[(Int, Int, Array[(Array[Int], Array[Double])]), + (Int, Array[Int]), (Int, InBlockInformation)] { + val buffer = ArrayBuffer[BlockRating]() + + override def coGroup(partialInfosIterable: + lang.Iterable[(Int, Int, Array[(Array[Int], Array[Double])])], + userIterable: lang.Iterable[(Int, Array[Int])], + collector: Collector[(Int, InBlockInformation)]): Unit = { + + val users = userIterable.iterator() + val partialInfos = partialInfosIterable.iterator() + + val userWrapper = users.next() + val id = userWrapper._1 + val userIDs = userWrapper._2 + val userIDToPos = userIDs.zipWithIndex.toMap + + var counter = 0 + + while (partialInfos.hasNext && counter < buffer.length) { + val partialInfo = partialInfos.next() + // entry contains the ratings and userIDs of a complete item block + val entry = partialInfo._3 + + // transform userIDs to positional indices + for (row <- 0 until entry.length; col <- 0 until entry(row)._1.length) { + entry(row)._1(col) = userIDToPos(entry(row)._1(col)) + } + + buffer(counter).ratings = entry + + counter += 1 + } + + while (partialInfos.hasNext) { + val partialInfo = partialInfos.next() + // entry contains the ratings and userIDs of a complete item block + val entry = partialInfo._3 + + // transform userIDs to positional indices + for (row <- 0 until entry.length; col <- 0 until entry(row)._1.length) { + entry(row)._1(col) = userIDToPos(entry(row)._1(col)) + } + + buffer += new BlockRating(entry) + + counter += 1 + } + + val array = new Array[BlockRating](counter) + + buffer.copyToArray(array) + + collector.collect((id, InBlockInformation(userIDs, array))) + } + } + }.withForwardedFieldsFirst("0").withForwardedFieldsSecond("0") + } + + /** Unblocks the blocked user and item matrix representation so that it is at DataSet of + * column vectors. + * + * @param users + * @param outInfo + * @param blockIDPartitioner + * @return + */ + def unblock(users: DataSet[(Int, Array[Array[Double]])], + outInfo: DataSet[(Int, OutBlockInformation)], + blockIDPartitioner: BlockIDPartitioner): DataSet[Factors] = { + users.join(outInfo).where(0).equalTo(0).withPartitioner(blockIDPartitioner).apply { + (left, right, col: Collector[Factors]) => { + val outInfo = right._2 + val factors = left._2 + + for(i <- 0 until outInfo.elementIDs.length){ + val id = outInfo.elementIDs(i) + val factorVector = factors(i) + col.collect(Factors(id, factorVector)) + } + } + } + } + + // ================================ Math helper functions ======================================== + + def outerProduct(vector: Array[Double], matrix: Array[Double], factors: Int): Unit = { + var row = 0 + var pos = 0 + while(row < factors){ + var col = 0 + while(col <= row){ + matrix(pos) = vector(row) * vector(col) + col += 1 + pos += 1 + } + + row += 1 + } + } + + def generateFullMatrix(triangularMatrix: Array[Double], fullMatrix: Array[Double], + factors: Int): Unit = { + var row = 0 + var pos = 0 + + while(row < factors){ + var col = 0 + while(col < row){ + fullMatrix(row*factors + col) = triangularMatrix(pos) + fullMatrix(col*factors + row) = triangularMatrix(pos) + + pos += 1 + col += 1 + } + + fullMatrix(row*factors + row) = triangularMatrix(pos) + + pos += 1 + row += 1 + } + } + + def generateRandomMatrix(users: DataSet[Int], factors: Int, seed: Long): DataSet[Factors] = { + users map { + id =>{ + val random = new Random(id ^ seed) + Factors(id, randomFactors(factors, random)) + } + } + } + + def randomFactors(factors: Int, random: Random): Array[Double] = { + Array.fill(factors)(random.nextDouble()) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala new file mode 100644 index 0000000..c3b3182 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression + +import org.apache.flink.api.scala.DataSet +import org.apache.flink.ml.math.{Breeze, Vector} +import org.apache.flink.ml.common._ + +import org.apache.flink.api.scala._ + +import org.apache.flink.ml.optimization.{LinearPrediction, SquaredLoss, GenericLossFunction, SimpleGradientDescent} +import org.apache.flink.ml.pipeline.{PredictOperation, FitOperation, Predictor} + + +/** Multiple linear regression using the ordinary least squares (OLS) estimator. + * + * The linear regression finds a solution to the problem + * + * `y = w_0 + w_1*x_1 + w_2*x_2 ... + w_n*x_n = w_0 + w^T*x` + * + * such that the sum of squared residuals is minimized + * + * `min_{w, w_0} \sum (y - w^T*x - w_0)^2` + * + * The minimization problem is solved by (stochastic) gradient descent. For each labeled vector + * `(x,y)`, the gradient is calculated. The weighted average of all gradients is subtracted from + * the current value `w` which gives the new value of `w_new`. The weight is defined as + * `stepsize/math.sqrt(iteration)`. + * + * The optimization runs at most a maximum number of iterations or, if a convergence threshold has + * been set, until the convergence criterion has been met. As convergence criterion the relative + * change of the sum of squared residuals is used: + * + * `(S_{k-1} - S_k)/S_{k-1} < \rho` + * + * with S_k being the sum of squared residuals in iteration k and `\rho` being the convergence + * threshold. + * + * At the moment, the whole partition is used for SGD, making it effectively a batch gradient + * descent. Once a sampling operator has been introduced, the algorithm can be optimized. + * + * @example + * {{{ + * val mlr = MultipleLinearRegression() + * .setIterations(10) + * .setStepsize(0.5) + * .setConvergenceThreshold(0.001) + * + * val trainingDS: DataSet[LabeledVector] = ... + * val testingDS: DataSet[Vector] = ... + * + * mlr.fit(trainingDS) + * + * val predictions = mlr.predict(testingDS) + * }}} + * + * =Parameters= + * + * - [[org.apache.flink.ml.regression.MultipleLinearRegression.Iterations]]: + * Maximum number of iterations. + * + * - [[org.apache.flink.ml.regression.MultipleLinearRegression.Stepsize]]: + * Initial step size for the gradient descent method. + * This value controls how far the gradient descent method moves in the opposite direction of the + * gradient. Tuning this parameter might be crucial to make it stable and to obtain a better + * performance. + * + * - [[org.apache.flink.ml.regression.MultipleLinearRegression.ConvergenceThreshold]]: + * Threshold for relative change of sum of squared residuals until convergence. + * + */ +class MultipleLinearRegression extends Predictor[MultipleLinearRegression] { + import org.apache.flink.ml._ + import MultipleLinearRegression._ + + // Stores the weights of the linear model after the fitting phase + var weightsOption: Option[DataSet[WeightVector]] = None + + def setIterations(iterations: Int): MultipleLinearRegression = { + parameters.add(Iterations, iterations) + this + } + + def setStepsize(stepsize: Double): MultipleLinearRegression = { + parameters.add(Stepsize, stepsize) + this + } + + def setConvergenceThreshold(convergenceThreshold: Double): MultipleLinearRegression = { + parameters.add(ConvergenceThreshold, convergenceThreshold) + this + } + + def squaredResidualSum(input: DataSet[LabeledVector]): DataSet[Double] = { + weightsOption match { + case Some(weights) => { + input.mapWithBcVariable(weights){ + (dataPoint, weights) => lossFunction.loss(dataPoint, weights) + }.reduce { + _ + _ + } + } + + case None => { + throw new RuntimeException("The MultipleLinearRegression has not been fitted to the " + + "data. This is necessary to learn the weight vector of the linear function.") + } + } + + } +} + +object MultipleLinearRegression { + + val WEIGHTVECTOR_BROADCAST = "weights_broadcast" + + val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction) + + // ====================================== Parameters ============================================= + + case object Stepsize extends Parameter[Double] { + val defaultValue = Some(0.1) + } + + case object Iterations extends Parameter[Int] { + val defaultValue = Some(10) + } + + case object ConvergenceThreshold extends Parameter[Double] { + val defaultValue = None + } + + // ======================================== Factory methods ====================================== + + def apply(): MultipleLinearRegression = { + new MultipleLinearRegression() + } + + // ====================================== Operations ============================================= + + /** Trains the linear model to fit the training data. The resulting weight vector is stored in + * the [[MultipleLinearRegression]] instance. + * + */ + implicit val fitMLR = new FitOperation[MultipleLinearRegression, LabeledVector] { + override def fit( + instance: MultipleLinearRegression, + fitParameters: ParameterMap, + input: DataSet[LabeledVector]) + : Unit = { + val map = instance.parameters ++ fitParameters + + // retrieve parameters of the algorithm + val numberOfIterations = map(Iterations) + val stepsize = map(Stepsize) + val convergenceThreshold = map.get(ConvergenceThreshold) + + val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction) + + val optimizer = SimpleGradientDescent() + .setIterations(numberOfIterations) + .setStepsize(stepsize) + .setLossFunction(lossFunction) + + convergenceThreshold match { + case Some(threshold) => optimizer.setConvergenceThreshold(threshold) + case None => + } + + instance.weightsOption = Some(optimizer.optimize(input, None)) + } + } + + implicit def predictVectors[T <: Vector] = { + new PredictOperation[MultipleLinearRegression, WeightVector, T, Double]() { + override def getModel(self: MultipleLinearRegression, predictParameters: ParameterMap) + : DataSet[WeightVector] = { + self.weightsOption match { + case Some(weights) => weights + + + case None => { + throw new RuntimeException("The MultipleLinearRegression has not been fitted to the " + + "data. This is necessary to learn the weight vector of the linear function.") + } + } + } + override def predict(value: T, model: WeightVector): Double = { + import Breeze._ + val WeightVector(weights, weight0) = model + val dotProduct = value.asBreeze.dot(weights.asBreeze) + dotProduct + weight0 + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/test/resources/log4j-test.properties ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/test/resources/log4j-test.properties b/flink-libraries/flink-ml/src/test/resources/log4j-test.properties new file mode 100644 index 0000000..023b23a --- /dev/null +++ b/flink-libraries/flink-ml/src/test/resources/log4j-test.properties @@ -0,0 +1,38 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +log4j.rootLogger=OFF, console + +# ----------------------------------------------------------------------------- +# Console (use 'console') +# ----------------------------------------------------------------------------- +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.layout=org.apache.flink.util.MavenForkNumberPrefixLayout +log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} %-5p %-60c %x - %m%n + +# ----------------------------------------------------------------------------- +# File (use 'file') +# ----------------------------------------------------------------------------- +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.file=${log.dir}/flinkML.log +log4j.appender.file.append=false +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{HH:mm:ss,SSS} %-5p %-60c %x - %m%n + +# suppress the irrelevant (wrong) warnings from the netty channel handler +log4j.logger.org.jboss.netty.channel.DefaultChannelPipeline=ERROR, console http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/test/resources/logback-test.xml ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/test/resources/logback-test.xml b/flink-libraries/flink-ml/src/test/resources/logback-test.xml new file mode 100644 index 0000000..1d64d46 --- /dev/null +++ b/flink-libraries/flink-ml/src/test/resources/logback-test.xml @@ -0,0 +1,42 @@ +<!-- + ~ Licensed to the Apache Software Foundation (ASF) under one + ~ or more contributor license agreements. See the NOTICE file + ~ distributed with this work for additional information + ~ regarding copyright ownership. The ASF licenses this file + ~ to you under the Apache License, Version 2.0 (the + ~ "License"); you may not use this file except in compliance + ~ with the License. You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> + +<configuration> + <appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender"> + <encoder> + <pattern>%d{HH:mm:ss.SSS} [%thread] [%X{sourceThread} - %X{akkaSource}] %-5level %logger{60} - %msg%n</pattern> + </encoder> + </appender> + + <root level="WARN"> + <appender-ref ref="STDOUT"/> + </root> + + <!-- The following loggers are disabled during tests, because many tests deliberately + throw error to test failing scenarios. Logging those would overflow the log. --> + <!----> + <logger name="org.apache.flink.runtime.operators.DataSinkTask" level="OFF"/> + <logger name="org.apache.flink.runtime.operators.BatchTask" level="OFF"/> + <logger name="org.apache.flink.runtime.client.JobClient" level="OFF"/> + <logger name="org.apache.flink.runtime.taskmanager.Task" level="OFF"/> + <logger name="org.apache.flink.runtime.jobmanager.JobManager" level="OFF"/> + <logger name="org.apache.flink.runtime.testingUtils" level="OFF"/> + <logger name="org.apache.flink.runtime.executiongraph.ExecutionGraph" level="OFF"/> + <logger name="org.apache.flink.runtime.jobmanager.EventCollector" level="OFF"/> + <logger name="org.apache.flink.runtime.instance.InstanceManager" level="OFF"/> +</configuration> \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/MLUtilsSuite.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/MLUtilsSuite.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/MLUtilsSuite.scala new file mode 100644 index 0000000..d896937 --- /dev/null +++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/MLUtilsSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml + +import java.io.File + +import scala.io.Source + +import org.scalatest.{FlatSpec, Matchers} + +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.ExecutionEnvironment +import org.apache.flink.ml.common.LabeledVector +import org.apache.flink.ml.math.SparseVector +import org.apache.flink.test.util.FlinkTestBase +import org.apache.flink.testutils.TestFileUtils + +class MLUtilsSuite extends FlatSpec with Matchers with FlinkTestBase { + + behavior of "The RichExecutionEnvironment" + + it should "read a libSVM/SVMLight input file" in { + val env = ExecutionEnvironment.getExecutionEnvironment + + val content = + """ + |1 2:10.0 4:4.5 8:4.2 # foo + |-1 1:9.0 4:-4.5 7:2.4 # bar + |0.4 3:1.0 8:-5.6 10:1.0 + |-42.1 2:2.0 4:-6.1 3:5.1 # svm + """.stripMargin + + val expectedLabeledVectors = Set( + LabeledVector(1, SparseVector.fromCOO(10, (1, 10), (3, 4.5), (7, 4.2))), + LabeledVector(-1, SparseVector.fromCOO(10, (0, 9), (3, -4.5), (6, 2.4))), + LabeledVector(0.4, SparseVector.fromCOO(10, (2, 1), (7, -5.6), (9, 1))), + LabeledVector(-42.1, SparseVector.fromCOO(10, (1, 2), (3, -6.1), (2, 5.1))) + ) + + val inputFilePath = TestFileUtils.createTempFile(content) + + val svmInput = env.readLibSVM(inputFilePath) + + val labeledVectors = svmInput.collect() + + labeledVectors.size should be(expectedLabeledVectors.size) + + for(lVector <- labeledVectors) { + expectedLabeledVectors.contains(lVector) should be(true) + } + + } + + it should "write a libSVM/SVMLight output file" in { + val env = ExecutionEnvironment.getExecutionEnvironment + + val labeledVectors = Seq( + LabeledVector(1.0, SparseVector.fromCOO(10, (1, 10), (3, 4.5), (7, 4.2))), + LabeledVector(-1.0, SparseVector.fromCOO(10, (0, 9), (3, -4.5), (6, 2.4))), + LabeledVector(0.4, SparseVector.fromCOO(10, (2, 1), (7, -5.6), (9, 1))), + LabeledVector(-42.1, SparseVector.fromCOO(10, (1, 2), (3, -6.1), (2, 5.1))) + ) + + val expectedLines = List( + "1.0 2:10.0 4:4.5 8:4.2", + "-1.0 1:9.0 4:-4.5 7:2.4", + "0.4 3:1.0 8:-5.6 10:1.0", + "-42.1 2:2.0 3:5.1 4:-6.1" + ) + + val labeledVectorsDS = env.fromCollection(labeledVectors) + + val tempDir = new File(System.getProperty("java.io.tmpdir")) + + val tempFile = new File(tempDir, TestFileUtils.randomFileName()) + + val outputFilePath = tempFile.getAbsolutePath + + labeledVectorsDS.writeAsLibSVM(outputFilePath) + + env.execute() + + val src = Source.fromFile(tempFile) + + var counter = 0 + + for(l <- src.getLines()) { + expectedLines.exists(_.equals(l)) should be(true) + counter += 1 + } + + counter should be(expectedLines.size) + + tempFile.delete() + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/classification/Classification.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/classification/Classification.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/classification/Classification.scala new file mode 100644 index 0000000..c9dd00f --- /dev/null +++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/classification/Classification.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification + +import org.apache.flink.ml.common.LabeledVector +import org.apache.flink.ml.math.DenseVector + +object Classification { + + /** Centered data of fisheriris data set + * + */ + val trainingData = Seq[LabeledVector]( + LabeledVector(1.0000, DenseVector(-0.2060, -0.2760)), + LabeledVector(1.0000, DenseVector(-0.4060, -0.1760)), + LabeledVector(1.0000, DenseVector(-0.0060, -0.1760)), + LabeledVector(1.0000, DenseVector(-0.9060, -0.3760)), + LabeledVector(1.0000, DenseVector(-0.3060, -0.1760)), + LabeledVector(1.0000, DenseVector(-0.4060, -0.3760)), + LabeledVector(1.0000, DenseVector(-0.2060, -0.0760)), + LabeledVector(1.0000, DenseVector(-1.6060, -0.6760)), + LabeledVector(1.0000, DenseVector(-0.3060, -0.3760)), + LabeledVector(1.0000, DenseVector(-1.0060, -0.2760)), + LabeledVector(1.0000, DenseVector(-1.4060, -0.6760)), + LabeledVector(1.0000, DenseVector(-0.7060, -0.1760)), + LabeledVector(1.0000, DenseVector(-0.9060, -0.6760)), + LabeledVector(1.0000, DenseVector(-0.2060, -0.2760)), + LabeledVector(1.0000, DenseVector(-1.3060, -0.3760)), + LabeledVector(1.0000, DenseVector(-0.5060, -0.2760)), + LabeledVector(1.0000, DenseVector(-0.4060, -0.1760)), + LabeledVector(1.0000, DenseVector(-0.8060, -0.6760)), + LabeledVector(1.0000, DenseVector(-0.4060, -0.1760)), + LabeledVector(1.0000, DenseVector(-1.0060, -0.5760)), + LabeledVector(1.0000, DenseVector(-0.1060, 0.1240)), + LabeledVector(1.0000, DenseVector(-0.9060, -0.3760)), + LabeledVector(1.0000, DenseVector(-0.0060, -0.1760)), + LabeledVector(1.0000, DenseVector(-0.2060, -0.4760)), + LabeledVector(1.0000, DenseVector(-0.6060, -0.3760)), + LabeledVector(1.0000, DenseVector(-0.5060, -0.2760)), + LabeledVector(1.0000, DenseVector(-0.1060, -0.2760)), + LabeledVector(1.0000, DenseVector(0.0940, 0.0240)), + LabeledVector(1.0000, DenseVector(-0.4060, -0.1760)), + LabeledVector(1.0000, DenseVector(-1.4060, -0.6760)), + LabeledVector(1.0000, DenseVector(-1.1060, -0.5760)), + LabeledVector(1.0000, DenseVector(-1.2060, -0.6760)), + LabeledVector(1.0000, DenseVector(-1.0060, -0.4760)), + LabeledVector(1.0000, DenseVector(0.1940, -0.0760)), + LabeledVector(1.0000, DenseVector(-0.4060, -0.1760)), + LabeledVector(1.0000, DenseVector(-0.4060, -0.0760)), + LabeledVector(1.0000, DenseVector(-0.2060, -0.1760)), + LabeledVector(1.0000, DenseVector(-0.5060, -0.3760)), + LabeledVector(1.0000, DenseVector(-0.8060, -0.3760)), + LabeledVector(1.0000, DenseVector(-0.9060, -0.3760)), + LabeledVector(1.0000, DenseVector(-0.5060, -0.4760)), + LabeledVector(1.0000, DenseVector(-0.3060, -0.2760)), + LabeledVector(1.0000, DenseVector(-0.9060, -0.4760)), + LabeledVector(1.0000, DenseVector(-1.6060, -0.6760)), + LabeledVector(1.0000, DenseVector(-0.7060, -0.3760)), + LabeledVector(1.0000, DenseVector(-0.7060, -0.4760)), + LabeledVector(1.0000, DenseVector(-0.7060, -0.3760)), + LabeledVector(1.0000, DenseVector(-0.6060, -0.3760)), + LabeledVector(1.0000, DenseVector(-1.9060, -0.5760)), + LabeledVector(1.0000, DenseVector(-0.8060, -0.3760)), + LabeledVector(-1.0000, DenseVector(1.0940, 0.8240)), + LabeledVector(-1.0000, DenseVector(0.1940, 0.2240)), + LabeledVector(-1.0000, DenseVector(0.9940, 0.4240)), + LabeledVector(-1.0000, DenseVector(0.6940, 0.1240)), + LabeledVector(-1.0000, DenseVector(0.8940, 0.5240)), + LabeledVector(-1.0000, DenseVector(1.6940, 0.4240)), + LabeledVector(-1.0000, DenseVector(-0.4060, 0.0240)), + LabeledVector(-1.0000, DenseVector(1.3940, 0.1240)), + LabeledVector(-1.0000, DenseVector(0.8940, 0.1240)), + LabeledVector(-1.0000, DenseVector(1.1940, 0.8240)), + LabeledVector(-1.0000, DenseVector(0.1940, 0.3240)), + LabeledVector(-1.0000, DenseVector(0.3940, 0.2240)), + LabeledVector(-1.0000, DenseVector(0.5940, 0.4240)), + LabeledVector(-1.0000, DenseVector(0.0940, 0.3240)), + LabeledVector(-1.0000, DenseVector(0.1940, 0.7240)), + LabeledVector(-1.0000, DenseVector(0.3940, 0.6240)), + LabeledVector(-1.0000, DenseVector(0.5940, 0.1240)), + LabeledVector(-1.0000, DenseVector(1.7940, 0.5240)), + LabeledVector(-1.0000, DenseVector(1.9940, 0.6240)), + LabeledVector(-1.0000, DenseVector(0.0940, -0.1760)), + LabeledVector(-1.0000, DenseVector(0.7940, 0.6240)), + LabeledVector(-1.0000, DenseVector(-0.0060, 0.3240)), + LabeledVector(-1.0000, DenseVector(1.7940, 0.3240)), + LabeledVector(-1.0000, DenseVector(-0.0060, 0.1240)), + LabeledVector(-1.0000, DenseVector(0.7940, 0.4240)), + LabeledVector(-1.0000, DenseVector(1.0940, 0.1240)), + LabeledVector(-1.0000, DenseVector(-0.1060, 0.1240)), + LabeledVector(-1.0000, DenseVector(-0.0060, 0.1240)), + LabeledVector(-1.0000, DenseVector(0.6940, 0.4240)), + LabeledVector(-1.0000, DenseVector(0.8940, -0.0760)), + LabeledVector(-1.0000, DenseVector(1.1940, 0.2240)), + LabeledVector(-1.0000, DenseVector(1.4940, 0.3240)), + LabeledVector(-1.0000, DenseVector(0.6940, 0.5240)), + LabeledVector(-1.0000, DenseVector(0.1940, -0.1760)), + LabeledVector(-1.0000, DenseVector(0.6940, -0.2760)), + LabeledVector(-1.0000, DenseVector(1.1940, 0.6240)), + LabeledVector(-1.0000, DenseVector(0.6940, 0.7240)), + LabeledVector(-1.0000, DenseVector(0.5940, 0.1240)), + LabeledVector(-1.0000, DenseVector(-0.1060, 0.1240)), + LabeledVector(-1.0000, DenseVector(0.4940, 0.4240)), + LabeledVector(-1.0000, DenseVector(0.6940, 0.7240)), + LabeledVector(-1.0000, DenseVector(0.1940, 0.6240)), + LabeledVector(-1.0000, DenseVector(0.1940, 0.2240)), + LabeledVector(-1.0000, DenseVector(0.9940, 0.6240)), + LabeledVector(-1.0000, DenseVector(0.7940, 0.8240)), + LabeledVector(-1.0000, DenseVector(0.2940, 0.6240)), + LabeledVector(-1.0000, DenseVector(0.0940, 0.2240)), + LabeledVector(-1.0000, DenseVector(0.2940, 0.3240)), + LabeledVector(-1.0000, DenseVector(0.4940, 0.6240)), + LabeledVector(-1.0000, DenseVector(0.1940, 0.1240)) + ) + + val expectedWeightVector = DenseVector(-1.95, -3.45) +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala new file mode 100644 index 0000000..e6eb873 --- /dev/null +++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification + +import org.scalatest.{FlatSpec, Matchers} +import org.apache.flink.ml.math.DenseVector + +import org.apache.flink.api.scala._ +import org.apache.flink.test.util.FlinkTestBase + +class SVMITSuite extends FlatSpec with Matchers with FlinkTestBase { + + behavior of "The SVM using CoCoA implementation" + + it should "train a SVM" in { + val env = ExecutionEnvironment.getExecutionEnvironment + + val svm = SVM(). + setBlocks(env.getParallelism). + setIterations(100). + setLocalIterations(100). + setRegularization(0.002). + setStepsize(0.1). + setSeed(0) + + val trainingDS = env.fromCollection(Classification.trainingData) + + svm.fit(trainingDS) + + val weightVector = svm.weightsOption.get.collect().head + + weightVector.valueIterator.zip(Classification.expectedWeightVector.valueIterator).foreach { + case (weight, expectedWeight) => + weight should be(expectedWeight +- 0.1) + } + } + + it should "make (mostly) correct predictions" in { + val env = ExecutionEnvironment.getExecutionEnvironment + + val svm = SVM(). + setBlocks(env.getParallelism). + setIterations(100). + setLocalIterations(100). + setRegularization(0.002). + setStepsize(0.1). + setSeed(0) + + val trainingDS = env.fromCollection(Classification.trainingData) + + val test = trainingDS.map(x => (x.vector, x.label)) + + svm.fit(trainingDS) + + val predictionPairs = svm.evaluate(test) + + val absoluteErrorSum = predictionPairs.collect().map{ + case (truth, prediction) => Math.abs(truth - prediction)}.sum + + absoluteErrorSum should be < 15.0 + } + + it should "be possible to get the raw decision function values" in { + val env = ExecutionEnvironment.getExecutionEnvironment + + val svm = SVM(). + setBlocks(env.getParallelism) + .setOutputDecisionFunction(false) + + val customWeights = env.fromElements(DenseVector(1.0, 1.0, 1.0)) + + svm.weightsOption = Option(customWeights) + + val test = env.fromElements(DenseVector(5.0, 5.0, 5.0)) + + val thresholdedPrediction = svm.predict(test).map(vectorLabel => vectorLabel._2).collect().head + + thresholdedPrediction should be (1.0 +- 1e-9) + + svm.setOutputDecisionFunction(true) + + val rawPrediction = svm.predict(test).map(vectorLabel => vectorLabel._2).collect().head + + rawPrediction should be (15.0 +- 1e-9) + + + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/common/FlinkMLToolsSuite.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/common/FlinkMLToolsSuite.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/common/FlinkMLToolsSuite.scala new file mode 100644 index 0000000..525ba4d --- /dev/null +++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/common/FlinkMLToolsSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common + +import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer +import org.apache.flink.api.scala.ExecutionEnvironment +import org.apache.flink.test.util.FlinkTestBase +import org.scalatest.{FlatSpec, Matchers} + +class FlinkMLToolsSuite extends FlatSpec with Matchers with FlinkTestBase { + behavior of "FlinkMLTools" + + it should "register the required types" in { + val env = ExecutionEnvironment.getExecutionEnvironment + + FlinkMLTools.registerFlinkMLTypes(env) + + val executionConfig = env.getConfig + + val serializer = new KryoSerializer[Nothing](classOf[Nothing], executionConfig) + + val kryo = serializer.getKryo() + + kryo.getRegistration(classOf[org.apache.flink.ml.math.DenseVector]).getId > 0 should be(true) + kryo.getRegistration(classOf[org.apache.flink.ml.math.SparseVector]).getId > 0 should be(true) + kryo.getRegistration(classOf[org.apache.flink.ml.math.DenseMatrix]).getId > 0 should be(true) + kryo.getRegistration(classOf[org.apache.flink.ml.math.SparseMatrix]).getId > 0 should be(true) + + kryo.getRegistration(classOf[breeze.linalg.DenseMatrix[_]]).getId > 0 should be(true) + kryo.getRegistration(classOf[breeze.linalg.CSCMatrix[_]]).getId > 0 should be(true) + kryo.getRegistration(classOf[breeze.linalg.DenseVector[_]]).getId > 0 should be(true) + kryo.getRegistration(classOf[breeze.linalg.SparseVector[_]]).getId > 0 should be(true) + + kryo.getRegistration(breeze.linalg.DenseVector.zeros[Double](0).getClass).getId > 0 should + be(true) + kryo.getRegistration(breeze.linalg.SparseVector.zeros[Double](0).getClass).getId > 0 should + be(true) + kryo.getRegistration(breeze.linalg.DenseMatrix.zeros[Double](0, 0).getClass).getId > 0 should + be(true) + kryo.getRegistration(breeze.linalg.CSCMatrix.zeros[Double](0, 0).getClass).getId > 0 should + be(true) + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathSuite.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathSuite.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathSuite.scala new file mode 100644 index 0000000..0d230c5 --- /dev/null +++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import Breeze._ +import breeze.linalg + +import org.scalatest.{Matchers, FlatSpec} + +class BreezeMathSuite extends FlatSpec with Matchers { + + behavior of "Breeze vector conversion" + + it should "convert a DenseMatrix into breeze.linalg.DenseMatrix and vice versa" in { + val numRows = 5 + val numCols = 4 + + val data = Array.range(0, numRows * numCols) + val expectedData = Array.range(0, numRows * numCols).map(_ * 2) + + val denseMatrix = DenseMatrix(numRows, numCols, data) + val expectedMatrix = DenseMatrix(numRows, numCols, expectedData) + + val m = denseMatrix.asBreeze + + val result = (m * 2.0).fromBreeze + + result should equal(expectedMatrix) + } + + it should "convert a SparseMatrix into breeze.linalg.CSCMatrix" in { + val numRows = 5 + val numCols = 4 + + val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols, + (0, 1, 1), + (4, 3, 13), + (3, 2, 45), + (4, 0, 12)) + + val expectedMatrix = SparseMatrix.fromCOO(numRows, numCols, + (0, 1, 2), + (4, 3, 26), + (3, 2, 90), + (4, 0, 24)) + + val sm = sparseMatrix.asBreeze + + val result = (sm * 2.0).fromBreeze + + result should equal(expectedMatrix) + } + + it should "convert a dense Flink vector into a dense Breeze vector and vice versa" in { + val vector = DenseVector(1, 2, 3) + + val breezeVector = vector.asBreeze + + val flinkVector = breezeVector.fromBreeze + + breezeVector.getClass should be(new linalg.DenseVector[Double](0).getClass()) + flinkVector.getClass should be (classOf[DenseVector]) + + flinkVector should equal(vector) + } + + it should "convert a sparse Flink vector into a sparse Breeze vector and given the right " + + "converter back into a dense Flink vector" in { + implicit val converter = implicitly[BreezeVectorConverter[DenseVector]] + + val vector = SparseVector.fromCOO(3, (1, 1.0), (2, 2.0)) + + val breezeVector = vector.asBreeze + + val flinkVector = breezeVector.fromBreeze + + breezeVector.getClass should be(new linalg.SparseVector[Double](null).getClass()) + flinkVector.getClass should be (classOf[DenseVector]) + + flinkVector.equalsVector(vector) should be(true) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala new file mode 100644 index 0000000..88bde3b --- /dev/null +++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import org.scalatest.{Matchers, FlatSpec} + +class DenseMatrixSuite extends FlatSpec with Matchers { + + behavior of "Flink's DenseMatrix" + + it should "contain the initialization data" in { + val numRows = 10 + val numCols = 13 + + val data = Array.range(0, numRows*numCols) + + val matrix = DenseMatrix(numRows, numCols, data) + + assertResult(numRows)(matrix.numRows) + assertResult(numCols)(matrix.numCols) + + for(row <- 0 until numRows; col <- 0 until numCols) { + assertResult(data(col*numRows + row))(matrix(row, col)) + } + } + + it should "fail in case of invalid element access" in { + val numRows = 10 + val numCols = 13 + + val matrix = DenseMatrix.zeros(numRows, numCols) + + intercept[IllegalArgumentException] { + matrix(-1, 2) + } + + intercept[IllegalArgumentException] { + matrix(0, -1) + } + + intercept[IllegalArgumentException] { + matrix(numRows, 0) + } + + intercept[IllegalArgumentException] { + matrix(0, numCols) + } + + intercept[IllegalArgumentException] { + matrix(numRows, numCols) + } + } + + it should "be copyable" in { + val numRows = 4 + val numCols = 5 + + val data = Array.range(0, numRows*numCols) + + val denseMatrix = DenseMatrix.apply(numRows, numCols, data) + + val copy = denseMatrix.copy + + denseMatrix should equal(copy) + + copy(0, 0) = 1 + + denseMatrix should not equal copy + } +}