Repository: ignite Updated Branches: refs/heads/master cd0ead329 -> f566bedbc
IGNITE-9022: [ML] Implement class labels mapping for SVM binary classifier this closes #4749 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/f566bedb Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/f566bedb Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/f566bedb Branch: refs/heads/master Commit: f566bedbcb400e5102af48470d1aa4395c97a13f Parents: cd0ead3 Author: zaleslaw <[email protected]> Authored: Tue Sep 18 15:14:03 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Tue Sep 18 15:14:03 2018 +0300 ---------------------------------------------------------------------- .../binary/SVMBinaryClassificationExample.java | 105 +++++++++---------- .../svm/SVMLinearBinaryClassificationModel.java | 8 +- .../SVMLinearBinaryClassificationTrainer.java | 25 +++-- ...VMLinearMultiClassClassificationTrainer.java | 2 +- .../ignite/ml/svm/SVMBinaryTrainerTest.java | 12 +-- .../org/apache/ignite/ml/svm/SVMModelTest.java | 8 +- .../ignite/ml/svm/SVMMultiClassTrainerTest.java | 10 +- 7 files changed, 88 insertions(+), 82 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/f566bedb/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java index bd88c20..f71db2d 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java @@ -91,8 +91,8 @@ public class SVMBinaryClassificationExample { if(groundTruth != prediction) amountOfErrors++; - int idx1 = (int)prediction == -1.0 ? 0 : 1; - int idx2 = (int)groundTruth == -1.0 ? 0 : 1; + int idx1 = prediction == 0.0 ? 0 : 1; + int idx2 = groundTruth == 0.0 ? 0 : 1; confusionMtx[idx1][idx2]++; @@ -117,56 +117,56 @@ public class SVMBinaryClassificationExample { /** The 1st and 2nd classes from the Iris dataset. */ private static final double[][] data = { - {-1, 5.1, 3.5, 1.4, 0.2}, - {-1, 4.9, 3, 1.4, 0.2}, - {-1, 4.7, 3.2, 1.3, 0.2}, - {-1, 4.6, 3.1, 1.5, 0.2}, - {-1, 5, 3.6, 1.4, 0.2}, - {-1, 5.4, 3.9, 1.7, 0.4}, - {-1, 4.6, 3.4, 1.4, 0.3}, - {-1, 5, 3.4, 1.5, 0.2}, - {-1, 4.4, 2.9, 1.4, 0.2}, - {-1, 4.9, 3.1, 1.5, 0.1}, - {-1, 5.4, 3.7, 1.5, 0.2}, - {-1, 4.8, 3.4, 1.6, 0.2}, - {-1, 4.8, 3, 1.4, 0.1}, - {-1, 4.3, 3, 1.1, 0.1}, - {-1, 5.8, 4, 1.2, 0.2}, - {-1, 5.7, 4.4, 1.5, 0.4}, - {-1, 5.4, 3.9, 1.3, 0.4}, - {-1, 5.1, 3.5, 1.4, 0.3}, - {-1, 5.7, 3.8, 1.7, 0.3}, - {-1, 5.1, 3.8, 1.5, 0.3}, - {-1, 5.4, 3.4, 1.7, 0.2}, - {-1, 5.1, 3.7, 1.5, 0.4}, - {-1, 4.6, 3.6, 1, 0.2}, - {-1, 5.1, 3.3, 1.7, 0.5}, - {-1, 4.8, 3.4, 1.9, 0.2}, - {-1, 5, 3, 1.6, 0.2}, - {-1, 5, 3.4, 1.6, 0.4}, - {-1, 5.2, 3.5, 1.5, 0.2}, - {-1, 5.2, 3.4, 1.4, 0.2}, - {-1, 4.7, 3.2, 1.6, 0.2}, - {-1, 4.8, 3.1, 1.6, 0.2}, - {-1, 5.4, 3.4, 1.5, 0.4}, - {-1, 5.2, 4.1, 1.5, 0.1}, - {-1, 5.5, 4.2, 1.4, 0.2}, - {-1, 4.9, 3.1, 1.5, 0.1}, - {-1, 5, 3.2, 1.2, 0.2}, - {-1, 5.5, 3.5, 1.3, 0.2}, - {-1, 4.9, 3.1, 1.5, 0.1}, - {-1, 4.4, 3, 1.3, 0.2}, - {-1, 5.1, 3.4, 1.5, 0.2}, - {-1, 5, 3.5, 1.3, 0.3}, - {-1, 4.5, 2.3, 1.3, 0.3}, - {-1, 4.4, 3.2, 1.3, 0.2}, - {-1, 5, 3.5, 1.6, 0.6}, - {-1, 5.1, 3.8, 1.9, 0.4}, - {-1, 4.8, 3, 1.4, 0.3}, - {-1, 5.1, 3.8, 1.6, 0.2}, - {-1, 4.6, 3.2, 1.4, 0.2}, - {-1, 5.3, 3.7, 1.5, 0.2}, - {-1, 5, 3.3, 1.4, 0.2}, + {0, 5.1, 3.5, 1.4, 0.2}, + {0, 4.9, 3, 1.4, 0.2}, + {0, 4.7, 3.2, 1.3, 0.2}, + {0, 4.6, 3.1, 1.5, 0.2}, + {0, 5, 3.6, 1.4, 0.2}, + {0, 5.4, 3.9, 1.7, 0.4}, + {0, 4.6, 3.4, 1.4, 0.3}, + {0, 5, 3.4, 1.5, 0.2}, + {0, 4.4, 2.9, 1.4, 0.2}, + {0, 4.9, 3.1, 1.5, 0.1}, + {0, 5.4, 3.7, 1.5, 0.2}, + {0, 4.8, 3.4, 1.6, 0.2}, + {0, 4.8, 3, 1.4, 0.1}, + {0, 4.3, 3, 1.1, 0.1}, + {0, 5.8, 4, 1.2, 0.2}, + {0, 5.7, 4.4, 1.5, 0.4}, + {0, 5.4, 3.9, 1.3, 0.4}, + {0, 5.1, 3.5, 1.4, 0.3}, + {0, 5.7, 3.8, 1.7, 0.3}, + {0, 5.1, 3.8, 1.5, 0.3}, + {0, 5.4, 3.4, 1.7, 0.2}, + {0, 5.1, 3.7, 1.5, 0.4}, + {0, 4.6, 3.6, 1, 0.2}, + {0, 5.1, 3.3, 1.7, 0.5}, + {0, 4.8, 3.4, 1.9, 0.2}, + {0, 5, 3, 1.6, 0.2}, + {0, 5, 3.4, 1.6, 0.4}, + {0, 5.2, 3.5, 1.5, 0.2}, + {0, 5.2, 3.4, 1.4, 0.2}, + {0, 4.7, 3.2, 1.6, 0.2}, + {0, 4.8, 3.1, 1.6, 0.2}, + {0, 5.4, 3.4, 1.5, 0.4}, + {0, 5.2, 4.1, 1.5, 0.1}, + {0, 5.5, 4.2, 1.4, 0.2}, + {0, 4.9, 3.1, 1.5, 0.1}, + {0, 5, 3.2, 1.2, 0.2}, + {0, 5.5, 3.5, 1.3, 0.2}, + {0, 4.9, 3.1, 1.5, 0.1}, + {0, 4.4, 3, 1.3, 0.2}, + {0, 5.1, 3.4, 1.5, 0.2}, + {0, 5, 3.5, 1.3, 0.3}, + {0, 4.5, 2.3, 1.3, 0.3}, + {0, 4.4, 3.2, 1.3, 0.2}, + {0, 5, 3.5, 1.6, 0.6}, + {0, 5.1, 3.8, 1.9, 0.4}, + {0, 4.8, 3, 1.4, 0.3}, + {0, 5.1, 3.8, 1.6, 0.2}, + {0, 4.6, 3.2, 1.4, 0.2}, + {0, 5.3, 3.7, 1.5, 0.2}, + {0, 5, 3.3, 1.4, 0.2}, {1, 7, 3.2, 4.7, 1.4}, {1, 6.4, 3.2, 4.5, 1.5}, {1, 6.9, 3.1, 4.9, 1.5}, @@ -218,5 +218,4 @@ public class SVMBinaryClassificationExample { {1, 5.1, 2.5, 3, 1.1}, {1, 5.7, 2.8, 4.1, 1.3}, }; - } http://git-wip-us.apache.org/repos/asf/ignite/blob/f566bedb/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java index 4771e4a..f5d2b28 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java @@ -31,11 +31,11 @@ public class SVMLinearBinaryClassificationModel implements Model<Vector, Double> /** */ private static final long serialVersionUID = -996984622291440226L; - /** Output label format. -1 and +1 for false value and raw distances from the separating hyperplane otherwise. */ + /** Output label format. '0' and '1' for false value and raw distances from the separating hyperplane otherwise. */ private boolean isKeepingRawLabels = false; - /** Threshold to assign +1 label to the observation if raw value more than this threshold. */ - private double threshold = 0.0; + /** Threshold to assign '1' label to the observation if raw value more than this threshold. */ + private double threshold = 0.5; /** Multiplier of the objects's vector required to make prediction. */ private Vector weights; @@ -99,7 +99,7 @@ public class SVMLinearBinaryClassificationModel implements Model<Vector, Double> if (isKeepingRawLabels) return res; else - return res - threshold > 0 ? 1.0 : -1.0; + return res - threshold > 0 ? 1.0 : 0; } /** http://git-wip-us.apache.org/repos/asf/ignite/blob/f566bedb/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java index 8fb98d2..2c621c8 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java @@ -35,8 +35,8 @@ import org.jetbrains.annotations.NotNull; /** * Base class for a soft-margin SVM linear classification trainer based on the communication-efficient distributed dual - * coordinate ascent algorithm (CoCoA) with hinge-loss function. <p> This trainer takes input as Labeled Dataset with -1 - * and +1 labels for two classes and makes binary classification. </p> The paper about this algorithm could be found + * coordinate ascent algorithm (CoCoA) with hinge-loss function. <p> This trainer takes input as Labeled Dataset with 0 + * and 1 labels for two classes and makes binary classification. </p> The paper about this algorithm could be found * here https://arxiv.org/abs/1409.1458. */ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrainer<SVMLinearBinaryClassificationModel> { @@ -73,9 +73,17 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai assert datasetBuilder != null; + IgniteBiFunction<K, V, Double> patchedLbExtractor = (k, v) -> { + final Double lb = lbExtractor.apply(k, v); + if (lb == 0.0) + return -1.0; + else + return lb; + }; + PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>( featureExtractor, - lbExtractor + patchedLbExtractor ); Vector weights; @@ -95,9 +103,8 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai final int weightVectorSizeWithIntercept = cols + 1; weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept); - } else { + } else weights = getStateVector(mdl); - } for (int i = 0; i < this.getAmountOfIterations(); i++) { Vector deltaWeights = calculateUpdates(weights, dataset); @@ -126,13 +133,13 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai Vector weights = mdl.weights(); int stateVectorSize = weights.size() + 1; - Vector result = weights.isDense() ? + Vector res = weights.isDense() ? new DenseVector(stateVectorSize) : new SparseVector(stateVectorSize, StorageConstants.RANDOM_ACCESS_MODE); - result.set(0, intercept); - weights.nonZeroes().forEach(ith -> result.set(ith.index(), ith.get())); - return result; + res.set(0, intercept); + weights.nonZeroes().forEach(ith -> res.set(ith.index(), ith.get())); + return res; } /** */ http://git-wip-us.apache.org/repos/asf/ignite/blob/f566bedb/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java index aeee178..7cbb1dc 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java @@ -93,7 +93,7 @@ public class SVMLinearMultiClassClassificationTrainer if (lb.equals(clsLb)) return 1.0; else - return -1.0; + return 0.0; }; SVMLinearBinaryClassificationModel model; http://git-wip-us.apache.org/repos/asf/ignite/blob/f566bedb/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java index 263bb6d..d6f77c0 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java @@ -50,7 +50,7 @@ public class SVMBinaryTrainerTest extends TrainerTest { (k, v) -> v[0] ); - TestUtils.assertEquals(-1, mdl.apply(VectorUtils.of(100, 10)), PRECISION); + TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION); TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION); } @@ -66,7 +66,7 @@ public class SVMBinaryTrainerTest extends TrainerTest { .withAmountOfIterations(1000) .withSeed(1234L); - SVMLinearBinaryClassificationModel originalModel = trainer.fit( + SVMLinearBinaryClassificationModel originalMdl = trainer.fit( cacheMock, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), @@ -74,7 +74,7 @@ public class SVMBinaryTrainerTest extends TrainerTest { ); SVMLinearBinaryClassificationModel updatedOnSameDS = trainer.update( - originalModel, + originalMdl, cacheMock, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), @@ -82,7 +82,7 @@ public class SVMBinaryTrainerTest extends TrainerTest { ); SVMLinearBinaryClassificationModel updatedOnEmptyDS = trainer.update( - originalModel, + originalMdl, new HashMap<Integer, double[]>(), parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), @@ -90,7 +90,7 @@ public class SVMBinaryTrainerTest extends TrainerTest { ); Vector v = VectorUtils.of(100, 10); - TestUtils.assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), PRECISION); - TestUtils.assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), PRECISION); + TestUtils.assertEquals(originalMdl.apply(v), updatedOnSameDS.apply(v), PRECISION); + TestUtils.assertEquals(originalMdl.apply(v), updatedOnEmptyDS.apply(v), PRECISION); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f566bedb/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java index 9a222c3..9c452f9 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java @@ -93,13 +93,13 @@ public class SVMModelTest { TestUtils.assertEquals(1.0, mdl.apply(observation), PRECISION); observation = new DenseVector(new double[]{-1.0, -1.0}); - TestUtils.assertEquals(-1.0, mdl.apply(observation), PRECISION); + TestUtils.assertEquals(0.0, mdl.apply(observation), PRECISION); observation = new DenseVector(new double[]{-2.0, 1.0}); - TestUtils.assertEquals(-1.0, mdl.apply(observation), PRECISION); + TestUtils.assertEquals(0.0, mdl.apply(observation), PRECISION); observation = new DenseVector(new double[]{-1.0, -2.0}); - TestUtils.assertEquals(-1.0, mdl.apply(observation), PRECISION); + TestUtils.assertEquals(0.0, mdl.apply(observation), PRECISION); final SVMLinearBinaryClassificationModel mdlWithNewData = mdl.withIntercept(-2.0).withWeights(new DenseVector(new double[] {-2.0, -2.0})); System.out.println("The SVM model is " + mdlWithNewData); @@ -116,7 +116,7 @@ public class SVMModelTest { SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0).withThreshold(5); Vector observation = new DenseVector(new double[]{1.0, 1.0}); - TestUtils.assertEquals(-1.0, mdl.apply(observation), PRECISION); + TestUtils.assertEquals(0.0, mdl.apply(observation), PRECISION); observation = new DenseVector(new double[]{3.0, 4.0}); TestUtils.assertEquals(1.0, mdl.apply(observation), PRECISION); http://git-wip-us.apache.org/repos/asf/ignite/blob/f566bedb/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java index e0c62af..7c4809f 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java @@ -70,7 +70,7 @@ public class SVMMultiClassTrainerTest extends TrainerTest { .withAmountOfIterations(100) .withSeed(1234L); - SVMLinearMultiClassClassificationModel originalModel = trainer.fit( + SVMLinearMultiClassClassificationModel originalMdl = trainer.fit( cacheMock, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), @@ -78,7 +78,7 @@ public class SVMMultiClassTrainerTest extends TrainerTest { ); SVMLinearMultiClassClassificationModel updatedOnSameDS = trainer.update( - originalModel, + originalMdl, cacheMock, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), @@ -86,7 +86,7 @@ public class SVMMultiClassTrainerTest extends TrainerTest { ); SVMLinearMultiClassClassificationModel updatedOnEmptyDS = trainer.update( - originalModel, + originalMdl, new HashMap<Integer, double[]>(), parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), @@ -94,7 +94,7 @@ public class SVMMultiClassTrainerTest extends TrainerTest { ); Vector v = VectorUtils.of(100, 10); - TestUtils.assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), PRECISION); - TestUtils.assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), PRECISION); + TestUtils.assertEquals(originalMdl.apply(v), updatedOnSameDS.apply(v), PRECISION); + TestUtils.assertEquals(originalMdl.apply(v), updatedOnEmptyDS.apply(v), PRECISION); } }
