http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java index d9b6f7a..7236820 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java @@ -27,8 +27,6 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; import org.junit.Test; /**
http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java index d6f77c0..ccde0d7 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java @@ -27,7 +27,7 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; /** - * Tests for {@link SVMLinearBinaryClassificationTrainer}. + * Tests for {@link SVMLinearClassificationTrainer}. */ public class SVMBinaryTrainerTest extends TrainerTest { /** @@ -40,10 +40,10 @@ public class SVMBinaryTrainerTest extends TrainerTest { for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]); - SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer() + SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer() .withSeed(1234L); - SVMLinearBinaryClassificationModel mdl = trainer.fit( + SVMLinearClassificationModel mdl = trainer.fit( cacheMock, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), @@ -62,18 +62,18 @@ public class SVMBinaryTrainerTest extends TrainerTest { for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]); - SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer() + SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer() .withAmountOfIterations(1000) .withSeed(1234L); - SVMLinearBinaryClassificationModel originalMdl = trainer.fit( + SVMLinearClassificationModel originalMdl = trainer.fit( cacheMock, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), (k, v) -> v[0] ); - SVMLinearBinaryClassificationModel updatedOnSameDS = trainer.update( + SVMLinearClassificationModel updatedOnSameDS = trainer.update( originalMdl, cacheMock, parts, @@ -81,7 +81,7 @@ public class SVMBinaryTrainerTest extends TrainerTest { (k, v) -> v[0] ); - SVMLinearBinaryClassificationModel updatedOnEmptyDS = trainer.update( + SVMLinearClassificationModel updatedOnEmptyDS = trainer.update( originalMdl, new HashMap<Integer, double[]>(), parts, http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java index 9c452f9..3bac790 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java @@ -36,7 +36,7 @@ public class SVMModelTest { @Test public void testPredictWithRawLabels() { Vector weights = new DenseVector(new double[]{2.0, 3.0}); - SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0).withRawLabels(true); + SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0).withRawLabels(true); Vector observation = new DenseVector(new double[]{1.0, 1.0}); TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 1.0, mdl.apply(observation), PRECISION); @@ -55,36 +55,16 @@ public class SVMModelTest { Assert.assertTrue(mdl.isKeepingRawLabels()); - Assert.assertTrue(mdl.toString().length() > 0); - Assert.assertTrue(mdl.toString(true).length() > 0); - Assert.assertTrue(mdl.toString(false).length() > 0); - } - - - /** */ - @Test - public void testPredictWithMultiClasses() { - Vector weights1 = new DenseVector(new double[]{10.0, 0.0}); - Vector weights2 = new DenseVector(new double[]{0.0, 10.0}); - Vector weights3 = new DenseVector(new double[]{-1.0, -1.0}); - SVMLinearMultiClassClassificationModel mdl = new SVMLinearMultiClassClassificationModel(); - mdl.add(1, new SVMLinearBinaryClassificationModel(weights1, 0.0).withRawLabels(true)); - mdl.add(2, new SVMLinearBinaryClassificationModel(weights2, 0.0).withRawLabels(true)); - mdl.add(2, new SVMLinearBinaryClassificationModel(weights3, 0.0).withRawLabels(true)); - - Assert.assertTrue(mdl.toString().length() > 0); - Assert.assertTrue(mdl.toString(true).length() > 0); - Assert.assertTrue(mdl.toString(false).length() > 0); - - Vector observation = new DenseVector(new double[]{1.0, 1.0}); - TestUtils.assertEquals( 1.0, mdl.apply(observation), PRECISION); + Assert.assertTrue(!mdl.toString().isEmpty()); + Assert.assertTrue(!mdl.toString(true).isEmpty()); + Assert.assertTrue(!mdl.toString(false).isEmpty()); } /** */ @Test public void testPredictWithErasedLabels() { Vector weights = new DenseVector(new double[]{1.0, 1.0}); - SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0); + SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0); Vector observation = new DenseVector(new double[]{1.0, 1.0}); TestUtils.assertEquals(1.0, mdl.apply(observation), PRECISION); @@ -101,7 +81,7 @@ public class SVMModelTest { observation = new DenseVector(new double[]{-1.0, -2.0}); TestUtils.assertEquals(0.0, mdl.apply(observation), PRECISION); - final SVMLinearBinaryClassificationModel mdlWithNewData = mdl.withIntercept(-2.0).withWeights(new DenseVector(new double[] {-2.0, -2.0})); + final SVMLinearClassificationModel mdlWithNewData = mdl.withIntercept(-2.0).withWeights(new DenseVector(new double[] {-2.0, -2.0})); System.out.println("The SVM model is " + mdlWithNewData); observation = new DenseVector(new double[]{-1.0, -2.0}); @@ -113,7 +93,7 @@ public class SVMModelTest { @Test public void testPredictWithErasedLabelsAndChangedThreshold() { Vector weights = new DenseVector(new double[]{1.0, 1.0}); - SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0).withThreshold(5); + SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0).withThreshold(5); Vector observation = new DenseVector(new double[]{1.0, 1.0}); TestUtils.assertEquals(0.0, mdl.apply(observation), PRECISION); @@ -129,7 +109,7 @@ public class SVMModelTest { public void testPredictOnAnObservationWithWrongCardinality() { Vector weights = new DenseVector(new double[]{2.0, 3.0}); - SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0); + SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0); Vector observation = new DenseVector(new double[]{1.0}); http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java deleted file mode 100644 index 7c4809f..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.svm; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; -import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.common.TrainerTest; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.math.primitives.vector.VectorUtils; -import org.junit.Test; - -/** - * Tests for {@link SVMLinearBinaryClassificationTrainer}. - */ -public class SVMMultiClassTrainerTest extends TrainerTest { - /** - * Test trainer on 4 sets grouped around of square vertices. - */ - @Test - public void testTrainWithTheLinearlySeparableCase() { - Map<Integer, double[]> cacheMock = new HashMap<>(); - - for (int i = 0; i < twoLinearlySeparableClasses.length; i++) - cacheMock.put(i, twoLinearlySeparableClasses[i]); - - SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer() - .withLambda(0.3) - .withAmountOfLocIterations(10) - .withAmountOfIterations(20) - .withSeed(1234L); - - SVMLinearMultiClassClassificationModel mdl = trainer.fit( - cacheMock, - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION); - TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION); - } - - /** */ - @Test - public void testUpdate() { - Map<Integer, double[]> cacheMock = new HashMap<>(); - - for (int i = 0; i < twoLinearlySeparableClasses.length; i++) - cacheMock.put(i, twoLinearlySeparableClasses[i]); - - SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer() - .withLambda(0.3) - .withAmountOfLocIterations(10) - .withAmountOfIterations(100) - .withSeed(1234L); - - SVMLinearMultiClassClassificationModel originalMdl = trainer.fit( - cacheMock, - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - SVMLinearMultiClassClassificationModel updatedOnSameDS = trainer.update( - originalMdl, - cacheMock, - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - SVMLinearMultiClassClassificationModel updatedOnEmptyDS = trainer.update( - originalMdl, - new HashMap<Integer, double[]>(), - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - Vector v = VectorUtils.of(100, 10); - TestUtils.assertEquals(originalMdl.apply(v), updatedOnSameDS.apply(v), PRECISION); - TestUtils.assertEquals(originalMdl.apply(v), updatedOnEmptyDS.apply(v), PRECISION); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java index df7263f..a2aea6e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java @@ -27,7 +27,6 @@ import org.junit.runners.Suite; @Suite.SuiteClasses({ SVMModelTest.class, SVMBinaryTrainerTest.class, - SVMMultiClassTrainerTest.class, }) public class SVMTestSuite { // No-op. http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java index 1b96ce2..31fe8b3 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java @@ -37,8 +37,8 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; import org.junit.Test; /**
