[FLINK-2297] [ml] Adds threshold setting for SVM binary predictions.

Added Threshold option for SVM, to determine which predictions are 
positive/negative.

Added parameter to determine output of prediction function.

Determines whether we output raw distances to the boundary
or binary class labels.

Documentation fixes

This closes #874.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/eb23f807
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/eb23f807
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/eb23f807

Branch: refs/heads/master
Commit: eb23f80742e743a75888f5d4fbed879383b6741e
Parents: 4cc7cf3
Author: Theodore Vasiloudis <t...@sics.se>
Authored: Tue Jun 30 14:49:03 2015 +0200
Committer: Till Rohrmann <trohrm...@apache.org>
Committed: Thu Jul 2 14:35:27 2015 +0200

----------------------------------------------------------------------
 docs/libs/ml/svm.md                             | 112 ++++++++++++-------
 .../apache/flink/ml/classification/SVM.scala    |  83 ++++++++++++--
 .../apache/flink/ml/common/ParameterMap.scala   |   2 +-
 .../apache/flink/ml/pipeline/Predictor.scala    |  48 ++++----
 .../flink/ml/classification/SVMITSuite.scala    |  41 +++++--
 5 files changed, 203 insertions(+), 83 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/eb23f807/docs/libs/ml/svm.md
----------------------------------------------------------------------
diff --git a/docs/libs/ml/svm.md b/docs/libs/ml/svm.md
index e649949..44e4d90 100644
--- a/docs/libs/ml/svm.md
+++ b/docs/libs/ml/svm.md
@@ -27,34 +27,34 @@ under the License.
 
 ## Description
 
-Implements an SVM with soft-margin using the communication-efficient 
distributed dual coordinate 
-ascent algorithm with hinge-loss function. 
+Implements an SVM with soft-margin using the communication-efficient 
distributed dual coordinate
+ascent algorithm with hinge-loss function.
 The algorithm solves the following minimization problem:
-  
+
 $$\min_{\mathbf{w} \in \mathbb{R}^d} \frac{\lambda}{2} \left\lVert \mathbf{w} 
\right\rVert^2 + \frac{1}{n} \sum_{i=1}^n 
l_{i}\left(\mathbf{w}^T\mathbf{x}_i\right)$$
- 
-with $\mathbf{w}$ being the weight vector, $\lambda$ being the regularization 
constant, 
-$$\mathbf{x}_i \in \mathbb{R}^d$$ being the data points and $$l_{i}$$ being 
the convex loss 
+
+with $\mathbf{w}$ being the weight vector, $\lambda$ being the regularization 
constant,
+$$\mathbf{x}_i \in \mathbb{R}^d$$ being the data points and $$l_{i}$$ being 
the convex loss
 functions, which can also depend on the labels $$y_{i} \in \mathbb{R}$$.
 In the current implementation the regularizer is the $\ell_2$-norm and the 
loss functions are the hinge-loss functions:
-  
+
   $$l_{i} = \max\left(0, 1 - y_{i} \mathbf{w}^T\mathbf{x}_i \right)$$
 
 With these choices, the problem definition is equivalent to a SVM with 
soft-margin.
 Thus, the algorithm allows us to train a SVM with soft-margin.
 
 The minimization problem is solved by applying stochastic dual coordinate 
ascent (SDCA).
-In order to make the algorithm efficient in a distributed setting, the CoCoA 
algorithm calculates 
+In order to make the algorithm efficient in a distributed setting, the CoCoA 
algorithm calculates
 several iterations of SDCA locally on a data block before merging the local 
updates into a
 valid global state.
-This state is redistributed to the different data partitions where the next 
round of local SDCA 
+This state is redistributed to the different data partitions where the next 
round of local SDCA
 iterations is then executed.
-The number of outer iterations and local SDCA iterations control the overall 
network costs, because 
+The number of outer iterations and local SDCA iterations control the overall 
network costs, because
 there is only network communication required for each outer iteration.
-The local SDCA iterations are embarrassingly parallel once the individual data 
partitions have been 
+The local SDCA iterations are embarrassingly parallel once the individual data 
partitions have been
 distributed across the cluster.
 
-The implementation of this algorithm is based on the work of 
+The implementation of this algorithm is based on the work of
 [Jaggi et al.](http://arxiv.org/abs/1409.1458)
 
 ## Operations
@@ -64,23 +64,24 @@ As such, it supports the `fit` and `predict` operation.
 
 ### Fit
 
-SVM is trained given a set of `LabeledVector`: 
+SVM is trained given a set of `LabeledVector`:
 
 * `fit: DataSet[LabeledVector] => Unit`
 
 ### Predict
 
-SVM predicts for all subtypes of `Vector` the corresponding class label: 
+SVM predicts for all subtypes of FlinkML's `Vector` the corresponding class 
label:
 
-* `predict[T <: Vector]: DataSet[T] => DataSet[LabeledVector]`
+* `predict[T <: Vector]: DataSet[T] => DataSet[(T, Double)]`, where the `(T, 
Double)` tuple
+  corresponds to (original_features, label)
 
-If we call predict with a `DataSet[LabeledVector]`, we make a prediction on 
the class label
+If we call evaluate with a `DataSet[(Vector, Double)]`, we make a prediction 
on the class label
 for each example, and return a `DataSet[(Double, Double)]`. In each tuple the 
first element
-is the true value, as was provided from the input `DataSet[LabeledVector]` and 
the second element
+is the true value, as was provided from the input `DataSet[(Vector, Double)]` 
and the second element
 is the predicted value. You can then use these `(truth, prediction)` tuples to 
evaluate
 the algorithm's performance.
 
-* `predict: DataSet[LabeledVector] => DataSet[(Double, Double)]`
+* `predict: DataSet[(Vector, Double)] => DataSet[(Double, Double)]`
 
 ## Parameters
 
@@ -99,10 +100,10 @@ The SVM implementation can be controlled by the following 
parameters:
         <td><strong>Blocks</strong></td>
         <td>
           <p>
-            Sets the number of blocks into which the input data will be split. 
-            On each block the local stochastic dual coordinate ascent method 
is executed. 
-            This number should be set at least to the degree of parallelism. 
-            If no value is specified, then the parallelism of the input 
DataSet is used as the number of blocks. 
+            Sets the number of blocks into which the input data will be split.
+            On each block the local stochastic dual coordinate ascent method 
is executed.
+            This number should be set at least to the degree of parallelism.
+            If no value is specified, then the parallelism of the input 
DataSet is used as the number of blocks.
             (Default value: <strong>None</strong>)
           </p>
         </td>
@@ -111,8 +112,8 @@ The SVM implementation can be controlled by the following 
parameters:
         <td><strong>Iterations</strong></td>
         <td>
           <p>
-            Defines the maximum number of iterations of the outer loop method. 
-            In other words, it defines how often the SDCA method is applied to 
the blocked data. 
+            Defines the maximum number of iterations of the outer loop method.
+            In other words, it defines how often the SDCA method is applied to 
the blocked data.
             After each iteration, the locally computed weight vector updates 
have to be reduced to update the global weight vector value.
             The new weight vector is broadcast to all SDCA tasks at the 
beginning of each iteration.
             (Default value: <strong>10</strong>)
@@ -123,7 +124,7 @@ The SVM implementation can be controlled by the following 
parameters:
         <td><strong>LocalIterations</strong></td>
         <td>
           <p>
-            Defines the maximum number of SDCA iterations. 
+            Defines the maximum number of SDCA iterations.
             In other words, it defines how many data points are drawn from 
each local data block to calculate the stochastic dual coordinate ascent.
             (Default value: <strong>10</strong>)
           </p>
@@ -133,8 +134,8 @@ The SVM implementation can be controlled by the following 
parameters:
         <td><strong>Regularization</strong></td>
         <td>
           <p>
-            Defines the regularization constant of the SVM algorithm. 
-            The higher the value, the smaller will the 2-norm of the weight 
vector be. 
+            Defines the regularization constant of the SVM algorithm.
+            The higher the value, the smaller will the 2-norm of the weight 
vector be.
             In case of a SVM with hinge loss this means that the SVM margin 
will be wider even though it might contain some false classifications.
             (Default value: <strong>1.0</strong>)
           </p>
@@ -144,47 +145,76 @@ The SVM implementation can be controlled by the following 
parameters:
         <td><strong>Stepsize</strong></td>
         <td>
           <p>
-            Defines the initial step size for the updates of the weight 
vector. 
-            The larger the step size is, the larger will be the contribution 
of the weight vector updates to the next weight vector value. 
+            Defines the initial step size for the updates of the weight vector.
+            The larger the step size is, the larger will be the contribution 
of the weight vector updates to the next weight vector value.
             The effective scaling of the updates is $\frac{stepsize}{blocks}$.
-            This value has to be tuned in case that the algorithm becomes 
unstable. 
+            This value has to be tuned in case that the algorithm becomes 
unstable.
             (Default value: <strong>1.0</strong>)
           </p>
         </td>
       </tr>
       <tr>
-        <td><strong>Seed</strong></td>
+        <td><strong>ThresholdValue</strong></td>
         <td>
           <p>
-            Defines the seed to initialize the random number generator. 
-            The seed directly controls which data points are chosen for the 
SDCA method. 
-            (Default value: <strong>0</strong>)
+            Defines the limiting value for the decision function above which 
examples are labeled as
+            positive (+1.0). Examples with a decision function value below 
this value are classified
+            as negative (-1.0). In order to get the raw decision function 
values you need to indicate it by
+            using the OutputDecisionFunction parameter.  (Default value: 
<strong>0.0</strong>)
           </p>
         </td>
       </tr>
+      <tr>
+        <td><strong>OutputDecisionFunction</strong></td>
+        <td>
+          <p>
+            Determines whether the predict and evaluate functions of the SVM 
should return the distance
+            to the separating hyperplane, or binary class labels. Setting this 
to true will 
+            return the raw distance to the hyperplane for each example. 
Setting it to false will 
+            return the binary class label (+1.0, -1.0) (Default value: 
<strong>false<\strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+      <td><strong>Seed</strong></td>
+      <td>
+        <p>
+          Defines the seed to initialize the random number generator.
+          The seed directly controls which data points are chosen for the SDCA 
method.
+          (Default value: <strong>0</strong>)
+        </p>
+      </td>
+    </tr>
     </tbody>
   </table>
 
 ## Examples
 
 {% highlight scala %}
+import org.apache.flink.api.scala._
+import org.apache.flink.ml.math.Vector
+import org.apache.flink.ml.common.LabeledVector
+import org.apache.flink.ml.classification.SVM
+import org.apache.flink.ml.RichExecutionEnvironment
+
+val pathToTrainingFile: String = ???
+val pathToTestingFile: String = ???
+val env = ExecutionEnvironment.getExecutionEnvironment
+
 // Read the training data set, from a LibSVM formatted file
 val trainingDS: DataSet[LabeledVector] = env.readLibSVM(pathToTrainingFile)
 
 // Create the SVM learner
 val svm = SVM()
-.setBlocks(10)
-.setIterations(10)
-.setLocalIterations(10)
-.setRegularization(0.5)
-.setStepsize(0.5)
+  .setBlocks(10)
 
 // Learn the SVM model
 svm.fit(trainingDS)
 
 // Read the testing data set
-val testingDS: DataSet[Vector] = env.readVectorFile(pathToTestingFile)
+val testingDS: DataSet[Vector] = env.readLibSVM(pathToTestingFile).map(lv => 
lv.vector)
 
 // Calculate the predictions for the testing data set
-val predictionDS: DataSet[LabeledVector] = svm.predict(testingDS)
+val predictionDS: DataSet[(Vector, Double)] = svm.predict(testingDS)
+
 {% endhighlight %}

http://git-wip-us.apache.org/repos/asf/flink/blob/eb23f807/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
 
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
index bd46204..4c539d9 100644
--- 
a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
+++ 
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
@@ -39,6 +39,9 @@ import breeze.linalg.{Vector => BreezeVector, DenseVector => 
BreezeDenseVector}
 /** Implements a soft-margin SVM using the communication-efficient distributed 
dual coordinate
   * ascent algorithm (CoCoA) with hinge-loss function.
   *
+  * It can be used for binary classification problems, with the labels set as 
+1.0 to indiciate a
+  * positive example and -1.0 to indicate a negative example.
+  *
   * The algorithm solves the following minimization problem:
   *
   * `min_{w in bbb"R"^d} lambda/2 ||w||^2 + 1/n sum_(i=1)^n l_{i}(w^Tx_i)`
@@ -69,20 +72,17 @@ import breeze.linalg.{Vector => BreezeVector, DenseVector 
=> BreezeDenseVector}
   *
   * @example
   *          {{{
-  *             val trainingDS: DataSet[LabeledVector] = 
env.readSVMFile(pathToTrainingFile)
+  *             val trainingDS: DataSet[LabeledVector] = 
env.readLibSVM(pathToTrainingFile)
   *
   *             val svm = SVM()
   *               .setBlocks(10)
-  *               .setIterations(10)
-  *               .setLocalIterations(10)
-  *               .setRegularization(0.5)
-  *               .setStepsize(0.5)
   *
   *             svm.fit(trainingDS)
   *
-  *             val testingDS: DataSet[Vector] = 
env.readVectorFile(pathToTestingFile)
+  *             val testingDS: DataSet[Vector] = 
env.readLibSVM(pathToTestingFile)
+  *               .map(lv => lv.vector)
   *
-  *             val predictionDS: DataSet[LabeledVector] = 
svm.predict(testingDS)
+  *             val predictionDS: DataSet[(Vector, Double)] = 
svm.predict(testingDS)
   *          }}}
   *
   * =Parameters=
@@ -120,6 +120,19 @@ import breeze.linalg.{Vector => BreezeVector, DenseVector 
=> BreezeDenseVector}
   *  - [[org.apache.flink.ml.classification.SVM.Seed]]:
   *  Defines the seed to initialize the random number generator. The seed 
directly controls which
   *  data points are chosen for the SDCA method. (Default value: '''0''')
+  *
+  *  - [[org.apache.flink.ml.classification.SVM.ThresholdValue]]:
+  *  Defines the limiting value for the decision function above which examples 
are labeled as
+  *  positive (+1.0). Examples with a decision function value below this value 
are classified as
+  *  negative(-1.0). In order to get the raw decision function values you need 
to indicate it by
+  *  using the 
[[org.apache.flink.ml.classification.SVM.OutputDecisionFunction]].
+  *  (Default value: '''0.0''')
+  *
+  *  - [[org.apache.flink.ml.classification.SVM.OutputDecisionFunction]]:
+  *  Determines whether the predict and evaluate functions of the SVM should 
return the distance
+  *  to the separating hyperplane, or binary class labels. Setting this to 
true will return the raw
+  *  distance to the hyperplane for each example. Setting it to false will 
return the binary
+  *  class label (+1.0, -1.0) (Default value: '''false''')
   */
 class SVM extends Predictor[SVM] {
 
@@ -187,6 +200,34 @@ class SVM extends Predictor[SVM] {
     parameters.add(Seed, seed)
     this
   }
+
+  /** Sets the threshold above which elements are classified as positive.
+    *
+    * The [[predict ]] and [[evaluate]] functions will return +1.0 for items 
with a decision
+    * function value above this threshold, and -1.0 for items below it.
+    * @param threshold
+    * @return
+    */
+  def setThreshold(threshold: Double): SVM = {
+    parameters.add(ThresholdValue, threshold)
+    this
+  }
+
+  /** Sets whether the predictions should return the raw decision function 
value or the
+    * thresholded binary value.
+    *
+    * When setting this to true, predict and evaluate return the raw decision 
value, which is
+    * the distance from the separating hyperplane.
+    * When setting this to false, they return thresholded (+1.0, -1.0) values.
+    *
+    * @param outputDecisionFunction When set to true, [[predict ]] and 
[[evaluate]] return the raw
+    *                               decision function values. When set to 
false, they return the
+    *                               thresholded binary values (+1.0, -1.0).
+    */
+  def setOutputDecisionFunction(outputDecisionFunction: Boolean): SVM = {
+    parameters.add(OutputDecisionFunction, outputDecisionFunction)
+    this
+  }
 }
 
 /** Companion object of SVM. Contains convenience functions and the parameter 
type definitions
@@ -222,6 +263,14 @@ object SVM{
     val defaultValue = Some(0L)
   }
 
+  case object ThresholdValue extends Parameter[Double] {
+    val defaultValue = Some(0.0)
+  }
+
+  case object OutputDecisionFunction extends Parameter[Boolean] {
+    val defaultValue = Some(false)
+  }
+
   // ========================================== Factory methods 
====================================
 
   def apply(): SVM = {
@@ -230,9 +279,21 @@ object SVM{
 
   // ========================================== Operations 
=========================================
 
+  /** Provides the operation that makes the predictions for individual 
examples.
+    *
+    * @tparam T
+    * @return A PredictOperation, through which it is possible to predict a 
value, given a
+    *         feature vector
+    */
   implicit def predictVectors[T <: Vector] = {
     new PredictOperation[SVM, DenseVector, T, Double](){
+
+      var thresholdValue: Double = _
+      var outputDecisionFunction: Boolean = _
+
       override def getModel(self: SVM, predictParameters: ParameterMap): 
DataSet[DenseVector] = {
+        thresholdValue = predictParameters(ThresholdValue)
+        outputDecisionFunction = predictParameters(OutputDecisionFunction)
         self.weightsOption match {
           case Some(model) => model
           case None => {
@@ -243,7 +304,13 @@ object SVM{
       }
 
       override def predict(value: T, model: DenseVector): Double = {
-        value.asBreeze dot model.asBreeze
+        val rawValue = value.asBreeze dot model.asBreeze
+
+        if (outputDecisionFunction) {
+          rawValue
+        } else {
+          if (rawValue > thresholdValue) 1.0 else -1.0
+        }
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/eb23f807/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala
 
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala
index a5efe8a..77d2d46 100644
--- 
a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala
+++ 
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala
@@ -21,7 +21,7 @@ package org.apache.flink.ml.common
 import scala.collection.mutable
 
 /**
- * Map used to store configuration parameters for [[Learner]] and 
[[Transformer]]. The parameter
+ * Map used to store configuration parameters for algorithms. The parameter
  * values are stored in a [[Map]] being identified by a [[Parameter]] object. 
ParameterMaps can
  * be fused. This operation is left associative, meaning that latter 
ParameterMaps can override
  * parameter values defined in a preceding ParameterMap.

http://git-wip-us.apache.org/repos/asf/flink/blob/eb23f807/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala
 
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala
index cd9cc51..9d11cff 100644
--- 
a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala
+++ 
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala
@@ -191,10 +191,10 @@ object Predictor {
 trait PredictDataSetOperation[Self, Testing, Prediction] extends Serializable{
 
   /** Calculates the predictions for all elements in the [[DataSet]] input
-    * 
-    * @param instance
-    * @param predictParameters
-    * @param input
+    *
+    * @param instance The Predictor instance that we will use to make the 
predictions
+    * @param predictParameters The parameters for the prediction
+    * @param input The DataSet containing the unlabeled examples
     * @return
     */
   def predictDataSet(
@@ -210,43 +210,49 @@ trait PredictDataSetOperation[Self, Testing, Prediction] 
extends Serializable{
   * It is sufficient for a [[Predictor]] to only implement this trait to 
support the evaluate and
   * predict method.
   *
-  * @tparam Instance
-  * @tparam Model
-  * @tparam Testing
-  * @tparam Prediction
+  * @tparam Instance The concrete type of the [[Predictor]] that we will use 
for predictions
+  * @tparam Model The representation of the predictive model for the 
algorithm, for example a
+  *               Vector of weights
+  * @tparam Testing The type of the example that we will use to make the 
predictions (input)
+  * @tparam Prediction The type of the label that the prediction operation 
will produce (output)
+  *
   */
 trait PredictOperation[Instance, Model, Testing, Prediction] extends 
Serializable{
 
   /** Defines how to retrieve the model of the type for which this operation 
was defined
-    * 
-    * @param instance
-    * @return
+    *
+    * @param instance The Predictor instance that we will use to make the 
predictions
+    * @param predictParameters The parameters for the prediction
+    * @return A DataSet with the model representation as its only element
     */
   def getModel(instance: Instance, predictParameters: ParameterMap): 
DataSet[Model]
 
   /** Calculates the prediction for a single element given the model of the 
[[Predictor]].
     *
-    * @param value
-    * @param model
-    * @return
+    * @param value The unlabeled example on which we make the prediction
+    * @param model The model representation of the prediciton algorithm
+    * @return A label for the provided example of type [[Prediction]]
     */
-  def predict(value: Testing, model: Model): Prediction
+  def predict(value: Testing, model: Model):
+    Prediction
 }
 
-/** Type class for the evalute operation of [[Predictor]]. This evaluate 
operation works on
+/** Type class for the evaluate operation of [[Predictor]]. This evaluate 
operation works on
   * DataSets.
   *
   * It takes a [[DataSet]] of some type. For each element of this [[DataSet]] 
the evaluate method
   * computes the prediction value and returns a tuple of true label value and 
prediction value.
   *
-  * @tparam Instance
-  * @tparam Testing
-  * @tparam PredictionValue
+  * @tparam Instance The concrete type of the Predictor instance that we will 
use to make the
+  *                  predictions
+  * @tparam Testing The type of the example that we will use to make the 
predictions (input)
+  * @tparam Prediction The type of the label that the prediction operation 
will produce (output)
+  *
   */
-trait EvaluateDataSetOperation[Instance, Testing, PredictionValue] extends 
Serializable{
+trait EvaluateDataSetOperation[Instance, Testing, Prediction] extends 
Serializable{
   def evaluateDataSet(
       instance: Instance,
       evaluateParameters: ParameterMap,
       testing: DataSet[Testing])
-    : DataSet[(PredictionValue, PredictionValue)]
+    : DataSet[(Prediction, Prediction)]
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/eb23f807/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala
 
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala
index b1a91a2..57a7783 100644
--- 
a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala
+++ 
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala
@@ -19,6 +19,7 @@
 package org.apache.flink.ml.classification
 
 import org.scalatest.{FlatSpec, Matchers}
+import org.apache.flink.ml.math.{Vector => FlinkVector, DenseVector}
 
 import org.apache.flink.api.scala._
 import org.apache.flink.test.util.FlinkTestBase
@@ -40,11 +41,9 @@ class SVMITSuite extends FlatSpec with Matchers with 
FlinkTestBase {
 
     val trainingDS = env.fromCollection(Classification.trainingData)
 
-    val testingDS = trainingDS.map(_.vector)
-
     svm.fit(trainingDS)
 
-    val weightVector = svm.weightsOption.get.collect().apply(0)
+    val weightVector = svm.weightsOption.get.collect().head
 
     
weightVector.valueIterator.zip(Classification.expectedWeightVector.valueIterator).foreach
 {
       case (weight, expectedWeight) =>
@@ -69,19 +68,37 @@ class SVMITSuite extends FlatSpec with Matchers with 
FlinkTestBase {
 
     svm.fit(trainingDS)
 
-    val threshold = 0.0
-
-    val predictionPairs = svm.evaluate(test).map {
-      truthPrediction =>
-        val truth = truthPrediction._1
-        val prediction = truthPrediction._2
-        val thresholdedPrediction = if (prediction > threshold) 1.0 else -1.0
-        (truth, thresholdedPrediction)
-    }
+    val predictionPairs = svm.evaluate(test)
 
     val absoluteErrorSum = predictionPairs.collect().map{
       case (truth, prediction) => Math.abs(truth - prediction)}.sum
 
     absoluteErrorSum should be < 15.0
   }
+
+  it should "be possible to get the raw decision function values" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    val svm = SVM().
+      setBlocks(env.getParallelism)
+      .setOutputDecisionFunction(false)
+
+    val customWeights = env.fromElements(DenseVector(1.0, 1.0, 1.0))
+
+    svm.weightsOption = Option(customWeights)
+
+    val test = env.fromElements(DenseVector(5.0, 5.0, 5.0))
+
+    val thresholdedPrediction = svm.predict(test).map(vectorLabel => 
vectorLabel._2).collect().head
+
+    thresholdedPrediction should be (1.0 +- 1e-9)
+
+    svm.setOutputDecisionFunction(true)
+
+    val rawPrediction = svm.predict(test).map(vectorLabel => 
vectorLabel._2).collect().head
+
+    rawPrediction should be (15.0 +- 1e-9)
+
+
+  }
 }

Reply via email to