Repository: ignite Updated Branches: refs/heads/master 8fdf26599 -> 26e405281
http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java index cbaab37..ad4aaf1 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java @@ -22,8 +22,8 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.ThreadLocalRandom; import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.math.VectorUtils; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; @@ -97,7 +97,7 @@ public class LogisticRegressionSGDTrainerTest { (k, v) -> v[0] ); - TestUtils.assertEquals(0, mdl.apply(new DenseLocalOnHeapVector(new double[]{100, 10})), PRECISION); - TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[]{10, 100})), PRECISION); + TestUtils.assertEquals(0, mdl.apply(new DenseVector(new double[]{100, 10})), PRECISION); + TestUtils.assertEquals(1, mdl.apply(new DenseVector(new double[]{10, 100})), PRECISION); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java index 1980489..90918d8 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java @@ -19,7 +19,7 @@ package org.apache.ignite.ml.selection.cv; import java.util.HashMap; import java.util.Map; -import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursorTest.java index 7ad3998..8d02077 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursorTest.java @@ -21,7 +21,7 @@ import java.util.UUID; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.selection.scoring.LabelPair; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursorTest.java index f998dc9..682d6d3 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursorTest.java @@ -19,7 +19,7 @@ package org.apache.ignite.ml.selection.scoring.cursor; import java.util.HashMap; import java.util.Map; -import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.selection.scoring.LabelPair; import org.junit.Test; http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java index d37bd47..ae94dd2 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java @@ -22,8 +22,8 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.ThreadLocalRandom; import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.math.VectorUtils; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.junit.Test; /** @@ -68,7 +68,7 @@ public class SVMBinaryTrainerTest { (k, v) -> v[0] ); - TestUtils.assertEquals(-1, mdl.apply(new DenseLocalOnHeapVector(new double[]{100, 10})), PRECISION); - TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[]{10, 100})), PRECISION); + TestUtils.assertEquals(-1, mdl.apply(new DenseVector(new double[]{100, 10})), PRECISION); + TestUtils.assertEquals(1, mdl.apply(new DenseVector(new double[]{10, 100})), PRECISION); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java index 9092873..e88e16e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java @@ -18,9 +18,9 @@ package org.apache.ignite.ml.svm; import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.exceptions.CardinalityException; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; import org.junit.Assert; import org.junit.Test; @@ -35,22 +35,22 @@ public class SVMModelTest { /** */ @Test public void testPredictWithRawLabels() { - Vector weights = new DenseLocalOnHeapVector(new double[]{2.0, 3.0}); + Vector weights = new DenseVector(new double[]{2.0, 3.0}); SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0).withRawLabels(true); - Vector observation = new DenseLocalOnHeapVector(new double[]{1.0, 1.0}); + Vector observation = new DenseVector(new double[]{1.0, 1.0}); TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 1.0, mdl.apply(observation), PRECISION); - observation = new DenseLocalOnHeapVector(new double[]{2.0, 1.0}); + observation = new DenseVector(new double[]{2.0, 1.0}); TestUtils.assertEquals(1.0 + 2.0 * 2.0 + 3.0 * 1.0, mdl.apply(observation), PRECISION); - observation = new DenseLocalOnHeapVector(new double[]{1.0, 2.0}); + observation = new DenseVector(new double[]{1.0, 2.0}); TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 2.0, mdl.apply(observation), PRECISION); - observation = new DenseLocalOnHeapVector(new double[]{-2.0, 1.0}); + observation = new DenseVector(new double[]{-2.0, 1.0}); TestUtils.assertEquals(1.0 - 2.0 * 2.0 + 3.0 * 1.0, mdl.apply(observation), PRECISION); - observation = new DenseLocalOnHeapVector(new double[]{1.0, -2.0}); + observation = new DenseVector(new double[]{1.0, -2.0}); TestUtils.assertEquals(1.0 + 2.0 * 1.0 - 3.0 * 2.0, mdl.apply(observation), PRECISION); Assert.assertEquals(true, mdl.isKeepingRawLabels()); @@ -60,43 +60,43 @@ public class SVMModelTest { /** */ @Test public void testPredictWithMultiClasses() { - Vector weights1 = new DenseLocalOnHeapVector(new double[]{10.0, 0.0}); - Vector weights2 = new DenseLocalOnHeapVector(new double[]{0.0, 10.0}); - Vector weights3 = new DenseLocalOnHeapVector(new double[]{-1.0, -1.0}); + Vector weights1 = new DenseVector(new double[]{10.0, 0.0}); + Vector weights2 = new DenseVector(new double[]{0.0, 10.0}); + Vector weights3 = new DenseVector(new double[]{-1.0, -1.0}); SVMLinearMultiClassClassificationModel mdl = new SVMLinearMultiClassClassificationModel(); mdl.add(1, new SVMLinearBinaryClassificationModel(weights1, 0.0).withRawLabels(true)); mdl.add(2, new SVMLinearBinaryClassificationModel(weights2, 0.0).withRawLabels(true)); mdl.add(2, new SVMLinearBinaryClassificationModel(weights3, 0.0).withRawLabels(true)); - Vector observation = new DenseLocalOnHeapVector(new double[]{1.0, 1.0}); + Vector observation = new DenseVector(new double[]{1.0, 1.0}); TestUtils.assertEquals( 1.0, mdl.apply(observation), PRECISION); } /** */ @Test public void testPredictWithErasedLabels() { - Vector weights = new DenseLocalOnHeapVector(new double[]{1.0, 1.0}); + Vector weights = new DenseVector(new double[]{1.0, 1.0}); SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0); - Vector observation = new DenseLocalOnHeapVector(new double[]{1.0, 1.0}); + Vector observation = new DenseVector(new double[]{1.0, 1.0}); TestUtils.assertEquals(1.0, mdl.apply(observation), PRECISION); - observation = new DenseLocalOnHeapVector(new double[]{3.0, 4.0}); + observation = new DenseVector(new double[]{3.0, 4.0}); TestUtils.assertEquals(1.0, mdl.apply(observation), PRECISION); - observation = new DenseLocalOnHeapVector(new double[]{-1.0, -1.0}); + observation = new DenseVector(new double[]{-1.0, -1.0}); TestUtils.assertEquals(-1.0, mdl.apply(observation), PRECISION); - observation = new DenseLocalOnHeapVector(new double[]{-2.0, 1.0}); + observation = new DenseVector(new double[]{-2.0, 1.0}); TestUtils.assertEquals(-1.0, mdl.apply(observation), PRECISION); - observation = new DenseLocalOnHeapVector(new double[]{-1.0, -2.0}); + observation = new DenseVector(new double[]{-1.0, -2.0}); TestUtils.assertEquals(-1.0, mdl.apply(observation), PRECISION); - final SVMLinearBinaryClassificationModel mdlWithNewData = mdl.withIntercept(-2.0).withWeights(new DenseLocalOnHeapVector(new double[] {-2.0, -2.0})); + final SVMLinearBinaryClassificationModel mdlWithNewData = mdl.withIntercept(-2.0).withWeights(new DenseVector(new double[] {-2.0, -2.0})); System.out.println("The SVM model is " + mdlWithNewData); - observation = new DenseLocalOnHeapVector(new double[]{-1.0, -2.0}); + observation = new DenseVector(new double[]{-1.0, -2.0}); TestUtils.assertEquals(1.0, mdl.apply(observation), PRECISION); TestUtils.assertEquals(-2.0, mdl.intercept(), PRECISION); } @@ -104,13 +104,13 @@ public class SVMModelTest { /** */ @Test public void testPredictWithErasedLabelsAndChangedThreshold() { - Vector weights = new DenseLocalOnHeapVector(new double[]{1.0, 1.0}); + Vector weights = new DenseVector(new double[]{1.0, 1.0}); SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0).withThreshold(5); - Vector observation = new DenseLocalOnHeapVector(new double[]{1.0, 1.0}); + Vector observation = new DenseVector(new double[]{1.0, 1.0}); TestUtils.assertEquals(-1.0, mdl.apply(observation), PRECISION); - observation = new DenseLocalOnHeapVector(new double[]{3.0, 4.0}); + observation = new DenseVector(new double[]{3.0, 4.0}); TestUtils.assertEquals(1.0, mdl.apply(observation), PRECISION); TestUtils.assertEquals(5, mdl.threshold(), PRECISION); @@ -119,11 +119,11 @@ public class SVMModelTest { /** */ @Test(expected = CardinalityException.class) public void testPredictOnAnObservationWithWrongCardinality() { - Vector weights = new DenseLocalOnHeapVector(new double[]{2.0, 3.0}); + Vector weights = new DenseVector(new double[]{2.0, 3.0}); SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0); - Vector observation = new DenseLocalOnHeapVector(new double[]{1.0}); + Vector observation = new DenseVector(new double[]{1.0}); mdl.apply(observation); } http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java index 27c0cd0..b12b266 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java @@ -22,8 +22,8 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.ThreadLocalRandom; import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.math.VectorUtils; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.junit.Test; /** @@ -71,7 +71,7 @@ public class SVMMultiClassTrainerTest { (k, v) -> v[0] ); - TestUtils.assertEquals(-1, mdl.apply(new DenseLocalOnHeapVector(new double[]{100, 10})), PRECISION); - TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[]{10, 100})), PRECISION); + TestUtils.assertEquals(-1, mdl.apply(new DenseVector(new double[]{100, 10})), PRECISION); + TestUtils.assertEquals(1, mdl.apply(new DenseVector(new double[]{10, 100})), PRECISION); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java index da0a702..aadc8a7 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java @@ -24,7 +24,7 @@ import org.apache.ignite.IgniteCache; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; /** http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java index 109fa6e..de40b48 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java @@ -23,7 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; -import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java index 11b75cd..a190685 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java @@ -24,7 +24,7 @@ import org.apache.ignite.IgniteCache; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; /** http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java index a552f85..f69da4f 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java @@ -23,7 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; -import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java index e11a669..ca513ed 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java @@ -23,8 +23,8 @@ import org.apache.ignite.IgniteCache; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.ml.math.VectorUtils; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.apache.ignite.ml.nn.performance.MnistMLPTestUtil; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; @@ -92,7 +92,7 @@ public class DecisionTreeMNISTIntegrationTest extends GridCommonAbstractTest { int incorrectAnswers = 0; for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTestSet(10_000)) { - double res = mdl.apply(new DenseLocalOnHeapVector(e.getPixels())); + double res = mdl.apply(new DenseVector(e.getPixels())); if (res == e.getLabel()) correctAnswers++; http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java index 67456ea..8a3f201 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java @@ -20,8 +20,8 @@ package org.apache.ignite.ml.tree.performance; import java.io.IOException; import java.util.HashMap; import java.util.Map; -import org.apache.ignite.ml.math.VectorUtils; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.apache.ignite.ml.nn.performance.MnistMLPTestUtil; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; @@ -61,7 +61,7 @@ public class DecisionTreeMNISTTest { int incorrectAnswers = 0; for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTestSet(10_000)) { - double res = mdl.apply(new DenseLocalOnHeapVector(e.getPixels())); + double res = mdl.apply(new DenseVector(e.getPixels())); if (res == e.getLabel()) correctAnswers++; http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java index eab9152..055223b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java @@ -24,7 +24,7 @@ import java.util.Map; import org.apache.ignite.ml.composition.ModelOnFeaturesSubspace; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator; -import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.tree.DecisionTreeConditionalNode; import org.junit.Test; import org.junit.runner.RunWith; @@ -56,7 +56,8 @@ public class RandomForestClassifierTrainerTest { } /** */ - @Test public void testFit() { + @Test + public void testFit() { int sampleSize = 1000; Map<double[], Double> sample = new HashMap<>(); for (int i = 0; i < sampleSize; i++) { @@ -69,13 +70,15 @@ public class RandomForestClassifierTrainerTest { } RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(4, 3, 5, 0.3, 4, 0.1); - ModelsComposition model = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); - model.getModels().forEach(m -> { + + ModelsComposition mdl = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); + + mdl.getModels().forEach(m -> { assertTrue(m instanceof ModelOnFeaturesSubspace); assertTrue(((ModelOnFeaturesSubspace) m).getMdl() instanceof DecisionTreeConditionalNode); }); - assertTrue(model.getPredictionsAggregator() instanceof OnMajorityPredictionsAggregator); - assertEquals(5, model.getModels().size()); + assertTrue(mdl.getPredictionsAggregator() instanceof OnMajorityPredictionsAggregator); + assertEquals(5, mdl.getModels().size()); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java index 0e32e42..1421e0a 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java @@ -24,7 +24,7 @@ import java.util.Map; import org.apache.ignite.ml.composition.ModelOnFeaturesSubspace; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator; -import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.tree.DecisionTreeConditionalNode; import org.junit.Test; import org.junit.runner.RunWith; @@ -56,7 +56,8 @@ public class RandomForestRegressionTrainerTest { } /** */ - @Test public void testFit() { + @Test + public void testFit() { int sampleSize = 1000; Map<Double, double[]> sample = new HashMap<>(); for (int i = 0; i < sampleSize; i++) { @@ -69,13 +70,15 @@ public class RandomForestRegressionTrainerTest { } RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(4, 3, 5, 0.3, 4, 0.1); - ModelsComposition model = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(v), (k, v) -> k); - model.getModels().forEach(m -> { + + ModelsComposition mdl = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(v), (k, v) -> k); + + mdl.getModels().forEach(m -> { assertTrue(m instanceof ModelOnFeaturesSubspace); assertTrue(((ModelOnFeaturesSubspace) m).getMdl() instanceof DecisionTreeConditionalNode); }); - assertTrue(model.getPredictionsAggregator() instanceof MeanValuePredictionsAggregator); - assertEquals(5, model.getModels().size()); + assertTrue(mdl.getPredictionsAggregator() instanceof MeanValuePredictionsAggregator); + assertEquals(5, mdl.getModels().size()); } }
