Repository: flink Updated Branches: refs/heads/master 7571959a1 -> d163a817f
[FLINK-2102] [ml] Add predict function for labeled data for SVM and MLR. These functions return for each example in the input DataSet[LabeledVector] a pair (truth, prediction) Added documentation for new predict functions This closes #744. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/d163a817 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/d163a817 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/d163a817 Branch: refs/heads/master Commit: d163a817fa2e330e86384d0bbcd104f051a6fb48 Parents: 7571959 Author: Theodore Vasiloudis <t...@sics.se> Authored: Thu May 28 18:51:17 2015 +0200 Committer: Till Rohrmann <trohrm...@apache.org> Committed: Tue Jun 2 13:24:05 2015 +0200 ---------------------------------------------------------------------- docs/libs/ml/multiple_linear_regression.md | 8 +++ docs/libs/ml/svm.md | 8 +++ .../apache/flink/ml/classification/SVM.scala | 53 +++++++++++++++++- .../regression/MultipleLinearRegression.scala | 58 +++++++++++++++++++- .../flink/ml/classification/SVMITSuite.scala | 31 +++++++++++ .../MultipleLinearRegressionITSuite.scala | 24 ++++++++ 6 files changed, 178 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/docs/libs/ml/multiple_linear_regression.md ---------------------------------------------------------------------- diff --git a/docs/libs/ml/multiple_linear_regression.md b/docs/libs/ml/multiple_linear_regression.md index d9bc951..aaf1fbf 100644 --- a/docs/libs/ml/multiple_linear_regression.md +++ b/docs/libs/ml/multiple_linear_regression.md @@ -77,6 +77,14 @@ MultipleLinearRegression predicts for all subtypes of `Vector` the corresponding * `predict[T <: Vector]: DataSet[T] => DataSet[LabeledVector]` +If we call predict with a `DataSet[LabeledVector]`, we make a prediction on the regression value +for each example, and return a `DataSet[(Double, Double)]`. In each tuple the first element +is the true value, as was provided from the input `DataSet[LabeledVector]` and the second element +is the predicted value. You can then use these `(truth, prediction)` tuples to evaluate +the algorithm's performance. + +* `predict: DataSet[LabeledVector] => DataSet[(Double, Double)]` + ## Parameters The multiple linear regression implementation can be controlled by the following parameters: http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/docs/libs/ml/svm.md ---------------------------------------------------------------------- diff --git a/docs/libs/ml/svm.md b/docs/libs/ml/svm.md index a9c94ec..e649949 100644 --- a/docs/libs/ml/svm.md +++ b/docs/libs/ml/svm.md @@ -74,6 +74,14 @@ SVM predicts for all subtypes of `Vector` the corresponding class label: * `predict[T <: Vector]: DataSet[T] => DataSet[LabeledVector]` +If we call predict with a `DataSet[LabeledVector]`, we make a prediction on the class label +for each example, and return a `DataSet[(Double, Double)]`. In each tuple the first element +is the true value, as was provided from the input `DataSet[LabeledVector]` and the second element +is the predicted value. You can then use these `(truth, prediction)` tuples to evaluate +the algorithm's performance. + +* `predict: DataSet[LabeledVector] => DataSet[(Double, Double)]` + ## Parameters The SVM implementation can be controlled by the following parameters: http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala index a186c5d..95f2b23 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala @@ -33,7 +33,7 @@ import org.apache.flink.ml.math.Breeze._ import breeze.linalg.{Vector => BreezeVector, DenseVector => BreezeDenseVector} -/** Implements a soft-maring SVM using the communication-efficient distributed dual coordinate +/** Implements a soft-margin SVM using the communication-efficient distributed dual coordinate * ascent algorithm (CoCoA) with hinge-loss function. * * The algorithm solves the following minimization problem: @@ -276,6 +276,57 @@ object SVM{ } } + /** [[org.apache.flink.ml.pipeline.PredictOperation]] for [[LabeledVector ]]types. The result type + * is a [[(Double, Double)]] tuple, corresponding to (truth, prediction) + * + * @return A DataSet[(Double, Double)] where each tuple is a (truth, prediction) pair. + */ + implicit def predictLabeledValues = { + new PredictOperation[SVM, LabeledVector, (Double, Double)]{ + override def predict( + instance: SVM, + predictParameters: ParameterMap, + input: DataSet[LabeledVector]) + : DataSet[(Double, Double)] = { + + instance.weightsOption match { + case Some(weights) => { + input.map(new LabeledPredictionMapper).withBroadcastSet(weights, WEIGHT_VECTOR) + } + + case None => { + throw new RuntimeException("The SVM model has not been trained. Call first fit" + + "before calling the predict operation.") + } + } + } + } + } + + /** Mapper to calculate the value of the prediction function. This is a RichMapFunction, because + * we broadcast the weight vector to all mappers. + */ + class LabeledPredictionMapper extends RichMapFunction[LabeledVector, (Double, Double)] { + + var weights: BreezeDenseVector[Double] = _ + + @throws(classOf[Exception]) + override def open(configuration: Configuration): Unit = { + // get current weights + weights = getRuntimeContext. + getBroadcastVariable[BreezeDenseVector[Double]](WEIGHT_VECTOR).get(0) + } + + override def map(labeledVector: LabeledVector): (Double, Double) = { + // calculate the prediction value (scaled distance from the separating hyperplane) + val prediction = weights dot labeledVector.vector.asBreeze + val truth = labeledVector.label + + (truth, prediction) + } + } + + /** [[FitOperation]] which trains a SVM with soft-margin based on the given training data set. * */ http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala index 64b24dc..32746a1 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala @@ -21,11 +21,9 @@ package org.apache.flink.ml.regression import org.apache.flink.api.common.functions.RichMapFunction import org.apache.flink.api.scala.DataSet import org.apache.flink.configuration.Configuration -import org.apache.flink.ml.math.Vector +import org.apache.flink.ml.math.{DenseVector, BLAS, Vector, vector2Array} import org.apache.flink.ml.common._ -import org.apache.flink.ml.math.vector2Array - import org.apache.flink.api.scala._ import com.github.fommil.netlib.BLAS.{ getInstance => blas } @@ -348,6 +346,60 @@ object MultipleLinearRegression { LabeledVector(prediction, value) } } + + /** Calculates the predictions for labeled data with respect to the learned linear model. + * + * @return A DataSet[(Double, Double)] where each tuple is a (truth, prediction) pair. + */ + implicit def predictLabeledVectors = { + new PredictOperation[MultipleLinearRegression, LabeledVector, (Double, Double)] { + override def predict( + instance: MultipleLinearRegression, + predictParameters: ParameterMap, + input: DataSet[LabeledVector]) + : DataSet[(Double, Double)] = { + instance.weightsOption match { + case Some(weights) => { + input.map(new LinearRegressionLabeledPrediction) + .withBroadcastSet(weights, WEIGHTVECTOR_BROADCAST) + } + + case None => { + throw new RuntimeException("The MultipleLinearRegression has not been fitted to the " + + "data. This is necessary to learn the weight vector of the linear function.") + } + } + } + } + } + + private class LinearRegressionLabeledPrediction + extends RichMapFunction[LabeledVector, (Double, Double)] { + private var weights: Array[Double] = null + private var weight0: Double = 0 + + + @throws(classOf[Exception]) + override def open(configuration: Configuration): Unit = { + val t = getRuntimeContext + .getBroadcastVariable[(Array[Double], Double)](WEIGHTVECTOR_BROADCAST) + + val weightsPair = t.get(0) + + weights = weightsPair._1 + weight0 = weightsPair._2 + } + + override def map(labeledVector: LabeledVector ): (Double, Double) = { + + val truth = labeledVector.label + val dotProduct = BLAS.dot(DenseVector(weights), labeledVector.vector) + + val prediction = dotProduct + weight0 + + (truth, prediction) + } + } } //-------------------------------------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala index 55ef056..25c2afb 100644 --- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala @@ -49,4 +49,35 @@ class SVMITSuite extends FlatSpec with Matchers with FlinkTestBase { weight should be(expectedWeight +- 0.1) } } + + it should "make (mostly) correct predictions" in { + val env = ExecutionEnvironment.getExecutionEnvironment + + val svm = SVM(). + setBlocks(env.getParallelism). + setIterations(100). + setLocalIterations(100). + setRegularization(0.002). + setStepsize(0.1). + setSeed(0) + + val trainingDS = env.fromCollection(Classification.trainingData) + + svm.fit(trainingDS) + + val threshold = 0.0 + + val predictionPairs = svm.predict(trainingDS).map { + truthPrediction => + val truth = truthPrediction._1 + val prediction = truthPrediction._2 + val thresholdedPrediction = if (prediction > threshold) 1.0 else -1.0 + (truth, thresholdedPrediction) + } + + val absoluteErrorSum = predictionPairs.collect().map{ + case (truth, prediction) => Math.abs(truth - prediction)}.sum + + absoluteErrorSum should be < 15.0 + } } http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala index 8be239a..30338e5 100644 --- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala @@ -106,4 +106,28 @@ class MultipleLinearRegressionITSuite srs should be(RegressionData.expectedPolynomialSquaredResidualSum +- 5) } + + it should "make (mostly) correct predictions" in { + val env = ExecutionEnvironment.getExecutionEnvironment + + val mlr = MultipleLinearRegression() + + import RegressionData._ + + val parameters = ParameterMap() + + parameters.add(MultipleLinearRegression.Stepsize, 1.0) + parameters.add(MultipleLinearRegression.Iterations, 10) + parameters.add(MultipleLinearRegression.ConvergenceThreshold, 0.001) + + val inputDS = env.fromCollection(data) + mlr.fit(inputDS, parameters) + + val predictionPairs = mlr.predict(inputDS) + + val absoluteErrorSum = predictionPairs.collect().map{ + case (truth, prediction) => Math.abs(truth - prediction)}.sum + + absoluteErrorSum should be < 50.0 + } }