Repository: ignite
Updated Branches:
  refs/heads/master cd0ead329 -> f566bedbc


IGNITE-9022: [ML] Implement class labels mapping for
SVM binary classifier

this closes #4749


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/f566bedb
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/f566bedb
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/f566bedb

Branch: refs/heads/master
Commit: f566bedbcb400e5102af48470d1aa4395c97a13f
Parents: cd0ead3
Author: zaleslaw <[email protected]>
Authored: Tue Sep 18 15:14:03 2018 +0300
Committer: Yury Babak <[email protected]>
Committed: Tue Sep 18 15:14:03 2018 +0300

----------------------------------------------------------------------
 .../binary/SVMBinaryClassificationExample.java  | 105 +++++++++----------
 .../svm/SVMLinearBinaryClassificationModel.java |   8 +-
 .../SVMLinearBinaryClassificationTrainer.java   |  25 +++--
 ...VMLinearMultiClassClassificationTrainer.java |   2 +-
 .../ignite/ml/svm/SVMBinaryTrainerTest.java     |  12 +--
 .../org/apache/ignite/ml/svm/SVMModelTest.java  |   8 +-
 .../ignite/ml/svm/SVMMultiClassTrainerTest.java |  10 +-
 7 files changed, 88 insertions(+), 82 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/f566bedb/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java
----------------------------------------------------------------------
diff --git 
a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java
 
b/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java
index bd88c20..f71db2d 100644
--- 
a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java
+++ 
b/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java
@@ -91,8 +91,8 @@ public class SVMBinaryClassificationExample {
                         if(groundTruth != prediction)
                             amountOfErrors++;
 
-                        int idx1 = (int)prediction == -1.0 ? 0 : 1;
-                        int idx2 = (int)groundTruth == -1.0 ? 0 : 1;
+                        int idx1 = prediction == 0.0 ? 0 : 1;
+                        int idx2 = groundTruth == 0.0 ? 0 : 1;
 
                         confusionMtx[idx1][idx2]++;
 
@@ -117,56 +117,56 @@ public class SVMBinaryClassificationExample {
 
     /** The 1st and 2nd classes from the Iris dataset. */
     private static final double[][] data = {
-        {-1, 5.1, 3.5, 1.4, 0.2},
-        {-1, 4.9, 3, 1.4, 0.2},
-        {-1, 4.7, 3.2, 1.3, 0.2},
-        {-1, 4.6, 3.1, 1.5, 0.2},
-        {-1, 5, 3.6, 1.4, 0.2},
-        {-1, 5.4, 3.9, 1.7, 0.4},
-        {-1, 4.6, 3.4, 1.4, 0.3},
-        {-1, 5, 3.4, 1.5, 0.2},
-        {-1, 4.4, 2.9, 1.4, 0.2},
-        {-1, 4.9, 3.1, 1.5, 0.1},
-        {-1, 5.4, 3.7, 1.5, 0.2},
-        {-1, 4.8, 3.4, 1.6, 0.2},
-        {-1, 4.8, 3, 1.4, 0.1},
-        {-1, 4.3, 3, 1.1, 0.1},
-        {-1, 5.8, 4, 1.2, 0.2},
-        {-1, 5.7, 4.4, 1.5, 0.4},
-        {-1, 5.4, 3.9, 1.3, 0.4},
-        {-1, 5.1, 3.5, 1.4, 0.3},
-        {-1, 5.7, 3.8, 1.7, 0.3},
-        {-1, 5.1, 3.8, 1.5, 0.3},
-        {-1, 5.4, 3.4, 1.7, 0.2},
-        {-1, 5.1, 3.7, 1.5, 0.4},
-        {-1, 4.6, 3.6, 1, 0.2},
-        {-1, 5.1, 3.3, 1.7, 0.5},
-        {-1, 4.8, 3.4, 1.9, 0.2},
-        {-1, 5, 3, 1.6, 0.2},
-        {-1, 5, 3.4, 1.6, 0.4},
-        {-1, 5.2, 3.5, 1.5, 0.2},
-        {-1, 5.2, 3.4, 1.4, 0.2},
-        {-1, 4.7, 3.2, 1.6, 0.2},
-        {-1, 4.8, 3.1, 1.6, 0.2},
-        {-1, 5.4, 3.4, 1.5, 0.4},
-        {-1, 5.2, 4.1, 1.5, 0.1},
-        {-1, 5.5, 4.2, 1.4, 0.2},
-        {-1, 4.9, 3.1, 1.5, 0.1},
-        {-1, 5, 3.2, 1.2, 0.2},
-        {-1, 5.5, 3.5, 1.3, 0.2},
-        {-1, 4.9, 3.1, 1.5, 0.1},
-        {-1, 4.4, 3, 1.3, 0.2},
-        {-1, 5.1, 3.4, 1.5, 0.2},
-        {-1, 5, 3.5, 1.3, 0.3},
-        {-1, 4.5, 2.3, 1.3, 0.3},
-        {-1, 4.4, 3.2, 1.3, 0.2},
-        {-1, 5, 3.5, 1.6, 0.6},
-        {-1, 5.1, 3.8, 1.9, 0.4},
-        {-1, 4.8, 3, 1.4, 0.3},
-        {-1, 5.1, 3.8, 1.6, 0.2},
-        {-1, 4.6, 3.2, 1.4, 0.2},
-        {-1, 5.3, 3.7, 1.5, 0.2},
-        {-1, 5, 3.3, 1.4, 0.2},
+        {0, 5.1, 3.5, 1.4, 0.2},
+        {0, 4.9, 3, 1.4, 0.2},
+        {0, 4.7, 3.2, 1.3, 0.2},
+        {0, 4.6, 3.1, 1.5, 0.2},
+        {0, 5, 3.6, 1.4, 0.2},
+        {0, 5.4, 3.9, 1.7, 0.4},
+        {0, 4.6, 3.4, 1.4, 0.3},
+        {0, 5, 3.4, 1.5, 0.2},
+        {0, 4.4, 2.9, 1.4, 0.2},
+        {0, 4.9, 3.1, 1.5, 0.1},
+        {0, 5.4, 3.7, 1.5, 0.2},
+        {0, 4.8, 3.4, 1.6, 0.2},
+        {0, 4.8, 3, 1.4, 0.1},
+        {0, 4.3, 3, 1.1, 0.1},
+        {0, 5.8, 4, 1.2, 0.2},
+        {0, 5.7, 4.4, 1.5, 0.4},
+        {0, 5.4, 3.9, 1.3, 0.4},
+        {0, 5.1, 3.5, 1.4, 0.3},
+        {0, 5.7, 3.8, 1.7, 0.3},
+        {0, 5.1, 3.8, 1.5, 0.3},
+        {0, 5.4, 3.4, 1.7, 0.2},
+        {0, 5.1, 3.7, 1.5, 0.4},
+        {0, 4.6, 3.6, 1, 0.2},
+        {0, 5.1, 3.3, 1.7, 0.5},
+        {0, 4.8, 3.4, 1.9, 0.2},
+        {0, 5, 3, 1.6, 0.2},
+        {0, 5, 3.4, 1.6, 0.4},
+        {0, 5.2, 3.5, 1.5, 0.2},
+        {0, 5.2, 3.4, 1.4, 0.2},
+        {0, 4.7, 3.2, 1.6, 0.2},
+        {0, 4.8, 3.1, 1.6, 0.2},
+        {0, 5.4, 3.4, 1.5, 0.4},
+        {0, 5.2, 4.1, 1.5, 0.1},
+        {0, 5.5, 4.2, 1.4, 0.2},
+        {0, 4.9, 3.1, 1.5, 0.1},
+        {0, 5, 3.2, 1.2, 0.2},
+        {0, 5.5, 3.5, 1.3, 0.2},
+        {0, 4.9, 3.1, 1.5, 0.1},
+        {0, 4.4, 3, 1.3, 0.2},
+        {0, 5.1, 3.4, 1.5, 0.2},
+        {0, 5, 3.5, 1.3, 0.3},
+        {0, 4.5, 2.3, 1.3, 0.3},
+        {0, 4.4, 3.2, 1.3, 0.2},
+        {0, 5, 3.5, 1.6, 0.6},
+        {0, 5.1, 3.8, 1.9, 0.4},
+        {0, 4.8, 3, 1.4, 0.3},
+        {0, 5.1, 3.8, 1.6, 0.2},
+        {0, 4.6, 3.2, 1.4, 0.2},
+        {0, 5.3, 3.7, 1.5, 0.2},
+        {0, 5, 3.3, 1.4, 0.2},
         {1, 7, 3.2, 4.7, 1.4},
         {1, 6.4, 3.2, 4.5, 1.5},
         {1, 6.9, 3.1, 4.9, 1.5},
@@ -218,5 +218,4 @@ public class SVMBinaryClassificationExample {
         {1, 5.1, 2.5, 3, 1.1},
         {1, 5.7, 2.8, 4.1, 1.3},
     };
-
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f566bedb/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java
index 4771e4a..f5d2b28 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java
@@ -31,11 +31,11 @@ public class SVMLinearBinaryClassificationModel implements 
Model<Vector, Double>
     /** */
     private static final long serialVersionUID = -996984622291440226L;
 
-    /** Output label format. -1 and +1 for false value and raw distances from 
the separating hyperplane otherwise. */
+    /** Output label format. '0' and '1' for false value and raw distances 
from the separating hyperplane otherwise. */
     private boolean isKeepingRawLabels = false;
 
-    /** Threshold to assign +1 label to the observation if raw value more than 
this threshold. */
-    private double threshold = 0.0;
+    /** Threshold to assign '1' label to the observation if raw value more 
than this threshold. */
+    private double threshold = 0.5;
 
     /** Multiplier of the objects's vector required to make prediction. */
     private Vector weights;
@@ -99,7 +99,7 @@ public class SVMLinearBinaryClassificationModel implements 
Model<Vector, Double>
         if (isKeepingRawLabels)
             return res;
         else
-            return res - threshold > 0 ? 1.0 : -1.0;
+            return res - threshold > 0 ? 1.0 : 0;
     }
 
     /**

http://git-wip-us.apache.org/repos/asf/ignite/blob/f566bedb/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
index 8fb98d2..2c621c8 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
@@ -35,8 +35,8 @@ import org.jetbrains.annotations.NotNull;
 
 /**
  * Base class for a soft-margin SVM linear classification trainer based on the 
communication-efficient distributed dual
- * coordinate ascent algorithm (CoCoA) with hinge-loss function. <p> This 
trainer takes input as Labeled Dataset with -1
- * and +1 labels for two classes and makes binary classification. </p> The 
paper about this algorithm could be found
+ * coordinate ascent algorithm (CoCoA) with hinge-loss function. <p> This 
trainer takes input as Labeled Dataset with 0
+ * and 1 labels for two classes and makes binary classification. </p> The 
paper about this algorithm could be found
  * here https://arxiv.org/abs/1409.1458.
  */
 public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrainer<SVMLinearBinaryClassificationModel> {
@@ -73,9 +73,17 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
 
         assert datasetBuilder != null;
 
+        IgniteBiFunction<K, V, Double> patchedLbExtractor = (k, v) ->  {
+            final Double lb = lbExtractor.apply(k, v);
+            if (lb == 0.0)
+                return -1.0;
+            else
+                return lb;
+        };
+
         PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, 
LabeledVector>> partDataBuilder = new 
LabeledDatasetPartitionDataBuilderOnHeap<>(
             featureExtractor,
-            lbExtractor
+            patchedLbExtractor
         );
 
         Vector weights;
@@ -95,9 +103,8 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
 
                 final int weightVectorSizeWithIntercept = cols + 1;
                 weights = 
initializeWeightsWithZeros(weightVectorSizeWithIntercept);
-            } else {
+            } else
                 weights = getStateVector(mdl);
-            }
 
             for (int i = 0; i < this.getAmountOfIterations(); i++) {
                 Vector deltaWeights = calculateUpdates(weights, dataset);
@@ -126,13 +133,13 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
         Vector weights = mdl.weights();
 
         int stateVectorSize = weights.size() + 1;
-        Vector result = weights.isDense() ?
+        Vector res = weights.isDense() ?
             new DenseVector(stateVectorSize) :
             new SparseVector(stateVectorSize, 
StorageConstants.RANDOM_ACCESS_MODE);
 
-        result.set(0, intercept);
-        weights.nonZeroes().forEach(ith -> result.set(ith.index(), ith.get()));
-        return result;
+        res.set(0, intercept);
+        weights.nonZeroes().forEach(ith -> res.set(ith.index(), ith.get()));
+        return res;
     }
 
     /** */

http://git-wip-us.apache.org/repos/asf/ignite/blob/f566bedb/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
index aeee178..7cbb1dc 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
@@ -93,7 +93,7 @@ public class SVMLinearMultiClassClassificationTrainer
                 if (lb.equals(clsLb))
                     return 1.0;
                 else
-                    return -1.0;
+                    return 0.0;
             };
 
             SVMLinearBinaryClassificationModel model;

http://git-wip-us.apache.org/repos/asf/ignite/blob/f566bedb/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
index 263bb6d..d6f77c0 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
@@ -50,7 +50,7 @@ public class SVMBinaryTrainerTest extends TrainerTest {
             (k, v) -> v[0]
         );
 
-        TestUtils.assertEquals(-1, mdl.apply(VectorUtils.of(100, 10)), 
PRECISION);
+        TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), 
PRECISION);
         TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), 
PRECISION);
     }
 
@@ -66,7 +66,7 @@ public class SVMBinaryTrainerTest extends TrainerTest {
             .withAmountOfIterations(1000)
             .withSeed(1234L);
 
-        SVMLinearBinaryClassificationModel originalModel = trainer.fit(
+        SVMLinearBinaryClassificationModel originalMdl = trainer.fit(
             cacheMock,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -74,7 +74,7 @@ public class SVMBinaryTrainerTest extends TrainerTest {
         );
 
         SVMLinearBinaryClassificationModel updatedOnSameDS = trainer.update(
-            originalModel,
+            originalMdl,
             cacheMock,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -82,7 +82,7 @@ public class SVMBinaryTrainerTest extends TrainerTest {
         );
 
         SVMLinearBinaryClassificationModel updatedOnEmptyDS = trainer.update(
-            originalModel,
+            originalMdl,
             new HashMap<Integer, double[]>(),
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -90,7 +90,7 @@ public class SVMBinaryTrainerTest extends TrainerTest {
         );
 
         Vector v = VectorUtils.of(100, 10);
-        TestUtils.assertEquals(originalModel.apply(v), 
updatedOnSameDS.apply(v), PRECISION);
-        TestUtils.assertEquals(originalModel.apply(v), 
updatedOnEmptyDS.apply(v), PRECISION);
+        TestUtils.assertEquals(originalMdl.apply(v), updatedOnSameDS.apply(v), 
PRECISION);
+        TestUtils.assertEquals(originalMdl.apply(v), 
updatedOnEmptyDS.apply(v), PRECISION);
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f566bedb/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java
index 9a222c3..9c452f9 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java
@@ -93,13 +93,13 @@ public class SVMModelTest {
         TestUtils.assertEquals(1.0, mdl.apply(observation), PRECISION);
 
         observation = new DenseVector(new double[]{-1.0, -1.0});
-        TestUtils.assertEquals(-1.0, mdl.apply(observation), PRECISION);
+        TestUtils.assertEquals(0.0, mdl.apply(observation), PRECISION);
 
         observation = new DenseVector(new double[]{-2.0, 1.0});
-        TestUtils.assertEquals(-1.0, mdl.apply(observation), PRECISION);
+        TestUtils.assertEquals(0.0, mdl.apply(observation), PRECISION);
 
         observation = new DenseVector(new double[]{-1.0, -2.0});
-        TestUtils.assertEquals(-1.0, mdl.apply(observation), PRECISION);
+        TestUtils.assertEquals(0.0, mdl.apply(observation), PRECISION);
 
         final SVMLinearBinaryClassificationModel mdlWithNewData = 
mdl.withIntercept(-2.0).withWeights(new DenseVector(new double[] {-2.0, -2.0}));
         System.out.println("The SVM model is " + mdlWithNewData);
@@ -116,7 +116,7 @@ public class SVMModelTest {
         SVMLinearBinaryClassificationModel mdl = new 
SVMLinearBinaryClassificationModel(weights, 1.0).withThreshold(5);
 
         Vector observation = new DenseVector(new double[]{1.0, 1.0});
-        TestUtils.assertEquals(-1.0, mdl.apply(observation), PRECISION);
+        TestUtils.assertEquals(0.0, mdl.apply(observation), PRECISION);
 
         observation = new DenseVector(new double[]{3.0, 4.0});
         TestUtils.assertEquals(1.0, mdl.apply(observation), PRECISION);

http://git-wip-us.apache.org/repos/asf/ignite/blob/f566bedb/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
index e0c62af..7c4809f 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
@@ -70,7 +70,7 @@ public class SVMMultiClassTrainerTest extends TrainerTest {
             .withAmountOfIterations(100)
             .withSeed(1234L);
 
-        SVMLinearMultiClassClassificationModel originalModel = trainer.fit(
+        SVMLinearMultiClassClassificationModel originalMdl = trainer.fit(
             cacheMock,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -78,7 +78,7 @@ public class SVMMultiClassTrainerTest extends TrainerTest {
         );
 
         SVMLinearMultiClassClassificationModel updatedOnSameDS = 
trainer.update(
-            originalModel,
+            originalMdl,
             cacheMock,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -86,7 +86,7 @@ public class SVMMultiClassTrainerTest extends TrainerTest {
         );
 
         SVMLinearMultiClassClassificationModel updatedOnEmptyDS = 
trainer.update(
-            originalModel,
+            originalMdl,
             new HashMap<Integer, double[]>(),
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -94,7 +94,7 @@ public class SVMMultiClassTrainerTest extends TrainerTest {
         );
 
         Vector v = VectorUtils.of(100, 10);
-        TestUtils.assertEquals(originalModel.apply(v), 
updatedOnSameDS.apply(v), PRECISION);
-        TestUtils.assertEquals(originalModel.apply(v), 
updatedOnEmptyDS.apply(v), PRECISION);
+        TestUtils.assertEquals(originalMdl.apply(v), updatedOnSameDS.apply(v), 
PRECISION);
+        TestUtils.assertEquals(originalMdl.apply(v), 
updatedOnEmptyDS.apply(v), PRECISION);
     }
 }

Reply via email to