http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java index 9f60c48..a316014 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java @@ -18,7 +18,10 @@ package org.apache.ignite.ml; import org.apache.ignite.ml.clustering.ClusteringTestSuite; +import org.apache.ignite.ml.common.CommonTestSuite; +import org.apache.ignite.ml.composition.CompositionTestSuite; import org.apache.ignite.ml.dataset.DatasetTestSuite; +import org.apache.ignite.ml.environment.EnvironmentTestSuite; import org.apache.ignite.ml.genetic.GAGridTestSuite; import org.apache.ignite.ml.knn.KNNTestSuite; import org.apache.ignite.ml.math.MathImplMainTestSuite; @@ -26,13 +29,15 @@ import org.apache.ignite.ml.nn.MLPTestSuite; import org.apache.ignite.ml.preprocessing.PreprocessingTestSuite; import org.apache.ignite.ml.regressions.RegressionsTestSuite; import org.apache.ignite.ml.selection.SelectionTestSuite; +import org.apache.ignite.ml.structures.StructuresTestSuite; import org.apache.ignite.ml.svm.SVMTestSuite; import org.apache.ignite.ml.tree.DecisionTreeTestSuite; import org.junit.runner.RunWith; import org.junit.runners.Suite; /** - * Test suite for all module tests. + * Test suite for all module tests. IMPL NOTE tests in {@code org.apache.ignite.ml.tree.performance} are not + * included here because these are intended only for manual execution. */ @RunWith(Suite.class) @Suite.SuiteClasses({ @@ -47,8 +52,12 @@ import org.junit.runners.Suite; DatasetTestSuite.class, PreprocessingTestSuite.class, GAGridTestSuite.class, - SelectionTestSuite.class + SelectionTestSuite.class, + CompositionTestSuite.class, + EnvironmentTestSuite.class, + StructuresTestSuite.class, + CommonTestSuite.class }) public class IgniteMLTestSuite { // No-op. -} \ No newline at end of file +}
http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java index a4591fb..4b472cc 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java @@ -18,7 +18,6 @@ package org.apache.ignite.ml; import java.util.stream.IntStream; -import org.apache.ignite.ml.math.Precision; import org.apache.ignite.ml.math.primitives.matrix.Matrix; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.junit.Assert; @@ -205,44 +204,6 @@ public class TestUtils { Assert.fail(out.toString()); } - /** - * Verifies that two float arrays are close (sup norm). - * - * @param msg The identifying message for the assertion error. - * @param exp Expected array. - * @param actual Actual array. - * @param tolerance Comparison tolerance value. - */ - public static void assertEquals(String msg, float[] exp, float[] actual, float tolerance) { - StringBuilder out = new StringBuilder(msg); - - if (exp.length != actual.length) { - out.append("\n Arrays not same length. \n"); - out.append("expected has length "); - out.append(exp.length); - out.append(" observed length = "); - out.append(actual.length); - Assert.fail(out.toString()); - } - - boolean failure = false; - - for (int i = 0; i < exp.length; i++) - if (!Precision.equalsIncludingNaN(exp[i], actual[i], tolerance)) { - failure = true; - out.append("\n Elements at index "); - out.append(i); - out.append(" differ. "); - out.append(" expected = "); - out.append(exp[i]); - out.append(" observed = "); - out.append(actual[i]); - } - - if (failure) - Assert.fail(out.toString()); - } - /** */ public static double maximumAbsoluteRowSum(Matrix mtx) { return IntStream.range(0, mtx.rowSize()).mapToObj(mtx::viewRow).map(v -> Math.abs(v.sum())).reduce(Math::max).get(); @@ -271,4 +232,97 @@ public class TestUtils { return true; } + + /** */ + private static class Precision { + /** Offset to order signed double numbers lexicographically. */ + private static final long SGN_MASK = 0x8000000000000000L; + + /** Positive zero bits. */ + private static final long POSITIVE_ZERO_DOUBLE_BITS = Double.doubleToRawLongBits(+0.0); + + /** Negative zero bits. */ + private static final long NEGATIVE_ZERO_DOUBLE_BITS = Double.doubleToRawLongBits(-0.0); + + /** + * Returns true if the arguments are both NaN, are equal or are within the range + * of allowed error (inclusive). + * + * @param x first value + * @param y second value + * @param eps the amount of absolute error to allow. + * @return {@code true} if the values are equal or within range of each other, or both are NaN. + * @since 2.2 + */ + static boolean equalsIncludingNaN(double x, double y, double eps) { + return equalsIncludingNaN(x, y) || (Math.abs(y - x) <= eps); + } + + /** + * Returns true if the arguments are both NaN or they are + * equal as defined by {@link #equals(double, double, int) equals(x, y, 1)}. + * + * @param x first value + * @param y second value + * @return {@code true} if the values are equal or both are NaN. + * @since 2.2 + */ + private static boolean equalsIncludingNaN(double x, double y) { + return (x != x || y != y) ? !(x != x ^ y != y) : equals(x, y, 1); + } + + /** + * Returns true if the arguments are equal or within the range of allowed + * error (inclusive). + * <p> + * Two float numbers are considered equal if there are {@code (maxUlps - 1)} + * (or fewer) floating point numbers between them, i.e. two adjacent + * floating point numbers are considered equal. + * </p> + * <p> + * Adapted from <a + * href="http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/"> + * Bruce Dawson</a>. Returns {@code false} if either of the arguments is NaN. + * </p> + * + * @param x first value + * @param y second value + * @param maxUlps {@code (maxUlps - 1)} is the number of floating point values between {@code x} and {@code y}. + * @return {@code true} if there are fewer than {@code maxUlps} floating point values between {@code x} and {@code + * y}. + */ + private static boolean equals(final double x, final double y, final int maxUlps) { + + final long xInt = Double.doubleToRawLongBits(x); + final long yInt = Double.doubleToRawLongBits(y); + + final boolean isEqual; + if (((xInt ^ yInt) & SGN_MASK) == 0L) { + // number have same sign, there is no risk of overflow + isEqual = Math.abs(xInt - yInt) <= maxUlps; + } + else { + // number have opposite signs, take care of overflow + final long deltaPlus; + final long deltaMinus; + if (xInt < yInt) { + deltaPlus = yInt - POSITIVE_ZERO_DOUBLE_BITS; + deltaMinus = xInt - NEGATIVE_ZERO_DOUBLE_BITS; + } + else { + deltaPlus = xInt - POSITIVE_ZERO_DOUBLE_BITS; + deltaMinus = yInt - NEGATIVE_ZERO_DOUBLE_BITS; + } + + if (deltaPlus > maxUlps) + isEqual = false; + else + isEqual = deltaMinus <= (maxUlps - deltaPlus); + + } + + return isEqual && !Double.isNaN(x) && !Double.isNaN(y); + + } + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java index 0d95d05..03e0e6d 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java @@ -46,6 +46,8 @@ public class KMeansModelTest { KMeansModel mdl = new KMeansModel(centers, distanceMeasure); + Assert.assertTrue(mdl.toString().contains("KMeansModel")); + Assert.assertEquals(mdl.apply(new DenseVector(new double[]{1.1, 1.1})), 0.0, PRECISION); Assert.assertEquals(mdl.apply(new DenseVector(new double[]{-1.1, 1.1})), 1.0, PRECISION); Assert.assertEquals(mdl.apply(new DenseVector(new double[]{1.1, -1.1})), 2.0, PRECISION); http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java index 8d2c341..420e4fb 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java @@ -30,6 +30,7 @@ import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; /** * Tests for {@link KMeansTrainer}. @@ -43,7 +44,6 @@ public class KMeansTrainerTest { */ @Test public void findOneClusters() { - Map<Integer, double[]> data = new HashMap<>(); data.put(0, new double[] {1.0, 1.0, 1.0}); data.put(1, new double[] {1.0, 2.0, 1.0}); @@ -54,11 +54,15 @@ public class KMeansTrainerTest { KMeansTrainer trainer = new KMeansTrainer() .withDistance(new EuclideanDistance()) - .withK(1) + .withK(10) .withMaxIterations(1) - .withEpsilon(PRECISION); + .withEpsilon(PRECISION) + .withSeed(2); + assertEquals(10, trainer.getK()); + assertEquals(2, trainer.getSeed()); + assertTrue(trainer.getDistance() instanceof EuclideanDistance); - KMeansModel knnMdl = trainer.fit( + KMeansModel knnMdl = trainer.withK(1).fit( new LocalDatasetBuilder<>(data, 2), (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[2] http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java new file mode 100644 index 0000000..c4d896c --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.common; + +import java.util.HashSet; +import java.util.Set; +import org.apache.ignite.ml.clustering.kmeans.KMeansModel; +import org.apache.ignite.ml.clustering.kmeans.KMeansModelFormat; +import org.apache.ignite.ml.knn.classification.KNNClassificationModel; +import org.apache.ignite.ml.knn.classification.KNNModelFormat; +import org.apache.ignite.ml.knn.classification.KNNStrategy; +import org.apache.ignite.ml.math.distances.EuclideanDistance; +import org.apache.ignite.ml.math.distances.HammingDistance; +import org.apache.ignite.ml.math.distances.ManhattanDistance; +import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; +import org.apache.ignite.ml.math.primitives.vector.impl.VectorizedViewMatrix; +import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; +import org.apache.ignite.ml.structures.Dataset; +import org.apache.ignite.ml.structures.DatasetRow; +import org.apache.ignite.ml.structures.FeatureMetadata; +import org.apache.ignite.ml.structures.LabeledVector; +import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel; +import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +/** + * Tests for equals and hashCode methods in classes that provide own implementations of these. + */ +public class CollectionsTest { + /** */ + @Test + @SuppressWarnings("unchecked") + public void test() { + test(new VectorizedViewMatrix(new DenseMatrix(2, 2), 1, 1, 1, 1), + new VectorizedViewMatrix(new DenseMatrix(3, 2), 2, 1, 1, 1)); + + specialTest(new ManhattanDistance(), new ManhattanDistance()); + + specialTest(new HammingDistance(), new HammingDistance()); + + specialTest(new EuclideanDistance(), new EuclideanDistance()); + + FeatureMetadata data = new FeatureMetadata("name2"); + data.setName("name1"); + test(data, new FeatureMetadata("name2")); + + test(new DatasetRow<>(new DenseVector()), new DatasetRow<>(new DenseVector(1))); + + test(new LabeledVector<>(new DenseVector(), null), new LabeledVector<>(new DenseVector(1), null)); + + test(new Dataset<DatasetRow<Vector>>(new DatasetRow[] {}, new FeatureMetadata[] {}), + new Dataset<DatasetRow<Vector>>(new DatasetRow[] {new DatasetRow()}, + new FeatureMetadata[] {new FeatureMetadata()})); + + test(new LogisticRegressionModel(new DenseVector(), 1.0), + new LogisticRegressionModel(new DenseVector(), 0.5)); + + test(new KMeansModelFormat(new Vector[] {}, new ManhattanDistance()), + new KMeansModelFormat(new Vector[] {}, new HammingDistance())); + + test(new KMeansModel(new Vector[] {}, new ManhattanDistance()), + new KMeansModel(new Vector[] {}, new HammingDistance())); + + test(new KNNModelFormat(1, new ManhattanDistance(), KNNStrategy.SIMPLE), + new KNNModelFormat(2, new ManhattanDistance(), KNNStrategy.SIMPLE)); + + test(new KNNClassificationModel(null).withK(1), new KNNClassificationModel(null).withK(2)); + + LogRegressionMultiClassModel mdl = new LogRegressionMultiClassModel(); + mdl.add(1, new LogisticRegressionModel(new DenseVector(), 1.0)); + test(mdl, new LogRegressionMultiClassModel()); + + test(new LinearRegressionModel(null, 1.0), new LinearRegressionModel(null, 0.5)); + + SVMLinearMultiClassClassificationModel mdl1 = new SVMLinearMultiClassClassificationModel(); + mdl1.add(1, new SVMLinearBinaryClassificationModel(new DenseVector(), 1.0)); + test(mdl1, new SVMLinearMultiClassClassificationModel()); + + test(new SVMLinearBinaryClassificationModel(null, 1.0), new SVMLinearBinaryClassificationModel(null, 0.5)); + } + + /** Test classes that have all instances equal (eg, metrics). */ + private <T> void specialTest(T o1, T o2) { + assertEquals(o1, o2); + + test(o1, new Object()); + } + + /** */ + private <T> void test(T o1, T o2) { + assertNotEquals(o1, null); + assertNotEquals(o2, null); + + assertEquals(o1, o1); + assertEquals(o2, o2); + + assertNotEquals(o1, o2); + + Set<T> set = new HashSet<>(); + set.add(o1); + set.add(o1); + assertEquals(1, set.size()); + + set.add(o2); + set.add(o2); + assertEquals(2, set.size()); + + set.remove(o1); + assertEquals(1, set.size()); + + set.remove(o2); + assertEquals(0, set.size()); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/common/CommonTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/CommonTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/CommonTestSuite.java new file mode 100644 index 0000000..5336bf8 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/CommonTestSuite.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.common; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * Test suite for all tests located in org.apache.ignite.ml.trees package. + */ +@RunWith(Suite.class) [email protected]({ + CollectionsTest.class, + ExternalizeTest.class +}) +public class CommonTestSuite { +} http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/common/ExternalizeTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/ExternalizeTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/ExternalizeTest.java new file mode 100644 index 0000000..dc37ee8 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/ExternalizeTest.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.common; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import org.apache.ignite.ml.math.Destroyable; +import org.apache.ignite.ml.math.distances.EuclideanDistance; +import org.apache.ignite.ml.math.distances.HammingDistance; +import org.apache.ignite.ml.math.distances.ManhattanDistance; +import org.apache.ignite.ml.math.primitives.MathTestConstants; +import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.impl.DelegatingVector; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; +import org.apache.ignite.ml.math.primitives.vector.impl.VectorizedViewMatrix; +import org.apache.ignite.ml.structures.Dataset; +import org.apache.ignite.ml.structures.DatasetRow; +import org.apache.ignite.ml.structures.FeatureMetadata; +import org.apache.ignite.ml.structures.LabeledVector; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** + * Tests for externalizable classes. + */ +public class ExternalizeTest { + /** */ + @Test + @SuppressWarnings("unchecked") + public void test() { + externalizeTest(new DelegatingVector(new DenseVector(1))); + + externalizeTest(new VectorizedViewMatrix(new DenseMatrix(2, 2), 1, 1, 1, 1)); + + externalizeTest(new ManhattanDistance()); + + externalizeTest(new HammingDistance()); + + externalizeTest(new EuclideanDistance()); + + externalizeTest(new FeatureMetadata()); + + externalizeTest(new VectorizedViewMatrix(new DenseMatrix(2, 2), 1, 1, 1, 1)); + + externalizeTest(new DatasetRow<>(new DenseVector())); + + externalizeTest(new LabeledVector<>(new DenseVector(), null)); + + externalizeTest(new Dataset<DatasetRow<Vector>>(new DatasetRow[] {}, new FeatureMetadata[] {})); + } + + /** */ + @SuppressWarnings("unchecked") + private <T> void externalizeTest(T initObj) { + T objRestored = null; + + try { + ByteArrayOutputStream byteArrOutputStream = new ByteArrayOutputStream(); + ObjectOutputStream objOutputStream = new ObjectOutputStream(byteArrOutputStream); + + objOutputStream.writeObject(initObj); + + ByteArrayInputStream byteArrInputStream = new ByteArrayInputStream(byteArrOutputStream.toByteArray()); + ObjectInputStream objInputStream = new ObjectInputStream(byteArrInputStream); + + objRestored = (T)objInputStream.readObject(); + + assertEquals(MathTestConstants.VAL_NOT_EQUALS, initObj, objRestored); + + assertEquals(MathTestConstants.VAL_NOT_EQUALS, 0, Integer.compare(initObj.hashCode(), objRestored.hashCode())); + } + catch (ClassNotFoundException | IOException e) { + fail(e + " [" + e.getMessage() + "]"); + } + finally { + if (objRestored instanceof Destroyable) + ((Destroyable)objRestored).destroy(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java new file mode 100644 index 0000000..8714eb2 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition; + +import org.apache.ignite.ml.composition.boosting.GDBTrainerTest; +import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregatorTest; +import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregatorTest; +import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregatorTest; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * Test suite for all tests located in org.apache.ignite.ml.composition package. + */ +@RunWith(Suite.class) [email protected]({ + GDBTrainerTest.class, + MeanValuePredictionsAggregatorTest.class, + OnMajorityPredictionsAggregatorTest.class, + WeightedPredictionsAggregatorTest.class +}) +public class CompositionTestSuite { +} http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java index bef5e9b..53f7934 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java @@ -52,7 +52,9 @@ public class GDBTrainerTest { learningSample.put(i, new double[] {xs[i], ys[i]}); } - DatasetTrainer<Model<Vector, Double>, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0); + DatasetTrainer<Model<Vector, Double>, Double> trainer + = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0).withUseIndex(true); + Model<Vector, Double> mdl = trainer.fit( learningSample, 1, (k, v) -> VectorUtils.of(v[0]), @@ -72,6 +74,10 @@ public class GDBTrainerTest { assertTrue(mdl instanceof ModelsComposition); ModelsComposition composition = (ModelsComposition)mdl; + assertTrue(composition.toString().length() > 0); + assertTrue(composition.toString(true).length() > 0); + assertTrue(composition.toString(false).length() > 0); + composition.getModels().forEach(m -> assertTrue(m instanceof DecisionTreeConditionalNode)); assertEquals(2000, composition.getModels().size()); @@ -94,7 +100,9 @@ public class GDBTrainerTest { for (int i = 0; i < sampleSize; i++) learningSample.put(i, new double[] {xs[i], ys[i]}); - DatasetTrainer<Model<Vector, Double>, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0); + DatasetTrainer<Model<Vector, Double>, Double> trainer + = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0).withUseIndex(true); + Model<Vector, Double> mdl = trainer.fit( learningSample, 1, (k, v) -> VectorUtils.of(v[0]), http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregatorTest.java index ae0b166..96e249f 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregatorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregatorTest.java @@ -20,6 +20,7 @@ package org.apache.ignite.ml.composition.predictionsaggregator; import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; /** */ public class WeightedPredictionsAggregatorTest { @@ -28,6 +29,10 @@ public class WeightedPredictionsAggregatorTest { public void testApply1() { WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(new double[] {}); assertEquals(0.0, aggregator.apply(new double[] {}), 0.001); + + assertTrue(aggregator.toString().length() > 0); + assertTrue(aggregator.toString(true).length() > 0); + assertTrue(aggregator.toString(false).length() > 0); } /** */ http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/dataset/DatasetTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/DatasetTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/DatasetTestSuite.java index 3be79a4..52ba705 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/DatasetTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/DatasetTestSuite.java @@ -24,6 +24,8 @@ import org.apache.ignite.ml.dataset.impl.cache.util.DatasetAffinityFunctionWrapp import org.apache.ignite.ml.dataset.impl.cache.util.PartitionDataStorageTest; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilderTest; import org.apache.ignite.ml.dataset.primitive.DatasetWrapperTest; +import org.apache.ignite.ml.dataset.primitive.SimpleDatasetTest; +import org.apache.ignite.ml.dataset.primitive.SimpleLabeledDatasetTest; import org.junit.runner.RunWith; import org.junit.runners.Suite; @@ -38,7 +40,9 @@ import org.junit.runners.Suite; PartitionDataStorageTest.class, CacheBasedDatasetBuilderTest.class, CacheBasedDatasetTest.class, - LocalDatasetBuilderTest.class + LocalDatasetBuilderTest.class, + SimpleDatasetTest.class, + SimpleLabeledDatasetTest.class }) public class DatasetTestSuite { // No-op. http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java index 2e39e65..d96c935 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java @@ -63,7 +63,7 @@ public class CacheBasedDatasetTest extends GridCommonAbstractTest { } /** {@inheritDoc} */ - @Override protected void beforeTest() throws Exception { + @Override protected void beforeTest() { /* Grid instance. */ ignite = grid(NODE_COUNT); ignite.configuration().setPeerClassLoadingEnabled(true); @@ -87,6 +87,9 @@ public class CacheBasedDatasetTest extends GridCommonAbstractTest { (upstream, upstreamSize, ctx) -> new SimpleDatasetData(new double[0], 0) ); + assertEquals("Upstream cache name from dataset", + upstreamCache.getName(), dataset.getUpstreamCache().getName()); + assertTrue("Before computation all partitions should not be reserved", areAllPartitionsNotReserved(upstreamCache.getName(), dataset.getDatasetCache().getName())); http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilderTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilderTest.java index 8a5eb3a..3b12300 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilderTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilderTest.java @@ -38,18 +38,10 @@ public class LocalDatasetBuilderTest { LocalDatasetBuilder<Integer, Integer> builder = new LocalDatasetBuilder<>(data, 10); - LocalDataset<Serializable, TestPartitionData> dataset = builder.build( - (upstream, upstreamSize) -> null, - (upstream, upstreamSize, ctx) -> { - int[] arr = new int[Math.toIntExact(upstreamSize)]; + LocalDataset<Serializable, TestPartitionData> dataset = buildDataset(builder); - int ptr = 0; - while (upstream.hasNext()) - arr[ptr++] = upstream.next().getValue(); - - return new TestPartitionData(arr); - } - ); + assertEquals(10, dataset.getCtx().size()); + assertEquals(10, dataset.getData().size()); AtomicLong cnt = new AtomicLong(); @@ -76,18 +68,7 @@ public class LocalDatasetBuilderTest { LocalDatasetBuilder<Integer, Integer> builder = new LocalDatasetBuilder<>(data, (k, v) -> k % 2 == 0,10); - LocalDataset<Serializable, TestPartitionData> dataset = builder.build( - (upstream, upstreamSize) -> null, - (upstream, upstreamSize, ctx) -> { - int[] arr = new int[Math.toIntExact(upstreamSize)]; - - int ptr = 0; - while (upstream.hasNext()) - arr[ptr++] = upstream.next().getValue(); - - return new TestPartitionData(arr); - } - ); + LocalDataset<Serializable, TestPartitionData> dataset = buildDataset(builder); AtomicLong cnt = new AtomicLong(); @@ -105,6 +86,23 @@ public class LocalDatasetBuilderTest { assertEquals(10, cnt.intValue()); } + /** */ + private LocalDataset<Serializable, TestPartitionData> buildDataset( + LocalDatasetBuilder<Integer, Integer> builder) { + return builder.build( + (upstream, upstreamSize) -> null, + (upstream, upstreamSize, ctx) -> { + int[] arr = new int[Math.toIntExact(upstreamSize)]; + + int ptr = 0; + while (upstream.hasNext()) + arr[ptr++] = upstream.next().getValue(); + + return new TestPartitionData(arr); + } + ); + } + /** * Test partition {@code data}. */ @@ -122,7 +120,7 @@ public class LocalDatasetBuilderTest { } /** {@inheritDoc} */ - @Override public void close() throws Exception { + @Override public void close() { // Do nothing, GC will clean up. } } http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleDatasetTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleDatasetTest.java new file mode 100644 index 0000000..eaa03d2 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleDatasetTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.dataset.primitive; + +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.dataset.DatasetFactory; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link SimpleDataset}. + */ +public class SimpleDatasetTest { + /** Basic test for SimpleDataset features. IMPL NOTE derived from LocalDatasetExample. */ + @Test + public void basicTest() throws Exception { + Map<Integer, DataPoint> dataPoints = new HashMap<Integer, DataPoint>() {{ + put(1, new DataPoint(42, 10000)); + put(2, new DataPoint(32, 64000)); + put(3, new DataPoint(53, 120000)); + put(4, new DataPoint(24, 70000)); + }}; + + // Creates a local simple dataset containing features and providing standard dataset API. + try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset( + dataPoints, + 2, + (k, v) -> VectorUtils.of(v.getAge(), v.getSalary()) + )) { + assertArrayEquals("Mean values.", new double[] {37.75, 66000.0}, dataset.mean(), 0); + + assertArrayEquals("Standard deviation values.", + new double[] {10.871407452579449, 38961.519477556314}, dataset.std(), 0); + + double[][] covExp = new double[][] { + new double[] {118.1875, 135500.0}, + new double[] {135500.0, 1.518E9} + }; + double[][] cov = dataset.cov(); + int rowCov = 0; + for (double[] row : cov) + assertArrayEquals("Covariance matrix row " + rowCov, + covExp[rowCov++], row, 0); + + + double[][] corrExp = new double[][] { + new double[] {1.0000000000000002, 0.31990250167874007}, + new double[] {0.31990250167874007, 1.0} + }; + double[][] corr = dataset.corr(); + int rowCorr = 0; + for (double[] row : corr) + assertArrayEquals("Correlation matrix row " + rowCorr, + corrExp[rowCorr++], row, 0); + } + } + + /** */ + private static class DataPoint { + /** Age. */ + private final double age; + + /** Salary. */ + private final double salary; + + /** + * Constructs a new instance of person. + * + * @param age Age. + * @param salary Salary. + */ + DataPoint(double age, double salary) { + this.age = age; + this.salary = salary; + } + + /** */ + double getAge() { + return age; + } + + /** */ + double getSalary() { + return salary; + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleLabeledDatasetTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleLabeledDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleLabeledDatasetTest.java new file mode 100644 index 0000000..f7b0f13 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleLabeledDatasetTest.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.dataset.primitive; + +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.dataset.DatasetFactory; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertNull; + +/** + * Tests for {@link SimpleLabeledDataset}. + */ +public class SimpleLabeledDatasetTest { + /** Basic test for SimpleLabeledDataset features. */ + @Test + public void basicTest() throws Exception { + Map<Integer, DataPoint> dataPoints = new HashMap<Integer, DataPoint>() {{ + put(5, new DataPoint(42, 10000)); + put(6, new DataPoint(32, 64000)); + put(7, new DataPoint(53, 120000)); + put(8, new DataPoint(24, 70000)); + }}; + + double[][] actualFeatures = new double[2][]; + double[][] actualLabels = new double[2][]; + int[] actualRows = new int[2]; + + // Creates a local simple dataset containing features and providing standard dataset API. + try (SimpleLabeledDataset<?> dataset = DatasetFactory.createSimpleLabeledDataset( + dataPoints, + 2, + (k, v) -> VectorUtils.of(v.getAge(), v.getSalary()), + (k, v) -> new double[] {k, v.getAge(), v.getSalary()} + )) { + assertNull(dataset.compute((data, partIdx) -> { + actualFeatures[partIdx] = data.getFeatures(); + actualLabels[partIdx] = data.getLabels(); + actualRows[partIdx] = data.getRows(); + return null; + }, (k, v) -> null)); + } + + double[][] expFeatures = new double[][] { + new double[] {42.0, 32.0, 10000.0, 64000.0}, + new double[] {53.0, 24.0, 120000.0, 70000.0} + }; + int rowFeat = 0; + for (double[] row : actualFeatures) + assertArrayEquals("Features partition index " + rowFeat, + expFeatures[rowFeat++], row, 0); + + double[][] expLabels = new double[][] { + new double[] {5.0, 6.0, 42.0, 32.0, 10000.0, 64000.0}, + new double[] {7.0, 8.0, 53.0, 24.0, 120000.0, 70000.0} + }; + int rowLbl = 0; + for (double[] row : actualLabels) + assertArrayEquals("Labels partition index " + rowLbl, + expLabels[rowLbl++], row, 0); + + assertArrayEquals("Rows per partitions", new int[] {2, 2}, actualRows); + } + + /** */ + private static class DataPoint { + /** Age. */ + private final double age; + + /** Salary. */ + private final double salary; + + /** + * Constructs a new instance of person. + * + * @param age Age. + * @param salary Salary. + */ + DataPoint(double age, double salary) { + this.age = age; + this.salary = salary; + } + + /** */ + double getAge() { + return age; + } + + /** */ + double getSalary() { + return salary; + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/environment/EnvironmentTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/environment/EnvironmentTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/environment/EnvironmentTestSuite.java new file mode 100644 index 0000000..527cc3c --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/environment/EnvironmentTestSuite.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.environment; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * Test suite for all tests located in org.apache.ignite.ml.trees package. + */ +@RunWith(Suite.class) [email protected]({ + LearningEnvironmentBuilderTest.class, + LearningEnvironmentTest.class +}) +public class EnvironmentTestSuite { +} http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilderTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilderTest.java new file mode 100644 index 0000000..56f262b --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilderTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.environment; + +import org.apache.ignite.logger.NullLogger; +import org.apache.ignite.ml.environment.logging.ConsoleLogger; +import org.apache.ignite.ml.environment.logging.CustomMLLogger; +import org.apache.ignite.ml.environment.logging.MLLogger; +import org.apache.ignite.ml.environment.logging.NoOpLogger; +import org.apache.ignite.ml.environment.parallelism.DefaultParallelismStrategy; +import org.apache.ignite.ml.environment.parallelism.NoParallelismStrategy; +import org.junit.Test; + +import static org.apache.ignite.ml.environment.parallelism.ParallelismStrategy.Type.NO_PARALLELISM; +import static org.apache.ignite.ml.environment.parallelism.ParallelismStrategy.Type.ON_DEFAULT_POOL; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + + +/** + * Tests for {@link LearningEnvironmentBuilder}. + */ +public class LearningEnvironmentBuilderTest { + /** */ + @Test + public void basic() { + LearningEnvironment env = LearningEnvironment.DEFAULT; + + assertNotNull("Strategy", env.parallelismStrategy()); + assertNotNull("Logger", env.logger()); + assertNotNull("Logger for class", env.logger(this.getClass())); + } + + /** */ + @Test + public void withParallelismStrategy() { + assertTrue(LearningEnvironment.builder().withParallelismStrategy(NoParallelismStrategy.INSTANCE).build() + .parallelismStrategy() instanceof NoParallelismStrategy); + + assertTrue(LearningEnvironment.builder().withParallelismStrategy(new DefaultParallelismStrategy()).build() + .parallelismStrategy() instanceof DefaultParallelismStrategy); + } + + /** */ + @Test + public void withParallelismStrategyType() { + assertTrue(LearningEnvironment.builder().withParallelismStrategy(NO_PARALLELISM).build() + .parallelismStrategy() instanceof NoParallelismStrategy); + + assertTrue(LearningEnvironment.builder().withParallelismStrategy(ON_DEFAULT_POOL).build() + .parallelismStrategy() instanceof DefaultParallelismStrategy); + } + + /** */ + @Test + public void withLoggingFactory() { + assertTrue(LearningEnvironment.builder().withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.HIGH)) + .build().logger() instanceof ConsoleLogger); + + assertTrue(LearningEnvironment.builder().withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.HIGH)) + .build().logger(this.getClass()) instanceof ConsoleLogger); + + assertTrue(LearningEnvironment.builder().withLoggingFactory(NoOpLogger.factory()) + .build().logger() instanceof NoOpLogger); + + assertTrue(LearningEnvironment.builder().withLoggingFactory(NoOpLogger.factory()) + .build().logger(this.getClass()) instanceof NoOpLogger); + + assertTrue(LearningEnvironment.builder().withLoggingFactory(CustomMLLogger.factory(new NullLogger())) + .build().logger() instanceof CustomMLLogger); + + assertTrue(LearningEnvironment.builder().withLoggingFactory(CustomMLLogger.factory(new NullLogger())) + .build().logger(this.getClass()) instanceof CustomMLLogger); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java new file mode 100644 index 0000000..9f8bab7 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.environment; + +import java.util.Arrays; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicReference; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.environment.logging.ConsoleLogger; +import org.apache.ignite.ml.environment.logging.MLLogger; +import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainer; +import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; +import org.apache.ignite.thread.IgniteThread; + +/** + * Tests for {@link LearningEnvironment} that require to start the whole Ignite infrastructure. IMPL NOTE based on + * RandomForestRegressionExample example. + */ +public class LearningEnvironmentTest extends GridCommonAbstractTest { + /** Number of nodes in grid */ + private static final int NODE_COUNT = 1; + + /** Ignite instance. */ + private Ignite ignite; + + /** {@inheritDoc} */ + @Override protected void beforeTestsStarted() throws Exception { + for (int i = 1; i <= NODE_COUNT; i++) + startGrid(i); + } + + /** {@inheritDoc} */ + @Override protected void afterTestsStopped() { + stopAllGrids(); + } + + /** + * {@inheritDoc} + */ + @Override protected void beforeTest() { + /* Grid instance. */ + ignite = grid(NODE_COUNT); + ignite.configuration().setPeerClassLoadingEnabled(true); + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + } + + /** */ + public void testBasic() throws InterruptedException { + AtomicReference<Integer> actualAmount = new AtomicReference<>(null); + AtomicReference<Double> actualMse = new AtomicReference<>(null); + AtomicReference<Double> actualMae = new AtomicReference<>(null); + + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + LearningEnvironmentTest.class.getSimpleName(), () -> { + IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); + + RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(13, 4, 101, 0.3, 2, 0); + + trainer.setEnvironment(LearningEnvironment.builder() + .withParallelismStrategy(ParallelismStrategy.Type.ON_DEFAULT_POOL) + .withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.LOW)) + .build() + ); + + ModelsComposition randomForest = trainer.fit(ignite, dataCache, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[v.length - 1] + ); + + double mse = 0.0; + double mae = 0.0; + int totalAmount = 0; + + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double difference = estimatePrediction(randomForest, observation); + + mse += Math.pow(difference, 2.0); + mae += Math.abs(difference); + + totalAmount++; + } + } + + actualAmount.set(totalAmount); + + mse = mse / totalAmount; + actualMse.set(mse); + + mae = mae / totalAmount; + actualMae.set(mae); + }); + + igniteThread.start(); + igniteThread.join(); + + assertEquals("Total amount", 23, (int)actualAmount.get()); + assertTrue("Mean squared error (MSE)", actualMse.get() > 0); + assertTrue("Mean absolute error (MAE)", actualMae.get() > 0); + } + + /** */ + private double estimatePrediction(ModelsComposition randomForest, Cache.Entry<Integer, double[]> observation) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 0, val.length - 1); + double groundTruth = val[val.length - 1]; + + double prediction = randomForest.apply(VectorUtils.of(inputs)); + + return prediction - groundTruth; + } + + /** + * Fills cache with data and returns it. + * + * @param ignite Ignite instance. + * @return Filled Ignite Cache. + */ + private IgniteCache<Integer, double[]> getTestCache(Ignite ignite) { + CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>(); + cacheConfiguration.setName(UUID.randomUUID().toString()); + cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); + + IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration); + + for (int i = 0; i < data.length; i++) + cache.put(i, data[i]); + + return cache; + } + + /** + * Part of the Boston housing dataset. + */ + private static final double[][] data = { + {0.02731,0.00,7.070,0,0.4690,6.4210,78.90,4.9671,2,242.0,17.80,396.90,9.14,21.60}, + {0.02729,0.00,7.070,0,0.4690,7.1850,61.10,4.9671,2,242.0,17.80,392.83,4.03,34.70}, + {0.03237,0.00,2.180,0,0.4580,6.9980,45.80,6.0622,3,222.0,18.70,394.63,2.94,33.40}, + {0.06905,0.00,2.180,0,0.4580,7.1470,54.20,6.0622,3,222.0,18.70,396.90,5.33,36.20}, + {0.02985,0.00,2.180,0,0.4580,6.4300,58.70,6.0622,3,222.0,18.70,394.12,5.21,28.70}, + {0.08829,12.50,7.870,0,0.5240,6.0120,66.60,5.5605,5,311.0,15.20,395.60,12.43,22.90}, + {0.14455,12.50,7.870,0,0.5240,6.1720,96.10,5.9505,5,311.0,15.20,396.90,19.15,27.10}, + {0.21124,12.50,7.870,0,0.5240,5.6310,100.00,6.0821,5,311.0,15.20,386.63,29.93,16.50}, + {0.17004,12.50,7.870,0,0.5240,6.0040,85.90,6.5921,5,311.0,15.20,386.71,17.10,18.90}, + {0.22489,12.50,7.870,0,0.5240,6.3770,94.30,6.3467,5,311.0,15.20,392.52,20.45,15.00}, + {0.11747,12.50,7.870,0,0.5240,6.0090,82.90,6.2267,5,311.0,15.20,396.90,13.27,18.90}, + {0.09378,12.50,7.870,0,0.5240,5.8890,39.00,5.4509,5,311.0,15.20,390.50,15.71,21.70}, + {0.62976,0.00,8.140,0,0.5380,5.9490,61.80,4.7075,4,307.0,21.00,396.90,8.26,20.40}, + {0.63796,0.00,8.140,0,0.5380,6.0960,84.50,4.4619,4,307.0,21.00,380.02,10.26,18.20}, + {0.62739,0.00,8.140,0,0.5380,5.8340,56.50,4.4986,4,307.0,21.00,395.62,8.47,19.90}, + {1.05393,0.00,8.140,0,0.5380,5.9350,29.30,4.4986,4,307.0,21.00,386.85,6.58,23.10}, + {0.78420,0.00,8.140,0,0.5380,5.9900,81.70,4.2579,4,307.0,21.00,386.75,14.67,17.50}, + {0.80271,0.00,8.140,0,0.5380,5.4560,36.60,3.7965,4,307.0,21.00,288.99,11.69,20.20}, + {0.72580,0.00,8.140,0,0.5380,5.7270,69.50,3.7965,4,307.0,21.00,390.95,11.28,18.20}, + {1.25179,0.00,8.140,0,0.5380,5.5700,98.10,3.7979,4,307.0,21.00,376.57,21.02,13.60}, + {0.85204,0.00,8.140,0,0.5380,5.9650,89.20,4.0123,4,307.0,21.00,392.53,13.83,19.60}, + {1.23247,0.00,8.140,0,0.5380,6.1420,91.70,3.9769,4,307.0,21.00,396.90,18.72,15.20}, + {0.98843,0.00,8.140,0,0.5380,5.8130,100.00,4.0952,4,307.0,21.00,394.54,19.88,14.50} + }; + +} + http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java index ab1ecee..aeb2414 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java @@ -34,8 +34,9 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import static junit.framework.TestCase.assertEquals; +import static junit.framework.TestCase.assertTrue; -/** Tests behaviour of KNNClassificationTest. */ +/** Tests behaviour of KNNClassification. */ @RunWith(Parameterized.class) public class KNNClassificationTest { /** Number of parts to be tested. */ @@ -58,7 +59,7 @@ public class KNNClassificationTest { /** */ @Test - public void testBinaryClassificationTest() { + public void testBinaryClassification() { Map<Integer, double[]> data = new HashMap<>(); data.put(0, new double[] {1.0, 1.0, 1.0}); data.put(1, new double[] {1.0, 2.0, 1.0}); @@ -78,6 +79,10 @@ public class KNNClassificationTest { .withDistanceMeasure(new EuclideanDistance()) .withStrategy(KNNStrategy.SIMPLE); + assertTrue(knnMdl.toString().length() > 0); + assertTrue(knnMdl.toString(true).length() > 0); + assertTrue(knnMdl.toString(false).length() > 0); + Vector firstVector = new DenseVector(new double[] {2.0, 2.0}); assertEquals(knnMdl.apply(firstVector), 1.0); Vector secondVector = new DenseVector(new double[] {-2.0, -2.0}); @@ -86,7 +91,7 @@ public class KNNClassificationTest { /** */ @Test - public void testBinaryClassificationWithSmallestKTest() { + public void testBinaryClassificationWithSmallestK() { Map<Integer, double[]> data = new HashMap<>(); data.put(0, new double[] {1.0, 1.0, 1.0}); data.put(1, new double[] {1.0, 2.0, 1.0}); http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java index 586e6c8..7d57ec9 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java @@ -87,41 +87,17 @@ public class KNNRegressionTest { /** */ @Test public void testLongly() { - Map<Integer, double[]> data = new HashMap<>(); - data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 1947}); - data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 1948}); - data.put(2, new double[] {60171, 88.2, 258054, 3682, 1616, 109773, 1949}); - data.put(3, new double[] {61187, 89.5, 284599, 3351, 1650, 110929, 1950}); - data.put(4, new double[] {63221, 96.2, 328975, 2099, 3099, 112075, 1951}); - data.put(5, new double[] {63639, 98.1, 346999, 1932, 3594, 113270, 1952}); - data.put(6, new double[] {64989, 99.0, 365385, 1870, 3547, 115094, 1953}); - data.put(7, new double[] {63761, 100.0, 363112, 3578, 3350, 116219, 1954}); - data.put(8, new double[] {66019, 101.2, 397469, 2904, 3048, 117388, 1955}); - data.put(9, new double[] {68169, 108.4, 442769, 2936, 2798, 120445, 1957}); - data.put(10, new double[] {66513, 110.8, 444546, 4681, 2637, 121950, 1958}); - data.put(11, new double[] {68655, 112.6, 482704, 3813, 2552, 123366, 1959}); - data.put(12, new double[] {69564, 114.2, 502601, 3931, 2514, 125368, 1960}); - data.put(13, new double[] {69331, 115.7, 518173, 4806, 2572, 127852, 1961}); - data.put(14, new double[] {70551, 116.9, 554894, 4007, 2827, 130081, 1962}); - - KNNRegressionTrainer trainer = new KNNRegressionTrainer(); - - KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit( - new LocalDatasetBuilder<>(data, parts), - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ).withK(3) - .withDistanceMeasure(new EuclideanDistance()) - .withStrategy(KNNStrategy.SIMPLE); - - Vector vector = new DenseVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956}); - System.out.println(knnMdl.apply(vector)); - Assert.assertEquals(67857, knnMdl.apply(vector), 2000); + testLongly(KNNStrategy.SIMPLE); } /** */ @Test public void testLonglyWithWeightedStrategy() { + testLongly(KNNStrategy.WEIGHTED); + } + + /** */ + private void testLongly(KNNStrategy stgy) { Map<Integer, double[]> data = new HashMap<>(); data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 1947}); data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 1948}); @@ -147,10 +123,16 @@ public class KNNRegressionTest { (k, v) -> v[0] ).withK(3) .withDistanceMeasure(new EuclideanDistance()) - .withStrategy(KNNStrategy.SIMPLE); + .withStrategy(stgy); Vector vector = new DenseVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956}); - System.out.println(knnMdl.apply(vector)); + + Assert.assertNotNull(knnMdl.apply(vector)); + Assert.assertEquals(67857, knnMdl.apply(vector), 2000); + + Assert.assertTrue(knnMdl.toString().contains(stgy.name())); + Assert.assertTrue(knnMdl.toString(true).contains(stgy.name())); + Assert.assertTrue(knnMdl.toString(false).contains(stgy.name())); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java index a029e49..9867fbe 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java @@ -21,11 +21,13 @@ import java.io.IOException; import java.net.URISyntaxException; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Objects; import org.apache.ignite.ml.math.ExternalizableTest; import org.apache.ignite.ml.math.exceptions.CardinalityException; import org.apache.ignite.ml.math.exceptions.NoDataException; import org.apache.ignite.ml.math.exceptions.knn.EmptyFileException; import org.apache.ignite.ml.math.exceptions.knn.FileParsingException; +import org.apache.ignite.ml.math.exceptions.knn.NoLabelVectorException; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.structures.LabeledDataset; import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair; @@ -36,7 +38,7 @@ import org.junit.Test; import static junit.framework.TestCase.assertEquals; import static junit.framework.TestCase.fail; -/** Tests behaviour of KNNClassificationTest. */ +/** Tests behaviour of LabeledDataset. */ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> { /** */ private static final String KNN_IRIS_TXT = "datasets/knn/iris.txt"; @@ -90,12 +92,22 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> { assertEquals(dataset.colSize(), 2); assertEquals(dataset.rowSize(), 6); + assertEquals(dataset.label(0), lbs[0], 0); + + assertEquals(dataset.copy().colSize(), 2); + + @SuppressWarnings("unchecked") final LabeledVector<Vector, Double> row = (LabeledVector<Vector, Double>)dataset.getRow(0); assertEquals(row.features().get(0), 1.0); assertEquals(row.label(), 1.0); dataset.setLabel(0, 2.0); assertEquals(row.label(), 2.0); + + assertEquals(0, new LabeledDataset().rowSize()); + assertEquals(1, new LabeledDataset(1, 2).rowSize()); + assertEquals(1, new LabeledDataset(1, 2, true).rowSize()); + assertEquals(1, new LabeledDataset(1, 2, null, true).rowSize()); } /** */ @@ -142,7 +154,7 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> { @Test public void testLoadingCorrectTxtFile() { LabeledDataset training = LabeledDatasetHelper.loadDatasetFromTxt(KNN_IRIS_TXT, false); - assertEquals(training.rowSize(), 150); + assertEquals(Objects.requireNonNull(training).rowSize(), 150); } /** */ @@ -175,7 +187,7 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> { @Test public void testLoadingFileWithIncorrectData() { LabeledDataset training = LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, false); - assertEquals(149, training.rowSize()); + assertEquals(149, Objects.requireNonNull(training).rowSize()); } /** */ @@ -195,7 +207,7 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> { /** */ @Test public void testLoadingFileWithMissedData() throws URISyntaxException, IOException { - Path path = Paths.get(this.getClass().getClassLoader().getResource(IRIS_MISSED_DATA).toURI()); + Path path = Paths.get(Objects.requireNonNull(this.getClass().getClassLoader().getResource(IRIS_MISSED_DATA)).toURI()); LabeledDataset training = LabeledDatasetLoader.loadFromTxtFile(path, ",", false, false); @@ -258,6 +270,13 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> { } /** */ + @Test(expected = NoLabelVectorException.class) + @SuppressWarnings("unchecked") + public void testSetLabelInvalid() { + new LabeledDataset(new LabeledVector[1]).setLabel(0, 2.0); + } + + /** */ @Override public void testExternalization() { double[][] mtx = new double[][] { http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/math/BlasTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/BlasTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/BlasTest.java index 61bde69..3bd7240 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/math/BlasTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/BlasTest.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.math; import java.util.Arrays; import java.util.function.BiPredicate; +import org.apache.ignite.ml.math.exceptions.NonSquareMatrixException; import org.apache.ignite.ml.math.primitives.matrix.Matrix; import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix; import org.apache.ignite.ml.math.primitives.matrix.impl.SparseMatrix; @@ -282,6 +283,15 @@ public class BlasTest { Assert.assertEquals(exp, y); } + /** Tests 'syr' operation for non-square dense matrix A. */ + @Test(expected = NonSquareMatrixException.class) + public void testSyrNonSquareMatrix() { + double alpha = 3.0; + DenseMatrix a = new DenseMatrix(new double[][] {{10.0, 11.0, 12.0}, {0.0, 1.0, 2.0}}, 2); + Vector x = new DenseVector(new double[] {1.0, 2.0}); + new Blas().syr(alpha, x, a); + } + /** * Create a sparse vector from array. * http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/DistanceTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/DistanceTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/DistanceTest.java index ed2ca11..4cfb092 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/DistanceTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/DistanceTest.java @@ -35,26 +35,31 @@ public class DistanceTest { private Vector v2; /** */ + private double[] data2; + + /** */ @Before public void setup() { + data2 = new double[] {2.0, 1.0, 0.0}; v1 = new DenseVector(new double[] {0.0, 0.0, 0.0}); - v2 = new DenseVector(new double[] {2.0, 1.0, 0.0}); + v2 = new DenseVector(data2); } /** */ @Test - public void euclideanDistance() throws Exception { - + public void euclideanDistance() { double expRes = Math.pow(5, 0.5); DistanceMeasure distanceMeasure = new EuclideanDistance(); Assert.assertEquals(expRes, distanceMeasure.compute(v1, v2), PRECISION); + + Assert.assertEquals(expRes, new EuclideanDistance().compute(v1, data2), PRECISION); } /** */ @Test - public void manhattanDistance() throws Exception { + public void manhattanDistance() { double expRes = 3; DistanceMeasure distanceMeasure = new ManhattanDistance(); @@ -64,7 +69,7 @@ public class DistanceTest { /** */ @Test - public void hammingDistance() throws Exception { + public void hammingDistance() { double expRes = 2; DistanceMeasure distanceMeasure = new HammingDistance(); @@ -72,4 +77,15 @@ public class DistanceTest { Assert.assertEquals(expRes, distanceMeasure.compute(v1, v2), PRECISION); } + /** */ + @Test(expected = UnsupportedOperationException.class) + public void manhattanDistance2() { + new ManhattanDistance().compute(v1, data2); + } + + /** */ + @Test(expected = UnsupportedOperationException.class) + public void hammingDistance2() { + new HammingDistance().compute(v1, data2); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java index 5d1dac3..6af03df 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java @@ -29,6 +29,8 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; /** * Tests for {@link LSQROnHeap}. @@ -73,7 +75,17 @@ public class LSQROnHeapTest { LSQRResult res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null); + assertEquals(3, res.getIterations()); + assertEquals(1, res.getIsstop()); + assertEquals(7.240617907140957E-14, res.getR1norm(), 0.0001); + assertEquals(7.240617907140957E-14, res.getR2norm(), 0.0001); + assertEquals(6.344288770224759, res.getAnorm(), 0.0001); + assertEquals(40.540617492419464, res.getAcond(), 0.0001); + assertEquals(3.4072322214704627E-13, res.getArnorm(), 0.0001); + assertEquals(3.000000000000001, res.getXnorm(), 0.0001); + assertArrayEquals(new double[]{0.0, 0.0, 0.0}, res.getVar(), 1e-6); assertArrayEquals(new double[]{1, -2, -2}, res.getX(), 1e-6); + assertTrue(res.toString().length() > 0); } /** Tests solving simple linear system with specified x0. */ @@ -97,6 +109,8 @@ public class LSQROnHeapTest { LSQRResult res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, new double[] {999, 999, 999}); + assertEquals(3, res.getIterations()); + assertArrayEquals(new double[]{1, -2, -2}, res.getX(), 1e-6); } @@ -126,6 +140,8 @@ public class LSQROnHeapTest { )) { LSQRResult res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null); + assertEquals(8, res.getIterations()); + assertArrayEquals(new double[]{72.26948107, 15.95144674, 24.07403921, 66.73038781}, res.getX(), 1e-6); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/math/primitives/vector/DelegatingVectorConstructorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/primitives/vector/DelegatingVectorConstructorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/primitives/vector/DelegatingVectorConstructorTest.java index 6b44c38..fbe6db8 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/math/primitives/vector/DelegatingVectorConstructorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/primitives/vector/DelegatingVectorConstructorTest.java @@ -33,12 +33,14 @@ public class DelegatingVectorConstructorTest { public void basicTest() { final Vector parent = new DenseVector(new double[] {0, 1}); - final Vector delegate = new DelegatingVector(parent); + final DelegatingVector delegate = new DelegatingVector(parent); final int size = parent.size(); assertEquals("Delegate size differs from expected.", size, delegate.size()); + assertEquals("Delegate vector differs from expected.", parent, delegate.getVector()); + for (int idx = 0; idx < size; idx++) assertDelegate(parent, delegate, idx); } http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/math/primitives/vector/SparseVectorConstructorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/primitives/vector/SparseVectorConstructorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/primitives/vector/SparseVectorConstructorTest.java index b53a952..1a6956f 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/math/primitives/vector/SparseVectorConstructorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/primitives/vector/SparseVectorConstructorTest.java @@ -17,11 +17,15 @@ package org.apache.ignite.ml.math.primitives.vector; +import java.util.HashMap; +import java.util.Map; import org.apache.ignite.ml.math.StorageConstants; import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector; import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** */ public class SparseVectorConstructorTest { @@ -52,4 +56,21 @@ public class SparseVectorConstructorTest { new SparseVector(1, StorageConstants.SEQUENTIAL_ACCESS_MODE).size()); } + + /** */ + @Test + public void noParamsCtorTest() { + assertNotNull(new SparseVector().nonZeroSpliterator()); + } + + /** */ + @Test + public void mapCtorTest() { + Map<Integer, Double> map = new HashMap<Integer, Double>() {{ + put(1, 1.); + }}; + + assertTrue("Copy true", new SparseVector(map, true).isRandomAccess()); + assertTrue("Copy false", new SparseVector(map, false).isRandomAccess()); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/nn/LossFunctionsTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/LossFunctionsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/LossFunctionsTest.java new file mode 100644 index 0000000..bef05ec --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/LossFunctionsTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn; + +import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; +import org.apache.ignite.ml.optimization.LossFunctions; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertNotNull; + +/** + * Tests for {@link LossFunctions}. + */ +public class LossFunctionsTest { + /** */ + @Test + public void testMSE() { + IgniteDifferentiableVectorToDoubleFunction f = LossFunctions.MSE.apply(new DenseVector(new double[] {2.0, 1.0})); + + assertNotNull(f); + + test(new double[] {1.0, 3.0}, f); + } + + /** */ + @Test + public void testLOG() { + IgniteDifferentiableVectorToDoubleFunction f = LossFunctions.LOG.apply(new DenseVector(new double[] {2.0, 1.0})); + + assertNotNull(f); + + test(new double[] {1.0, 3.0}, f); + } + + /** */ + @Test + public void testL2() { + IgniteDifferentiableVectorToDoubleFunction f = LossFunctions.L2.apply(new DenseVector(new double[] {2.0, 1.0})); + + assertNotNull(f); + + test(new double[] {1.0, 3.0}, f); + } + + /** */ + @Test + public void testL1() { + IgniteDifferentiableVectorToDoubleFunction f = LossFunctions.L1.apply(new DenseVector(new double[] {2.0, 1.0})); + + assertNotNull(f); + + test(new double[] {1.0, 3.0}, f); + } + + /** */ + @Test + public void testHINGE() { + IgniteDifferentiableVectorToDoubleFunction f = LossFunctions.HINGE.apply(new DenseVector(new double[] {2.0, 1.0})); + + assertNotNull(f); + + test(new double[] {1.0, 3.0}, f); + } + + /** */ + private void test(double[] expData, IgniteDifferentiableVectorToDoubleFunction f) { + verify(expData, f.differential(new DenseVector(new double[] {3.0, 4.0}))); + } + + /** */ + private void verify(double[] expData, Vector actual) { + assertArrayEquals(expData, new DenseVector(actual.size()).assign(actual).getStorage().data(), 0); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java index 51620b7..b0c0c95 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java @@ -58,13 +58,13 @@ public class MLPTest { withAddedLayer(2, true, Activators.SIGMOID). withAddedLayer(1, true, Activators.SIGMOID); - MultilayerPerceptron mlp = new MultilayerPerceptron(conf, new MLPConstInitializer(1, 2)); + MultilayerPerceptron mlp1 = new MultilayerPerceptron(conf, new MLPConstInitializer(1, 2)); - mlp.setWeights(1, new DenseMatrix(new double[][] {{20.0, 20.0}, {-20.0, -20.0}})); - mlp.setBiases(1, new DenseVector(new double[] {-10.0, 30.0})); + mlp1.setWeights(1, new DenseMatrix(new double[][] {{20.0, 20.0}, {-20.0, -20.0}})); + mlp1.setBiases(1, new DenseVector(new double[] {-10.0, 30.0})); - mlp.setWeights(2, new DenseMatrix(new double[][] {{20.0, 20.0}})); - mlp.setBiases(2, new DenseVector(new double[] {-30.0})); + MultilayerPerceptron mlp2 = mlp1.setWeights(2, new DenseMatrix(new double[][] {{20.0, 20.0}})); + MultilayerPerceptron mlp = mlp2.setBiases(2, new DenseVector(new double[] {-30.0})); Matrix input = new DenseMatrix(new double[][] {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}}); @@ -106,6 +106,39 @@ public class MLPTest { } /** + * Test three layer MLP. + */ + @Test + public void testStackedTwiceMLP() { + int firstLayerNeuronsCnt = 3; + int secondLayerNeuronsCnt = 2; + int thirdLayerNeuronsCnt = 4; + MLPConstInitializer initer = new MLPConstInitializer(1, 2); + + MLPArchitecture mlpLayer1Conf = new MLPArchitecture(4). + withAddedLayer(firstLayerNeuronsCnt, true, Activators.SIGMOID); + MLPArchitecture mlpLayer2Conf = new MLPArchitecture(firstLayerNeuronsCnt). + withAddedLayer(secondLayerNeuronsCnt, false, Activators.SIGMOID); + MLPArchitecture mlpLayer3Conf = new MLPArchitecture(secondLayerNeuronsCnt). + withAddedLayer(thirdLayerNeuronsCnt, false, Activators.SIGMOID); + + MultilayerPerceptron mlp1 = new MultilayerPerceptron(mlpLayer1Conf, initer); + MultilayerPerceptron mlp2 = new MultilayerPerceptron(mlpLayer2Conf, initer); + MultilayerPerceptron mlp3 = new MultilayerPerceptron(mlpLayer3Conf, initer); + + Assert.assertEquals(1., mlp1.weight(1, 0, 1), 0); + + MultilayerPerceptron stackedMLP = mlp1.add(mlp2).add(mlp3); + + Assert.assertTrue(stackedMLP.toString().length() > 0); + Assert.assertTrue(stackedMLP.toString(true).length() > 0); + Assert.assertTrue(stackedMLP.toString(false).length() > 0); + + Assert.assertEquals(4, stackedMLP.architecture().outputSize()); + Assert.assertEquals(8, stackedMLP.architecture().layersCount()); + } + + /** * Test parameters count works well. */ @Test @@ -169,10 +202,10 @@ public class MLPTest { MLPArchitecture conf = new MLPArchitecture(inputSize). withAddedLayer(firstLayerNeuronsCnt, false, Activators.SIGMOID); - MultilayerPerceptron mlp = new MultilayerPerceptron(conf); + MultilayerPerceptron mlp1 = new MultilayerPerceptron(conf); - mlp.setWeight(1, 0, 0, w10); - mlp.setWeight(1, 1, 0, w11); + mlp1.setWeight(1, 0, 0, w10); + MultilayerPerceptron mlp = mlp1.setWeight(1, 1, 0, w11); double x0 = 1.0; double x1 = 3.0; @@ -205,4 +238,26 @@ public class MLPTest { Assert.assertEquals(mlp.architecture().parametersCount(), grad.size()); Assert.assertEquals(trueGrad, grad); } + + /** + * Test methods related to per-neuron bias. + */ + @Test + public void testNeuronBias() { + int inputSize = 3; + int firstLayerNeuronsCnt = 2; + int secondLayerNeurons = 1; + + MLPArchitecture conf = new MLPArchitecture(inputSize). + withAddedLayer(firstLayerNeuronsCnt, false, Activators.SIGMOID). + withAddedLayer(secondLayerNeurons, true, Activators.SIGMOID); + + MultilayerPerceptron mlp1 = new MultilayerPerceptron(conf, new MLPConstInitializer(100, 200)); + + MultilayerPerceptron mlp = mlp1.setBias(2, 0, 1.); + Assert.assertEquals(1., mlp.bias(2, 0), 0); + + mlp.setBias(2, 0, 0.5); + Assert.assertEquals(0.5, mlp.bias(2, 0), 0); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTestSuite.java index 2e41813..3f98ba5 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTestSuite.java @@ -27,7 +27,8 @@ import org.junit.runners.Suite; @Suite.SuiteClasses({ MLPTest.class, MLPTrainerTest.class, - MLPTrainerIntegrationTest.class + MLPTrainerIntegrationTest.class, + LossFunctionsTest.class }) public class MLPTestSuite { // No-op.
