IGNITE-8233: KNN and SVM algorithms don't work when partition doesn't contain data.
this closes #3807 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/ee9ca06a Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/ee9ca06a Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/ee9ca06a Branch: refs/heads/ignite-7708 Commit: ee9ca06a8cbec6eec2b963fb1db1b3f383fc1837 Parents: 568c3e7 Author: dmitrievanthony <dmitrievanth...@gmail.com> Authored: Fri Apr 13 18:02:37 2018 +0300 Committer: Yury Babak <yba...@gridgain.com> Committed: Fri Apr 13 18:02:37 2018 +0300 ---------------------------------------------------------------------- .../dataset/impl/cache/CacheBasedDataset.java | 14 ++- .../dataset/impl/cache/util/ComputeUtils.java | 9 +- .../ml/dataset/impl/local/LocalDataset.java | 16 ++- .../dataset/impl/local/LocalDatasetBuilder.java | 8 +- .../classification/KNNClassificationModel.java | 32 +++-- .../impl/cache/CacheBasedDatasetTest.java | 9 +- .../ignite/ml/knn/KNNClassificationTest.java | 120 ++++++++++-------- .../apache/ignite/ml/knn/KNNRegressionTest.java | 122 +++++++++++-------- .../ignite/ml/knn/LabeledDatasetHelper.java | 45 ++----- .../ignite/ml/knn/LabeledDatasetTest.java | 54 ++++---- 10 files changed, 226 insertions(+), 203 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java index 463d496..7428faf 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java @@ -101,12 +101,16 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose partDataBuilder ); - R res = map.apply(ctx, data, part); + if (data != null) { + R res = map.apply(ctx, data, part); - // Saves partition context after update. - ComputeUtils.saveContext(Ignition.localIgnite(), datasetCacheName, part, ctx); + // Saves partition context after update. + ComputeUtils.saveContext(Ignition.localIgnite(), datasetCacheName, part, ctx); - return res; + return res; + } + + return null; }, reduce, identity); } @@ -125,7 +129,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose partDataBuilder ); - return map.apply(data, part); + return data != null ? map.apply(data, part) : null; }, reduce, identity); } http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java index 0785db2..ce2fcfd 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java @@ -163,9 +163,14 @@ public class ComputeUtils { qry.setPartition(part); long cnt = upstreamCache.localSizeLong(part); - try (QueryCursor<Cache.Entry<K, V>> cursor = upstreamCache.query(qry)) { - return partDataBuilder.build(new UpstreamCursorAdapter<>(cursor.iterator(), cnt), cnt, ctx); + + if (cnt > 0) { + try (QueryCursor<Cache.Entry<K, V>> cursor = upstreamCache.query(qry)) { + return partDataBuilder.build(new UpstreamCursorAdapter<>(cursor.iterator(), cnt), cnt, ctx); + } } + + return null; }); } http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java index c08b7de..e312b20 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java @@ -55,8 +55,12 @@ public class LocalDataset<C extends Serializable, D extends AutoCloseable> imple R identity) { R res = identity; - for (int part = 0; part < ctx.size(); part++) - res = reduce.apply(res, map.apply(ctx.get(part), data.get(part), part)); + for (int part = 0; part < ctx.size(); part++) { + D partData = data.get(part); + + if (partData != null) + res = reduce.apply(res, map.apply(ctx.get(part), partData, part)); + } return res; } @@ -65,8 +69,12 @@ public class LocalDataset<C extends Serializable, D extends AutoCloseable> imple @Override public <R> R compute(IgniteBiFunction<D, Integer, R> map, IgniteBinaryOperator<R> reduce, R identity) { R res = identity; - for (int part = 0; part < data.size(); part++) - res = reduce.apply(res, map.apply(data.get(part), part)); + for (int part = 0; part < data.size(); part++) { + D partData = data.get(part); + + if (partData != null) + res = reduce.apply(res, map.apply(partData, part)); + } return res; } http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java index 0dc1ed6..cfc1801 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java @@ -69,16 +69,16 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> { for (int part = 0; part < partitions; part++) { int cnt = part == partitions - 1 ? upstreamMap.size() - ptr : Math.min(partSize, upstreamMap.size() - ptr); - C ctx = partCtxBuilder.build( + C ctx = cnt > 0 ? partCtxBuilder.build( new IteratorWindow<>(firstKeysIter, k -> new UpstreamEntry<>(k, upstreamMap.get(k)), cnt), cnt - ); + ) : null; - D data = partDataBuilder.build( + D data = cnt > 0 ? partDataBuilder.build( new IteratorWindow<>(secondKeysIter, k -> new UpstreamEntry<>(k, upstreamMap.get(k)), cnt), cnt, ctx - ); + ) : null; ctxList.add(ctx); dataList.add(data); http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java index 693b81d..0f0cc9f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java @@ -151,19 +151,29 @@ public class KNNClassificationModel<K, V> implements Model<Vector, Double>, Expo */ @NotNull private LabeledVector[] getKClosestVectors(LabeledDataset<Double, LabeledVector> trainingData, TreeMap<Double, Set<Integer>> distanceIdxPairs) { - LabeledVector[] res = new LabeledVector[k]; - int i = 0; - final Iterator<Double> iter = distanceIdxPairs.keySet().iterator(); - while (i < k) { - double key = iter.next(); - Set<Integer> idxs = distanceIdxPairs.get(key); - for (Integer idx : idxs) { - res[i] = trainingData.getRow(idx); - i++; - if (i >= k) - break; // go to next while-loop iteration + LabeledVector[] res; + + if (trainingData.rowSize() <= k) { + res = new LabeledVector[trainingData.rowSize()]; + for (int i = 0; i < trainingData.rowSize(); i++) + res[i] = trainingData.getRow(i); + } + else { + res = new LabeledVector[k]; + int i = 0; + final Iterator<Double> iter = distanceIdxPairs.keySet().iterator(); + while (i < k) { + double key = iter.next(); + Set<Integer> idxs = distanceIdxPairs.get(key); + for (Integer idx : idxs) { + res[i] = trainingData.getRow(idx); + i++; + if (i >= k) + break; // go to next while-loop iteration + } } } + return res; } http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java index dc0e160..16ba044 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java @@ -38,6 +38,7 @@ import org.apache.ignite.internal.processors.cache.distributed.dht.GridDhtPartit import org.apache.ignite.internal.util.IgniteUtils; import org.apache.ignite.internal.util.typedef.G; import org.apache.ignite.lang.IgnitePredicate; +import org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; /** @@ -81,9 +82,9 @@ public class CacheBasedDatasetTest extends GridCommonAbstractTest { CacheBasedDatasetBuilder<Integer, String> builder = new CacheBasedDatasetBuilder<>(ignite, upstreamCache); - CacheBasedDataset<Integer, String, Long, AutoCloseable> dataset = builder.build( + CacheBasedDataset<Integer, String, Long, SimpleDatasetData> dataset = builder.build( (upstream, upstreamSize) -> upstreamSize, - (upstream, upstreamSize, ctx) -> null + (upstream, upstreamSize, ctx) -> new SimpleDatasetData(new double[0], 0) ); assertTrue("Before computation all partitions should not be reserved", @@ -133,9 +134,9 @@ public class CacheBasedDatasetTest extends GridCommonAbstractTest { CacheBasedDatasetBuilder<Integer, String> builder = new CacheBasedDatasetBuilder<>(ignite, upstreamCache); - CacheBasedDataset<Integer, String, Long, AutoCloseable> dataset = builder.build( + CacheBasedDataset<Integer, String, Long, SimpleDatasetData> dataset = builder.build( (upstream, upstreamSize) -> upstreamSize, - (upstream, upstreamSize, ctx) -> null + (upstream, upstreamSize, ctx) -> new SimpleDatasetData(new double[0], 0) ); assertTrue("Before computation all partitions should not be reserved", http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java index 0877fc0..004718e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java @@ -17,11 +17,11 @@ package org.apache.ignite.ml.knn; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; -import org.junit.Assert; -import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.knn.classification.KNNClassificationModel; import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer; import org.apache.ignite.ml.knn.classification.KNNStrategy; @@ -29,121 +29,137 @@ import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static junit.framework.TestCase.assertEquals; /** Tests behaviour of KNNClassificationTest. */ +@RunWith(Parameterized.class) public class KNNClassificationTest { - /** Precision in test checks. */ - private static final double PRECISION = 1e-2; + /** Number of parts to be tested. */ + private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7, 100}; + + /** Number of partitions. */ + @Parameterized.Parameter + public int parts; + + /** Parameters. */ + @Parameterized.Parameters(name = "Data divided on {0} partitions, training with batch size {1}") + public static Iterable<Integer[]> data() { + List<Integer[]> res = new ArrayList<>(); + + for (int part : partsToBeTested) + res.add(new Integer[] {part}); + + return res; + } /** */ @Test - public void binaryClassificationTest() { - + public void testBinaryClassificationTest() { Map<Integer, double[]> data = new HashMap<>(); - data.put(0, new double[]{1.0, 1.0, 1.0}); - data.put(1, new double[]{1.0, 2.0, 1.0}); - data.put(2, new double[]{2.0, 1.0, 1.0}); - data.put(3, new double[]{-1.0, -1.0, 2.0}); - data.put(4, new double[]{-1.0, -2.0, 2.0}); - data.put(5, new double[]{-2.0, -1.0, 2.0}); + data.put(0, new double[] {1.0, 1.0, 1.0}); + data.put(1, new double[] {1.0, 2.0, 1.0}); + data.put(2, new double[] {2.0, 1.0, 1.0}); + data.put(3, new double[] {-1.0, -1.0, 2.0}); + data.put(4, new double[] {-1.0, -2.0, 2.0}); + data.put(5, new double[] {-2.0, -1.0, 2.0}); KNNClassificationTrainer trainer = new KNNClassificationTrainer(); KNNClassificationModel knnMdl = trainer.fit( data, - 2, + parts, (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), (k, v) -> v[2] ).withK(3) .withDistanceMeasure(new EuclideanDistance()) .withStrategy(KNNStrategy.SIMPLE); - Vector firstVector = new DenseLocalOnHeapVector(new double[]{2.0, 2.0}); - Assert.assertEquals(knnMdl.apply(firstVector), 1.0, PRECISION); - Vector secondVector = new DenseLocalOnHeapVector(new double[]{-2.0, -2.0}); - Assert.assertEquals(knnMdl.apply(secondVector), 2.0, PRECISION); + Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0, 2.0}); + assertEquals(knnMdl.apply(firstVector), 1.0); + Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0, -2.0}); + assertEquals(knnMdl.apply(secondVector), 2.0); } /** */ @Test - public void binaryClassificationWithSmallestKTest() { + public void testBinaryClassificationWithSmallestKTest() { Map<Integer, double[]> data = new HashMap<>(); - - data.put(0, new double[]{1.0, 1.0, 1.0}); - data.put(1, new double[]{1.0, 2.0, 1.0}); - data.put(2, new double[]{2.0, 1.0, 1.0}); - data.put(3, new double[]{-1.0, -1.0, 2.0}); - data.put(4, new double[]{-1.0, -2.0, 2.0}); - data.put(5, new double[]{-2.0, -1.0, 2.0}); + data.put(0, new double[] {1.0, 1.0, 1.0}); + data.put(1, new double[] {1.0, 2.0, 1.0}); + data.put(2, new double[] {2.0, 1.0, 1.0}); + data.put(3, new double[] {-1.0, -1.0, 2.0}); + data.put(4, new double[] {-1.0, -2.0, 2.0}); + data.put(5, new double[] {-2.0, -1.0, 2.0}); KNNClassificationTrainer trainer = new KNNClassificationTrainer(); KNNClassificationModel knnMdl = trainer.fit( data, - 2, + parts, (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), (k, v) -> v[2] ).withK(1) .withDistanceMeasure(new EuclideanDistance()) .withStrategy(KNNStrategy.SIMPLE); - Vector firstVector = new DenseLocalOnHeapVector(new double[]{2.0, 2.0}); - Assert.assertEquals(knnMdl.apply(firstVector), 1.0, PRECISION); - Vector secondVector = new DenseLocalOnHeapVector(new double[]{-2.0, -2.0}); - Assert.assertEquals(knnMdl.apply(secondVector), 2.0, PRECISION); + Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0, 2.0}); + assertEquals(knnMdl.apply(firstVector), 1.0); + Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0, -2.0}); + assertEquals(knnMdl.apply(secondVector), 2.0); } /** */ @Test - public void binaryClassificationFarPointsWithSimpleStrategy() { + public void testBinaryClassificationFarPointsWithSimpleStrategy() { Map<Integer, double[]> data = new HashMap<>(); - - data.put(0, new double[]{10.0, 10.0, 1.0}); - data.put(1, new double[]{10.0, 20.0, 1.0}); - data.put(2, new double[]{-1, -1, 1.0}); - data.put(3, new double[]{-2, -2, 2.0}); - data.put(4, new double[]{-1.0, -2.0, 2.0}); - data.put(5, new double[]{-2.0, -1.0, 2.0}); + data.put(0, new double[] {10.0, 10.0, 1.0}); + data.put(1, new double[] {10.0, 20.0, 1.0}); + data.put(2, new double[] {-1, -1, 1.0}); + data.put(3, new double[] {-2, -2, 2.0}); + data.put(4, new double[] {-1.0, -2.0, 2.0}); + data.put(5, new double[] {-2.0, -1.0, 2.0}); KNNClassificationTrainer trainer = new KNNClassificationTrainer(); KNNClassificationModel knnMdl = trainer.fit( data, - 2, + parts, (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), (k, v) -> v[2] ).withK(3) .withDistanceMeasure(new EuclideanDistance()) .withStrategy(KNNStrategy.SIMPLE); - Vector vector = new DenseLocalOnHeapVector(new double[]{-1.01, -1.01}); - Assert.assertEquals(knnMdl.apply(vector), 2.0, PRECISION); + Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01, -1.01}); + assertEquals(knnMdl.apply(vector), 2.0); } /** */ @Test - public void binaryClassificationFarPointsWithWeightedStrategy() { + public void testBinaryClassificationFarPointsWithWeightedStrategy() { Map<Integer, double[]> data = new HashMap<>(); - - data.put(0, new double[]{10.0, 10.0, 1.0}); - data.put(1, new double[]{10.0, 20.0, 1.0}); - data.put(2, new double[]{-1, -1, 1.0}); - data.put(3, new double[]{-2, -2, 2.0}); - data.put(4, new double[]{-1.0, -2.0, 2.0}); - data.put(5, new double[]{-2.0, -1.0, 2.0}); + data.put(0, new double[] {10.0, 10.0, 1.0}); + data.put(1, new double[] {10.0, 20.0, 1.0}); + data.put(2, new double[] {-1, -1, 1.0}); + data.put(3, new double[] {-2, -2, 2.0}); + data.put(4, new double[] {-1.0, -2.0, 2.0}); + data.put(5, new double[] {-2.0, -1.0, 2.0}); KNNClassificationTrainer trainer = new KNNClassificationTrainer(); KNNClassificationModel knnMdl = trainer.fit( data, - 2, + parts, (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), (k, v) -> v[2] ).withK(3) .withDistanceMeasure(new EuclideanDistance()) .withStrategy(KNNStrategy.WEIGHTED); - Vector vector = new DenseLocalOnHeapVector(new double[]{-1.01, -1.01}); - Assert.assertEquals(knnMdl.apply(vector), 1.0, PRECISION); + Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01, -1.01}); + assertEquals(knnMdl.apply(vector), 1.0); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java index ce9cae5..0c26ba9 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java @@ -17,6 +17,11 @@ package org.apache.ignite.ml.knn; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.knn.classification.KNNStrategy; import org.apache.ignite.ml.knn.regression.KNNRegressionModel; @@ -25,110 +30,125 @@ import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.junit.Assert; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; /** * Tests for {@link KNNRegressionTrainer}. */ +@RunWith(Parameterized.class) public class KNNRegressionTest { + /** Number of parts to be tested. */ + private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7, 100}; + + /** Number of partitions. */ + @Parameterized.Parameter + public int parts; + + /** Parameters. */ + @Parameterized.Parameters(name = "Data divided on {0} partitions, training with batch size {1}") + public static Iterable<Integer[]> data() { + List<Integer[]> res = new ArrayList<>(); + + for (int part : partsToBeTested) + res.add(new Integer[] {part}); + + return res; + } + /** */ @Test - public void simpleRegressionWithOneNeighbour() { + public void testSimpleRegressionWithOneNeighbour() { Map<Integer, double[]> data = new HashMap<>(); - - data.put(0, new double[]{11.0, 0, 0, 0, 0, 0}); - data.put(1, new double[]{12.0, 2.0, 0, 0, 0, 0}); - data.put(2, new double[]{13.0, 0, 3.0, 0, 0, 0}); - data.put(3, new double[]{14.0, 0, 0, 4.0, 0, 0}); - data.put(4, new double[]{15.0, 0, 0, 0, 5.0, 0}); - data.put(5, new double[]{16.0, 0, 0, 0, 0, 6.0}); + data.put(0, new double[] {11.0, 0, 0, 0, 0, 0}); + data.put(1, new double[] {12.0, 2.0, 0, 0, 0, 0}); + data.put(2, new double[] {13.0, 0, 3.0, 0, 0, 0}); + data.put(3, new double[] {14.0, 0, 0, 4.0, 0, 0}); + data.put(4, new double[] {15.0, 0, 0, 0, 5.0, 0}); + data.put(5, new double[] {16.0, 0, 0, 0, 0, 6.0}); KNNRegressionTrainer trainer = new KNNRegressionTrainer(); KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit( - new LocalDatasetBuilder<>(data, 2), + new LocalDatasetBuilder<>(data, parts), (k, v) -> Arrays.copyOfRange(v, 1, v.length), (k, v) -> v[0] ).withK(1) .withDistanceMeasure(new EuclideanDistance()) .withStrategy(KNNStrategy.SIMPLE); - Vector vector = new DenseLocalOnHeapVector(new double[]{0, 0, 0, 5.0, 0.0}); + Vector vector = new DenseLocalOnHeapVector(new double[] {0, 0, 0, 5.0, 0.0}); System.out.println(knnMdl.apply(vector)); Assert.assertEquals(15, knnMdl.apply(vector), 1E-12); } /** */ @Test - public void longly() { + public void testLongly() { Map<Integer, double[]> data = new HashMap<>(); - - data.put(0, new double[]{60323, 83.0, 234289, 2356, 1590, 107608, 1947}); - data.put(1, new double[]{61122, 88.5, 259426, 2325, 1456, 108632, 1948}); - data.put(2, new double[]{60171, 88.2, 258054, 3682, 1616, 109773, 1949}); - data.put(3, new double[]{61187, 89.5, 284599, 3351, 1650, 110929, 1950}); - data.put(4, new double[]{63221, 96.2, 328975, 2099, 3099, 112075, 1951}); - data.put(5, new double[]{63639, 98.1, 346999, 1932, 3594, 113270, 1952}); - data.put(6, new double[]{64989, 99.0, 365385, 1870, 3547, 115094, 1953}); - data.put(7, new double[]{63761, 100.0, 363112, 3578, 3350, 116219, 1954}); - data.put(8, new double[]{66019, 101.2, 397469, 2904, 3048, 117388, 1955}); - data.put(9, new double[]{68169, 108.4, 442769, 2936, 2798, 120445, 1957}); - data.put(10, new double[]{66513, 110.8, 444546, 4681, 2637, 121950, 1958}); - data.put(11, new double[]{68655, 112.6, 482704, 3813, 2552, 123366, 1959}); - data.put(12, new double[]{69564, 114.2, 502601, 3931, 2514, 125368, 1960}); - data.put(13, new double[]{69331, 115.7, 518173, 4806, 2572, 127852, 1961}); - data.put(14, new double[]{70551, 116.9, 554894, 4007, 2827, 130081, 1962}); + data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 1947}); + data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 1948}); + data.put(2, new double[] {60171, 88.2, 258054, 3682, 1616, 109773, 1949}); + data.put(3, new double[] {61187, 89.5, 284599, 3351, 1650, 110929, 1950}); + data.put(4, new double[] {63221, 96.2, 328975, 2099, 3099, 112075, 1951}); + data.put(5, new double[] {63639, 98.1, 346999, 1932, 3594, 113270, 1952}); + data.put(6, new double[] {64989, 99.0, 365385, 1870, 3547, 115094, 1953}); + data.put(7, new double[] {63761, 100.0, 363112, 3578, 3350, 116219, 1954}); + data.put(8, new double[] {66019, 101.2, 397469, 2904, 3048, 117388, 1955}); + data.put(9, new double[] {68169, 108.4, 442769, 2936, 2798, 120445, 1957}); + data.put(10, new double[] {66513, 110.8, 444546, 4681, 2637, 121950, 1958}); + data.put(11, new double[] {68655, 112.6, 482704, 3813, 2552, 123366, 1959}); + data.put(12, new double[] {69564, 114.2, 502601, 3931, 2514, 125368, 1960}); + data.put(13, new double[] {69331, 115.7, 518173, 4806, 2572, 127852, 1961}); + data.put(14, new double[] {70551, 116.9, 554894, 4007, 2827, 130081, 1962}); KNNRegressionTrainer trainer = new KNNRegressionTrainer(); KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit( - new LocalDatasetBuilder<>(data, 2), + new LocalDatasetBuilder<>(data, parts), (k, v) -> Arrays.copyOfRange(v, 1, v.length), (k, v) -> v[0] ).withK(3) .withDistanceMeasure(new EuclideanDistance()) .withStrategy(KNNStrategy.SIMPLE); - Vector vector = new DenseLocalOnHeapVector(new double[]{104.6, 419180, 2822, 2857, 118734, 1956}); + Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956}); System.out.println(knnMdl.apply(vector)); Assert.assertEquals(67857, knnMdl.apply(vector), 2000); } /** */ + @Test public void testLonglyWithWeightedStrategy() { Map<Integer, double[]> data = new HashMap<>(); - - data.put(0, new double[]{60323, 83.0, 234289, 2356, 1590, 107608, 1947}); - data.put(1, new double[]{61122, 88.5, 259426, 2325, 1456, 108632, 1948}); - data.put(2, new double[]{60171, 88.2, 258054, 3682, 1616, 109773, 1949}); - data.put(3, new double[]{61187, 89.5, 284599, 3351, 1650, 110929, 1950}); - data.put(4, new double[]{63221, 96.2, 328975, 2099, 3099, 112075, 1951}); - data.put(5, new double[]{63639, 98.1, 346999, 1932, 3594, 113270, 1952}); - data.put(6, new double[]{64989, 99.0, 365385, 1870, 3547, 115094, 1953}); - data.put(7, new double[]{63761, 100.0, 363112, 3578, 3350, 116219, 1954}); - data.put(8, new double[]{66019, 101.2, 397469, 2904, 3048, 117388, 1955}); - data.put(9, new double[]{68169, 108.4, 442769, 2936, 2798, 120445, 1957}); - data.put(10, new double[]{66513, 110.8, 444546, 4681, 2637, 121950, 1958}); - data.put(11, new double[]{68655, 112.6, 482704, 3813, 2552, 123366, 1959}); - data.put(12, new double[]{69564, 114.2, 502601, 3931, 2514, 125368, 1960}); - data.put(13, new double[]{69331, 115.7, 518173, 4806, 2572, 127852, 1961}); - data.put(14, new double[]{70551, 116.9, 554894, 4007, 2827, 130081, 1962}); + data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 1947}); + data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 1948}); + data.put(2, new double[] {60171, 88.2, 258054, 3682, 1616, 109773, 1949}); + data.put(3, new double[] {61187, 89.5, 284599, 3351, 1650, 110929, 1950}); + data.put(4, new double[] {63221, 96.2, 328975, 2099, 3099, 112075, 1951}); + data.put(5, new double[] {63639, 98.1, 346999, 1932, 3594, 113270, 1952}); + data.put(6, new double[] {64989, 99.0, 365385, 1870, 3547, 115094, 1953}); + data.put(7, new double[] {63761, 100.0, 363112, 3578, 3350, 116219, 1954}); + data.put(8, new double[] {66019, 101.2, 397469, 2904, 3048, 117388, 1955}); + data.put(9, new double[] {68169, 108.4, 442769, 2936, 2798, 120445, 1957}); + data.put(10, new double[] {66513, 110.8, 444546, 4681, 2637, 121950, 1958}); + data.put(11, new double[] {68655, 112.6, 482704, 3813, 2552, 123366, 1959}); + data.put(12, new double[] {69564, 114.2, 502601, 3931, 2514, 125368, 1960}); + data.put(13, new double[] {69331, 115.7, 518173, 4806, 2572, 127852, 1961}); + data.put(14, new double[] {70551, 116.9, 554894, 4007, 2827, 130081, 1962}); KNNRegressionTrainer trainer = new KNNRegressionTrainer(); KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit( - new LocalDatasetBuilder<>(data, 2), + new LocalDatasetBuilder<>(data, parts), (k, v) -> Arrays.copyOfRange(v, 1, v.length), (k, v) -> v[0] ).withK(3) .withDistanceMeasure(new EuclideanDistance()) .withStrategy(KNNStrategy.SIMPLE); - Vector vector = new DenseLocalOnHeapVector(new double[]{104.6, 419180, 2822, 2857, 118734, 1956}); + Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956}); System.out.println(knnMdl.apply(vector)); Assert.assertEquals(67857, knnMdl.apply(vector), 2000); } http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java index a25b303..dbcdb99 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java @@ -21,64 +21,33 @@ import java.io.IOException; import java.net.URISyntaxException; import java.nio.file.Path; import java.nio.file.Paths; -import org.apache.ignite.Ignite; import org.apache.ignite.ml.structures.LabeledDataset; import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader; -import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; /** * Base class for decision trees test. */ -public class LabeledDatasetHelper extends GridCommonAbstractTest { - /** Count of nodes. */ - private static final int NODE_COUNT = 4; - +public class LabeledDatasetHelper { /** Separator. */ private static final String SEPARATOR = "\t"; - /** Grid instance. */ - protected Ignite ignite; - - /** - * Default constructor. - */ - public LabeledDatasetHelper() { - super(false); - } - - /** - * {@inheritDoc} - */ - @Override protected void beforeTest() throws Exception { - ignite = grid(NODE_COUNT); - } - - /** {@inheritDoc} */ - @Override protected void beforeTestsStarted() throws Exception { - for (int i = 1; i <= NODE_COUNT; i++) - startGrid(i); - } - - /** {@inheritDoc} */ - @Override protected void afterTestsStopped() throws Exception { - stopAllGrids(); - } - /** * Loads labeled dataset from file with .txt extension. * * @param rsrcPath path to dataset. * @return null if path is incorrect. */ - LabeledDataset loadDatasetFromTxt(String rsrcPath, boolean isFallOnBadData) { + public static LabeledDataset loadDatasetFromTxt(String rsrcPath, boolean isFallOnBadData) { try { - Path path = Paths.get(this.getClass().getClassLoader().getResource(rsrcPath).toURI()); + Path path = Paths.get(LabeledDatasetHelper.class.getClassLoader().getResource(rsrcPath).toURI()); try { return LabeledDatasetLoader.loadFromTxtFile(path, SEPARATOR, false, isFallOnBadData); - } catch (IOException e) { + } + catch (IOException e) { e.printStackTrace(); } - } catch (URISyntaxException e) { + } + catch (URISyntaxException e) { e.printStackTrace(); return null; } http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java index 77d40a6..e986740 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.net.URISyntaxException; import java.nio.file.Path; import java.nio.file.Paths; -import org.apache.ignite.internal.util.IgniteUtils; import org.apache.ignite.ml.math.ExternalizableTest; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.exceptions.CardinalityException; @@ -32,9 +31,13 @@ import org.apache.ignite.ml.structures.LabeledDataset; import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair; import org.apache.ignite.ml.structures.LabeledVector; import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader; +import org.junit.Test; + +import static junit.framework.TestCase.assertEquals; +import static junit.framework.TestCase.fail; /** Tests behaviour of KNNClassificationTest. */ -public class LabeledDatasetTest extends LabeledDatasetHelper implements ExternalizableTest<LabeledDataset> { +public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> { /** */ private static final String KNN_IRIS_TXT = "datasets/knn/iris.txt"; @@ -51,9 +54,8 @@ public class LabeledDatasetTest extends LabeledDatasetHelper implements External private static final String IRIS_MISSED_DATA = "datasets/knn/missed_data.txt"; /** */ + @Test public void testFeatureNames() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - double[][] mtx = new double[][] { {1.0, 1.0}, @@ -71,9 +73,8 @@ public class LabeledDatasetTest extends LabeledDatasetHelper implements External } /** */ + @Test public void testAccessMethods() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - double[][] mtx = new double[][] { {1.0, 1.0}, @@ -98,9 +99,8 @@ public class LabeledDatasetTest extends LabeledDatasetHelper implements External } /** */ + @Test public void testFailOnYNull() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - double[][] mtx = new double[][] { {1.0, 1.0}, @@ -122,9 +122,8 @@ public class LabeledDatasetTest extends LabeledDatasetHelper implements External } /** */ + @Test public void testFailOnXNull() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - double[][] mtx = new double[][] {}; double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0}; @@ -140,18 +139,17 @@ public class LabeledDatasetTest extends LabeledDatasetHelper implements External } /** */ + @Test public void testLoadingCorrectTxtFile() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - LabeledDataset training = loadDatasetFromTxt(KNN_IRIS_TXT, false); + LabeledDataset training = LabeledDatasetHelper.loadDatasetFromTxt(KNN_IRIS_TXT, false); assertEquals(training.rowSize(), 150); } /** */ + @Test public void testLoadingEmptyFile() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - try { - loadDatasetFromTxt(EMPTY_TXT, false); + LabeledDatasetHelper.loadDatasetFromTxt(EMPTY_TXT, false); fail("EmptyFileException"); } catch (EmptyFileException e) { @@ -161,11 +159,10 @@ public class LabeledDatasetTest extends LabeledDatasetHelper implements External } /** */ + @Test public void testLoadingFileWithFirstEmptyRow() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - try { - loadDatasetFromTxt(NO_DATA_TXT, false); + LabeledDatasetHelper.loadDatasetFromTxt(NO_DATA_TXT, false); fail("NoDataException"); } catch (NoDataException e) { @@ -175,19 +172,17 @@ public class LabeledDatasetTest extends LabeledDatasetHelper implements External } /** */ + @Test public void testLoadingFileWithIncorrectData() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - - LabeledDataset training = loadDatasetFromTxt(IRIS_INCORRECT_TXT, false); + LabeledDataset training = LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, false); assertEquals(149, training.rowSize()); } /** */ + @Test public void testFailOnLoadingFileWithIncorrectData() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - try { - loadDatasetFromTxt(IRIS_INCORRECT_TXT, true); + LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, true); fail("FileParsingException"); } catch (FileParsingException e) { @@ -198,9 +193,8 @@ public class LabeledDatasetTest extends LabeledDatasetHelper implements External } /** */ + @Test public void testLoadingFileWithMissedData() throws URISyntaxException, IOException { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - Path path = Paths.get(this.getClass().getClassLoader().getResource(IRIS_MISSED_DATA).toURI()); LabeledDataset training = LabeledDatasetLoader.loadFromTxtFile(path, ",", false, false); @@ -209,9 +203,8 @@ public class LabeledDatasetTest extends LabeledDatasetHelper implements External } /** */ + @Test public void testSplitting() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - double[][] mtx = new double[][] { {1.0, 1.0}, @@ -246,9 +239,8 @@ public class LabeledDatasetTest extends LabeledDatasetHelper implements External } /** */ + @Test public void testLabels() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - double[][] mtx = new double[][] { {1.0, 1.0}, @@ -267,8 +259,6 @@ public class LabeledDatasetTest extends LabeledDatasetHelper implements External /** */ @Override public void testExternalization() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - double[][] mtx = new double[][] { {1.0, 1.0},