Repository: ignite
Updated Branches:
refs/heads/master 568c3e79e -> ee9ca06a8
Advertising
IGNITE-8233: KNN and SVM algorithms don't work when partition doesn't contain
data.
this closes #3807
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/ee9ca06a
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/ee9ca06a
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/ee9ca06a
Branch: refs/heads/master
Commit: ee9ca06a8cbec6eec2b963fb1db1b3f383fc1837
Parents: 568c3e7
Author: dmitrievanthony <dmitrievanth...@gmail.com>
Authored: Fri Apr 13 18:02:37 2018 +0300
Committer: Yury Babak <yba...@gridgain.com>
Committed: Fri Apr 13 18:02:37 2018 +0300
----------------------------------------------------------------------
.../dataset/impl/cache/CacheBasedDataset.java | 14 ++-
.../dataset/impl/cache/util/ComputeUtils.java | 9 +-
.../ml/dataset/impl/local/LocalDataset.java | 16 ++-
.../dataset/impl/local/LocalDatasetBuilder.java | 8 +-
.../classification/KNNClassificationModel.java | 32 +++--
.../impl/cache/CacheBasedDatasetTest.java | 9 +-
.../ignite/ml/knn/KNNClassificationTest.java | 120 ++++++++++--------
.../apache/ignite/ml/knn/KNNRegressionTest.java | 122 +++++++++++--------
.../ignite/ml/knn/LabeledDatasetHelper.java | 45 ++-----
.../ignite/ml/knn/LabeledDatasetTest.java | 54 ++++----
10 files changed, 226 insertions(+), 203 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
----------------------------------------------------------------------
diff --git
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
index 463d496..7428faf 100644
---
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
+++
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
@@ -101,12 +101,16 @@ public class CacheBasedDataset<K, V, C extends
Serializable, D extends AutoClose
partDataBuilder
);
- R res = map.apply(ctx, data, part);
+ if (data != null) {
+ R res = map.apply(ctx, data, part);
- // Saves partition context after update.
- ComputeUtils.saveContext(Ignition.localIgnite(), datasetCacheName,
part, ctx);
+ // Saves partition context after update.
+ ComputeUtils.saveContext(Ignition.localIgnite(),
datasetCacheName, part, ctx);
- return res;
+ return res;
+ }
+
+ return null;
}, reduce, identity);
}
@@ -125,7 +129,7 @@ public class CacheBasedDataset<K, V, C extends
Serializable, D extends AutoClose
partDataBuilder
);
- return map.apply(data, part);
+ return data != null ? map.apply(data, part) : null;
}, reduce, identity);
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
----------------------------------------------------------------------
diff --git
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
index 0785db2..ce2fcfd 100644
---
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
+++
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
@@ -163,9 +163,14 @@ public class ComputeUtils {
qry.setPartition(part);
long cnt = upstreamCache.localSizeLong(part);
- try (QueryCursor<Cache.Entry<K, V>> cursor =
upstreamCache.query(qry)) {
- return partDataBuilder.build(new
UpstreamCursorAdapter<>(cursor.iterator(), cnt), cnt, ctx);
+
+ if (cnt > 0) {
+ try (QueryCursor<Cache.Entry<K, V>> cursor =
upstreamCache.query(qry)) {
+ return partDataBuilder.build(new
UpstreamCursorAdapter<>(cursor.iterator(), cnt), cnt, ctx);
+ }
}
+
+ return null;
});
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
----------------------------------------------------------------------
diff --git
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
index c08b7de..e312b20 100644
---
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
+++
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
@@ -55,8 +55,12 @@ public class LocalDataset<C extends Serializable, D extends
AutoCloseable> imple
R identity) {
R res = identity;
- for (int part = 0; part < ctx.size(); part++)
- res = reduce.apply(res, map.apply(ctx.get(part), data.get(part),
part));
+ for (int part = 0; part < ctx.size(); part++) {
+ D partData = data.get(part);
+
+ if (partData != null)
+ res = reduce.apply(res, map.apply(ctx.get(part), partData,
part));
+ }
return res;
}
@@ -65,8 +69,12 @@ public class LocalDataset<C extends Serializable, D extends
AutoCloseable> imple
@Override public <R> R compute(IgniteBiFunction<D, Integer, R> map,
IgniteBinaryOperator<R> reduce, R identity) {
R res = identity;
- for (int part = 0; part < data.size(); part++)
- res = reduce.apply(res, map.apply(data.get(part), part));
+ for (int part = 0; part < data.size(); part++) {
+ D partData = data.get(part);
+
+ if (partData != null)
+ res = reduce.apply(res, map.apply(partData, part));
+ }
return res;
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
----------------------------------------------------------------------
diff --git
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
index 0dc1ed6..cfc1801 100644
---
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
+++
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
@@ -69,16 +69,16 @@ public class LocalDatasetBuilder<K, V> implements
DatasetBuilder<K, V> {
for (int part = 0; part < partitions; part++) {
int cnt = part == partitions - 1 ? upstreamMap.size() - ptr :
Math.min(partSize, upstreamMap.size() - ptr);
- C ctx = partCtxBuilder.build(
+ C ctx = cnt > 0 ? partCtxBuilder.build(
new IteratorWindow<>(firstKeysIter, k -> new
UpstreamEntry<>(k, upstreamMap.get(k)), cnt),
cnt
- );
+ ) : null;
- D data = partDataBuilder.build(
+ D data = cnt > 0 ? partDataBuilder.build(
new IteratorWindow<>(secondKeysIter, k -> new
UpstreamEntry<>(k, upstreamMap.get(k)), cnt),
cnt,
ctx
- );
+ ) : null;
ctxList.add(ctx);
dataList.add(data);
http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
----------------------------------------------------------------------
diff --git
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
index 693b81d..0f0cc9f 100644
---
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
+++
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
@@ -151,19 +151,29 @@ public class KNNClassificationModel<K, V> implements
Model<Vector, Double>, Expo
*/
@NotNull private LabeledVector[] getKClosestVectors(LabeledDataset<Double,
LabeledVector> trainingData,
TreeMap<Double, Set<Integer>> distanceIdxPairs) {
- LabeledVector[] res = new LabeledVector[k];
- int i = 0;
- final Iterator<Double> iter = distanceIdxPairs.keySet().iterator();
- while (i < k) {
- double key = iter.next();
- Set<Integer> idxs = distanceIdxPairs.get(key);
- for (Integer idx : idxs) {
- res[i] = trainingData.getRow(idx);
- i++;
- if (i >= k)
- break; // go to next while-loop iteration
+ LabeledVector[] res;
+
+ if (trainingData.rowSize() <= k) {
+ res = new LabeledVector[trainingData.rowSize()];
+ for (int i = 0; i < trainingData.rowSize(); i++)
+ res[i] = trainingData.getRow(i);
+ }
+ else {
+ res = new LabeledVector[k];
+ int i = 0;
+ final Iterator<Double> iter = distanceIdxPairs.keySet().iterator();
+ while (i < k) {
+ double key = iter.next();
+ Set<Integer> idxs = distanceIdxPairs.get(key);
+ for (Integer idx : idxs) {
+ res[i] = trainingData.getRow(idx);
+ i++;
+ if (i >= k)
+ break; // go to next while-loop iteration
+ }
}
}
+
return res;
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java
----------------------------------------------------------------------
diff --git
a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java
index dc0e160..16ba044 100644
---
a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java
+++
b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java
@@ -38,6 +38,7 @@ import
org.apache.ignite.internal.processors.cache.distributed.dht.GridDhtPartit
import org.apache.ignite.internal.util.IgniteUtils;
import org.apache.ignite.internal.util.typedef.G;
import org.apache.ignite.lang.IgnitePredicate;
+import org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData;
import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
/**
@@ -81,9 +82,9 @@ public class CacheBasedDatasetTest extends
GridCommonAbstractTest {
CacheBasedDatasetBuilder<Integer, String> builder = new
CacheBasedDatasetBuilder<>(ignite, upstreamCache);
- CacheBasedDataset<Integer, String, Long, AutoCloseable> dataset =
builder.build(
+ CacheBasedDataset<Integer, String, Long, SimpleDatasetData> dataset =
builder.build(
(upstream, upstreamSize) -> upstreamSize,
- (upstream, upstreamSize, ctx) -> null
+ (upstream, upstreamSize, ctx) -> new SimpleDatasetData(new
double[0], 0)
);
assertTrue("Before computation all partitions should not be reserved",
@@ -133,9 +134,9 @@ public class CacheBasedDatasetTest extends
GridCommonAbstractTest {
CacheBasedDatasetBuilder<Integer, String> builder = new
CacheBasedDatasetBuilder<>(ignite, upstreamCache);
- CacheBasedDataset<Integer, String, Long, AutoCloseable> dataset =
builder.build(
+ CacheBasedDataset<Integer, String, Long, SimpleDatasetData> dataset =
builder.build(
(upstream, upstreamSize) -> upstreamSize,
- (upstream, upstreamSize, ctx) -> null
+ (upstream, upstreamSize, ctx) -> new SimpleDatasetData(new
double[0], 0)
);
assertTrue("Before computation all partitions should not be reserved",
http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
----------------------------------------------------------------------
diff --git
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
index 0877fc0..004718e 100644
---
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
+++
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
@@ -17,11 +17,11 @@
package org.apache.ignite.ml.knn;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
-import org.junit.Assert;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
import org.apache.ignite.ml.knn.classification.KNNStrategy;
@@ -29,121 +29,137 @@ import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import static junit.framework.TestCase.assertEquals;
/** Tests behaviour of KNNClassificationTest. */
+@RunWith(Parameterized.class)
public class KNNClassificationTest {
- /** Precision in test checks. */
- private static final double PRECISION = 1e-2;
+ /** Number of parts to be tested. */
+ private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7,
100};
+
+ /** Number of partitions. */
+ @Parameterized.Parameter
+ public int parts;
+
+ /** Parameters. */
+ @Parameterized.Parameters(name = "Data divided on {0} partitions, training
with batch size {1}")
+ public static Iterable<Integer[]> data() {
+ List<Integer[]> res = new ArrayList<>();
+
+ for (int part : partsToBeTested)
+ res.add(new Integer[] {part});
+
+ return res;
+ }
/** */
@Test
- public void binaryClassificationTest() {
-
+ public void testBinaryClassificationTest() {
Map<Integer, double[]> data = new HashMap<>();
- data.put(0, new double[]{1.0, 1.0, 1.0});
- data.put(1, new double[]{1.0, 2.0, 1.0});
- data.put(2, new double[]{2.0, 1.0, 1.0});
- data.put(3, new double[]{-1.0, -1.0, 2.0});
- data.put(4, new double[]{-1.0, -2.0, 2.0});
- data.put(5, new double[]{-2.0, -1.0, 2.0});
+ data.put(0, new double[] {1.0, 1.0, 1.0});
+ data.put(1, new double[] {1.0, 2.0, 1.0});
+ data.put(2, new double[] {2.0, 1.0, 1.0});
+ data.put(3, new double[] {-1.0, -1.0, 2.0});
+ data.put(4, new double[] {-1.0, -2.0, 2.0});
+ data.put(5, new double[] {-2.0, -1.0, 2.0});
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
KNNClassificationModel knnMdl = trainer.fit(
data,
- 2,
+ parts,
(k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
(k, v) -> v[2]
).withK(3)
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(KNNStrategy.SIMPLE);
- Vector firstVector = new DenseLocalOnHeapVector(new double[]{2.0,
2.0});
- Assert.assertEquals(knnMdl.apply(firstVector), 1.0, PRECISION);
- Vector secondVector = new DenseLocalOnHeapVector(new double[]{-2.0,
-2.0});
- Assert.assertEquals(knnMdl.apply(secondVector), 2.0, PRECISION);
+ Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0,
2.0});
+ assertEquals(knnMdl.apply(firstVector), 1.0);
+ Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0,
-2.0});
+ assertEquals(knnMdl.apply(secondVector), 2.0);
}
/** */
@Test
- public void binaryClassificationWithSmallestKTest() {
+ public void testBinaryClassificationWithSmallestKTest() {
Map<Integer, double[]> data = new HashMap<>();
-
- data.put(0, new double[]{1.0, 1.0, 1.0});
- data.put(1, new double[]{1.0, 2.0, 1.0});
- data.put(2, new double[]{2.0, 1.0, 1.0});
- data.put(3, new double[]{-1.0, -1.0, 2.0});
- data.put(4, new double[]{-1.0, -2.0, 2.0});
- data.put(5, new double[]{-2.0, -1.0, 2.0});
+ data.put(0, new double[] {1.0, 1.0, 1.0});
+ data.put(1, new double[] {1.0, 2.0, 1.0});
+ data.put(2, new double[] {2.0, 1.0, 1.0});
+ data.put(3, new double[] {-1.0, -1.0, 2.0});
+ data.put(4, new double[] {-1.0, -2.0, 2.0});
+ data.put(5, new double[] {-2.0, -1.0, 2.0});
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
KNNClassificationModel knnMdl = trainer.fit(
data,
- 2,
+ parts,
(k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
(k, v) -> v[2]
).withK(1)
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(KNNStrategy.SIMPLE);
- Vector firstVector = new DenseLocalOnHeapVector(new double[]{2.0,
2.0});
- Assert.assertEquals(knnMdl.apply(firstVector), 1.0, PRECISION);
- Vector secondVector = new DenseLocalOnHeapVector(new double[]{-2.0,
-2.0});
- Assert.assertEquals(knnMdl.apply(secondVector), 2.0, PRECISION);
+ Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0,
2.0});
+ assertEquals(knnMdl.apply(firstVector), 1.0);
+ Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0,
-2.0});
+ assertEquals(knnMdl.apply(secondVector), 2.0);
}
/** */
@Test
- public void binaryClassificationFarPointsWithSimpleStrategy() {
+ public void testBinaryClassificationFarPointsWithSimpleStrategy() {
Map<Integer, double[]> data = new HashMap<>();
-
- data.put(0, new double[]{10.0, 10.0, 1.0});
- data.put(1, new double[]{10.0, 20.0, 1.0});
- data.put(2, new double[]{-1, -1, 1.0});
- data.put(3, new double[]{-2, -2, 2.0});
- data.put(4, new double[]{-1.0, -2.0, 2.0});
- data.put(5, new double[]{-2.0, -1.0, 2.0});
+ data.put(0, new double[] {10.0, 10.0, 1.0});
+ data.put(1, new double[] {10.0, 20.0, 1.0});
+ data.put(2, new double[] {-1, -1, 1.0});
+ data.put(3, new double[] {-2, -2, 2.0});
+ data.put(4, new double[] {-1.0, -2.0, 2.0});
+ data.put(5, new double[] {-2.0, -1.0, 2.0});
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
KNNClassificationModel knnMdl = trainer.fit(
data,
- 2,
+ parts,
(k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
(k, v) -> v[2]
).withK(3)
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(KNNStrategy.SIMPLE);
- Vector vector = new DenseLocalOnHeapVector(new double[]{-1.01, -1.01});
- Assert.assertEquals(knnMdl.apply(vector), 2.0, PRECISION);
+ Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01,
-1.01});
+ assertEquals(knnMdl.apply(vector), 2.0);
}
/** */
@Test
- public void binaryClassificationFarPointsWithWeightedStrategy() {
+ public void testBinaryClassificationFarPointsWithWeightedStrategy() {
Map<Integer, double[]> data = new HashMap<>();
-
- data.put(0, new double[]{10.0, 10.0, 1.0});
- data.put(1, new double[]{10.0, 20.0, 1.0});
- data.put(2, new double[]{-1, -1, 1.0});
- data.put(3, new double[]{-2, -2, 2.0});
- data.put(4, new double[]{-1.0, -2.0, 2.0});
- data.put(5, new double[]{-2.0, -1.0, 2.0});
+ data.put(0, new double[] {10.0, 10.0, 1.0});
+ data.put(1, new double[] {10.0, 20.0, 1.0});
+ data.put(2, new double[] {-1, -1, 1.0});
+ data.put(3, new double[] {-2, -2, 2.0});
+ data.put(4, new double[] {-1.0, -2.0, 2.0});
+ data.put(5, new double[] {-2.0, -1.0, 2.0});
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
KNNClassificationModel knnMdl = trainer.fit(
data,
- 2,
+ parts,
(k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
(k, v) -> v[2]
).withK(3)
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(KNNStrategy.WEIGHTED);
- Vector vector = new DenseLocalOnHeapVector(new double[]{-1.01, -1.01});
- Assert.assertEquals(knnMdl.apply(vector), 1.0, PRECISION);
+ Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01,
-1.01});
+ assertEquals(knnMdl.apply(vector), 1.0);
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
----------------------------------------------------------------------
diff --git
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
index ce9cae5..0c26ba9 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
@@ -17,6 +17,11 @@
package org.apache.ignite.ml.knn;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.knn.classification.KNNStrategy;
import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
@@ -25,110 +30,125 @@ import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.junit.Assert;
-
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
/**
* Tests for {@link KNNRegressionTrainer}.
*/
+@RunWith(Parameterized.class)
public class KNNRegressionTest {
+ /** Number of parts to be tested. */
+ private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7,
100};
+
+ /** Number of partitions. */
+ @Parameterized.Parameter
+ public int parts;
+
+ /** Parameters. */
+ @Parameterized.Parameters(name = "Data divided on {0} partitions, training
with batch size {1}")
+ public static Iterable<Integer[]> data() {
+ List<Integer[]> res = new ArrayList<>();
+
+ for (int part : partsToBeTested)
+ res.add(new Integer[] {part});
+
+ return res;
+ }
+
/** */
@Test
- public void simpleRegressionWithOneNeighbour() {
+ public void testSimpleRegressionWithOneNeighbour() {
Map<Integer, double[]> data = new HashMap<>();
-
- data.put(0, new double[]{11.0, 0, 0, 0, 0, 0});
- data.put(1, new double[]{12.0, 2.0, 0, 0, 0, 0});
- data.put(2, new double[]{13.0, 0, 3.0, 0, 0, 0});
- data.put(3, new double[]{14.0, 0, 0, 4.0, 0, 0});
- data.put(4, new double[]{15.0, 0, 0, 0, 5.0, 0});
- data.put(5, new double[]{16.0, 0, 0, 0, 0, 6.0});
+ data.put(0, new double[] {11.0, 0, 0, 0, 0, 0});
+ data.put(1, new double[] {12.0, 2.0, 0, 0, 0, 0});
+ data.put(2, new double[] {13.0, 0, 3.0, 0, 0, 0});
+ data.put(3, new double[] {14.0, 0, 0, 4.0, 0, 0});
+ data.put(4, new double[] {15.0, 0, 0, 0, 5.0, 0});
+ data.put(5, new double[] {16.0, 0, 0, 0, 0, 6.0});
KNNRegressionTrainer trainer = new KNNRegressionTrainer();
KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
- new LocalDatasetBuilder<>(data, 2),
+ new LocalDatasetBuilder<>(data, parts),
(k, v) -> Arrays.copyOfRange(v, 1, v.length),
(k, v) -> v[0]
).withK(1)
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(KNNStrategy.SIMPLE);
- Vector vector = new DenseLocalOnHeapVector(new double[]{0, 0, 0, 5.0,
0.0});
+ Vector vector = new DenseLocalOnHeapVector(new double[] {0, 0, 0, 5.0,
0.0});
System.out.println(knnMdl.apply(vector));
Assert.assertEquals(15, knnMdl.apply(vector), 1E-12);
}
/** */
@Test
- public void longly() {
+ public void testLongly() {
Map<Integer, double[]> data = new HashMap<>();
-
- data.put(0, new double[]{60323, 83.0, 234289, 2356, 1590, 107608,
1947});
- data.put(1, new double[]{61122, 88.5, 259426, 2325, 1456, 108632,
1948});
- data.put(2, new double[]{60171, 88.2, 258054, 3682, 1616, 109773,
1949});
- data.put(3, new double[]{61187, 89.5, 284599, 3351, 1650, 110929,
1950});
- data.put(4, new double[]{63221, 96.2, 328975, 2099, 3099, 112075,
1951});
- data.put(5, new double[]{63639, 98.1, 346999, 1932, 3594, 113270,
1952});
- data.put(6, new double[]{64989, 99.0, 365385, 1870, 3547, 115094,
1953});
- data.put(7, new double[]{63761, 100.0, 363112, 3578, 3350, 116219,
1954});
- data.put(8, new double[]{66019, 101.2, 397469, 2904, 3048, 117388,
1955});
- data.put(9, new double[]{68169, 108.4, 442769, 2936, 2798, 120445,
1957});
- data.put(10, new double[]{66513, 110.8, 444546, 4681, 2637, 121950,
1958});
- data.put(11, new double[]{68655, 112.6, 482704, 3813, 2552, 123366,
1959});
- data.put(12, new double[]{69564, 114.2, 502601, 3931, 2514, 125368,
1960});
- data.put(13, new double[]{69331, 115.7, 518173, 4806, 2572, 127852,
1961});
- data.put(14, new double[]{70551, 116.9, 554894, 4007, 2827, 130081,
1962});
+ data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608,
1947});
+ data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632,
1948});
+ data.put(2, new double[] {60171, 88.2, 258054, 3682, 1616, 109773,
1949});
+ data.put(3, new double[] {61187, 89.5, 284599, 3351, 1650, 110929,
1950});
+ data.put(4, new double[] {63221, 96.2, 328975, 2099, 3099, 112075,
1951});
+ data.put(5, new double[] {63639, 98.1, 346999, 1932, 3594, 113270,
1952});
+ data.put(6, new double[] {64989, 99.0, 365385, 1870, 3547, 115094,
1953});
+ data.put(7, new double[] {63761, 100.0, 363112, 3578, 3350, 116219,
1954});
+ data.put(8, new double[] {66019, 101.2, 397469, 2904, 3048, 117388,
1955});
+ data.put(9, new double[] {68169, 108.4, 442769, 2936, 2798, 120445,
1957});
+ data.put(10, new double[] {66513, 110.8, 444546, 4681, 2637, 121950,
1958});
+ data.put(11, new double[] {68655, 112.6, 482704, 3813, 2552, 123366,
1959});
+ data.put(12, new double[] {69564, 114.2, 502601, 3931, 2514, 125368,
1960});
+ data.put(13, new double[] {69331, 115.7, 518173, 4806, 2572, 127852,
1961});
+ data.put(14, new double[] {70551, 116.9, 554894, 4007, 2827, 130081,
1962});
KNNRegressionTrainer trainer = new KNNRegressionTrainer();
KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
- new LocalDatasetBuilder<>(data, 2),
+ new LocalDatasetBuilder<>(data, parts),
(k, v) -> Arrays.copyOfRange(v, 1, v.length),
(k, v) -> v[0]
).withK(3)
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(KNNStrategy.SIMPLE);
- Vector vector = new DenseLocalOnHeapVector(new double[]{104.6, 419180,
2822, 2857, 118734, 1956});
+ Vector vector = new DenseLocalOnHeapVector(new double[] {104.6,
419180, 2822, 2857, 118734, 1956});
System.out.println(knnMdl.apply(vector));
Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
}
/** */
+ @Test
public void testLonglyWithWeightedStrategy() {
Map<Integer, double[]> data = new HashMap<>();
-
- data.put(0, new double[]{60323, 83.0, 234289, 2356, 1590, 107608,
1947});
- data.put(1, new double[]{61122, 88.5, 259426, 2325, 1456, 108632,
1948});
- data.put(2, new double[]{60171, 88.2, 258054, 3682, 1616, 109773,
1949});
- data.put(3, new double[]{61187, 89.5, 284599, 3351, 1650, 110929,
1950});
- data.put(4, new double[]{63221, 96.2, 328975, 2099, 3099, 112075,
1951});
- data.put(5, new double[]{63639, 98.1, 346999, 1932, 3594, 113270,
1952});
- data.put(6, new double[]{64989, 99.0, 365385, 1870, 3547, 115094,
1953});
- data.put(7, new double[]{63761, 100.0, 363112, 3578, 3350, 116219,
1954});
- data.put(8, new double[]{66019, 101.2, 397469, 2904, 3048, 117388,
1955});
- data.put(9, new double[]{68169, 108.4, 442769, 2936, 2798, 120445,
1957});
- data.put(10, new double[]{66513, 110.8, 444546, 4681, 2637, 121950,
1958});
- data.put(11, new double[]{68655, 112.6, 482704, 3813, 2552, 123366,
1959});
- data.put(12, new double[]{69564, 114.2, 502601, 3931, 2514, 125368,
1960});
- data.put(13, new double[]{69331, 115.7, 518173, 4806, 2572, 127852,
1961});
- data.put(14, new double[]{70551, 116.9, 554894, 4007, 2827, 130081,
1962});
+ data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608,
1947});
+ data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632,
1948});
+ data.put(2, new double[] {60171, 88.2, 258054, 3682, 1616, 109773,
1949});
+ data.put(3, new double[] {61187, 89.5, 284599, 3351, 1650, 110929,
1950});
+ data.put(4, new double[] {63221, 96.2, 328975, 2099, 3099, 112075,
1951});
+ data.put(5, new double[] {63639, 98.1, 346999, 1932, 3594, 113270,
1952});
+ data.put(6, new double[] {64989, 99.0, 365385, 1870, 3547, 115094,
1953});
+ data.put(7, new double[] {63761, 100.0, 363112, 3578, 3350, 116219,
1954});
+ data.put(8, new double[] {66019, 101.2, 397469, 2904, 3048, 117388,
1955});
+ data.put(9, new double[] {68169, 108.4, 442769, 2936, 2798, 120445,
1957});
+ data.put(10, new double[] {66513, 110.8, 444546, 4681, 2637, 121950,
1958});
+ data.put(11, new double[] {68655, 112.6, 482704, 3813, 2552, 123366,
1959});
+ data.put(12, new double[] {69564, 114.2, 502601, 3931, 2514, 125368,
1960});
+ data.put(13, new double[] {69331, 115.7, 518173, 4806, 2572, 127852,
1961});
+ data.put(14, new double[] {70551, 116.9, 554894, 4007, 2827, 130081,
1962});
KNNRegressionTrainer trainer = new KNNRegressionTrainer();
KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
- new LocalDatasetBuilder<>(data, 2),
+ new LocalDatasetBuilder<>(data, parts),
(k, v) -> Arrays.copyOfRange(v, 1, v.length),
(k, v) -> v[0]
).withK(3)
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(KNNStrategy.SIMPLE);
- Vector vector = new DenseLocalOnHeapVector(new double[]{104.6, 419180,
2822, 2857, 118734, 1956});
+ Vector vector = new DenseLocalOnHeapVector(new double[] {104.6,
419180, 2822, 2857, 118734, 1956});
System.out.println(knnMdl.apply(vector));
Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
----------------------------------------------------------------------
diff --git
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
index a25b303..dbcdb99 100644
---
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
+++
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
@@ -21,64 +21,33 @@ import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.nio.file.Paths;
-import org.apache.ignite.Ignite;
import org.apache.ignite.ml.structures.LabeledDataset;
import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
-import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
/**
* Base class for decision trees test.
*/
-public class LabeledDatasetHelper extends GridCommonAbstractTest {
- /** Count of nodes. */
- private static final int NODE_COUNT = 4;
-
+public class LabeledDatasetHelper {
/** Separator. */
private static final String SEPARATOR = "\t";
- /** Grid instance. */
- protected Ignite ignite;
-
- /**
- * Default constructor.
- */
- public LabeledDatasetHelper() {
- super(false);
- }
-
- /**
- * {@inheritDoc}
- */
- @Override protected void beforeTest() throws Exception {
- ignite = grid(NODE_COUNT);
- }
-
- /** {@inheritDoc} */
- @Override protected void beforeTestsStarted() throws Exception {
- for (int i = 1; i <= NODE_COUNT; i++)
- startGrid(i);
- }
-
- /** {@inheritDoc} */
- @Override protected void afterTestsStopped() throws Exception {
- stopAllGrids();
- }
-
/**
* Loads labeled dataset from file with .txt extension.
*
* @param rsrcPath path to dataset.
* @return null if path is incorrect.
*/
- LabeledDataset loadDatasetFromTxt(String rsrcPath, boolean
isFallOnBadData) {
+ public static LabeledDataset loadDatasetFromTxt(String rsrcPath, boolean
isFallOnBadData) {
try {
- Path path =
Paths.get(this.getClass().getClassLoader().getResource(rsrcPath).toURI());
+ Path path =
Paths.get(LabeledDatasetHelper.class.getClassLoader().getResource(rsrcPath).toURI());
try {
return LabeledDatasetLoader.loadFromTxtFile(path, SEPARATOR,
false, isFallOnBadData);
- } catch (IOException e) {
+ }
+ catch (IOException e) {
e.printStackTrace();
}
- } catch (URISyntaxException e) {
+ }
+ catch (URISyntaxException e) {
e.printStackTrace();
return null;
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
----------------------------------------------------------------------
diff --git
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
index 77d40a6..e986740 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
@@ -21,7 +21,6 @@ import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.nio.file.Paths;
-import org.apache.ignite.internal.util.IgniteUtils;
import org.apache.ignite.ml.math.ExternalizableTest;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.exceptions.CardinalityException;
@@ -32,9 +31,13 @@ import org.apache.ignite.ml.structures.LabeledDataset;
import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
+import org.junit.Test;
+
+import static junit.framework.TestCase.assertEquals;
+import static junit.framework.TestCase.fail;
/** Tests behaviour of KNNClassificationTest. */
-public class LabeledDatasetTest extends LabeledDatasetHelper implements
ExternalizableTest<LabeledDataset> {
+public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
/** */
private static final String KNN_IRIS_TXT = "datasets/knn/iris.txt";
@@ -51,9 +54,8 @@ public class LabeledDatasetTest extends LabeledDatasetHelper
implements External
private static final String IRIS_MISSED_DATA =
"datasets/knn/missed_data.txt";
/** */
+ @Test
public void testFeatureNames() {
-
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
double[][] mtx =
new double[][] {
{1.0, 1.0},
@@ -71,9 +73,8 @@ public class LabeledDatasetTest extends LabeledDatasetHelper
implements External
}
/** */
+ @Test
public void testAccessMethods() {
-
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
double[][] mtx =
new double[][] {
{1.0, 1.0},
@@ -98,9 +99,8 @@ public class LabeledDatasetTest extends LabeledDatasetHelper
implements External
}
/** */
+ @Test
public void testFailOnYNull() {
-
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
double[][] mtx =
new double[][] {
{1.0, 1.0},
@@ -122,9 +122,8 @@ public class LabeledDatasetTest extends
LabeledDatasetHelper implements External
}
/** */
+ @Test
public void testFailOnXNull() {
-
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
double[][] mtx =
new double[][] {};
double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
@@ -140,18 +139,17 @@ public class LabeledDatasetTest extends
LabeledDatasetHelper implements External
}
/** */
+ @Test
public void testLoadingCorrectTxtFile() {
-
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
- LabeledDataset training = loadDatasetFromTxt(KNN_IRIS_TXT, false);
+ LabeledDataset training =
LabeledDatasetHelper.loadDatasetFromTxt(KNN_IRIS_TXT, false);
assertEquals(training.rowSize(), 150);
}
/** */
+ @Test
public void testLoadingEmptyFile() {
-
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
try {
- loadDatasetFromTxt(EMPTY_TXT, false);
+ LabeledDatasetHelper.loadDatasetFromTxt(EMPTY_TXT, false);
fail("EmptyFileException");
}
catch (EmptyFileException e) {
@@ -161,11 +159,10 @@ public class LabeledDatasetTest extends
LabeledDatasetHelper implements External
}
/** */
+ @Test
public void testLoadingFileWithFirstEmptyRow() {
-
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
try {
- loadDatasetFromTxt(NO_DATA_TXT, false);
+ LabeledDatasetHelper.loadDatasetFromTxt(NO_DATA_TXT, false);
fail("NoDataException");
}
catch (NoDataException e) {
@@ -175,19 +172,17 @@ public class LabeledDatasetTest extends
LabeledDatasetHelper implements External
}
/** */
+ @Test
public void testLoadingFileWithIncorrectData() {
-
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
- LabeledDataset training = loadDatasetFromTxt(IRIS_INCORRECT_TXT,
false);
+ LabeledDataset training =
LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, false);
assertEquals(149, training.rowSize());
}
/** */
+ @Test
public void testFailOnLoadingFileWithIncorrectData() {
-
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
try {
- loadDatasetFromTxt(IRIS_INCORRECT_TXT, true);
+ LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, true);
fail("FileParsingException");
}
catch (FileParsingException e) {
@@ -198,9 +193,8 @@ public class LabeledDatasetTest extends
LabeledDatasetHelper implements External
}
/** */
+ @Test
public void testLoadingFileWithMissedData() throws URISyntaxException,
IOException {
-
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
Path path =
Paths.get(this.getClass().getClassLoader().getResource(IRIS_MISSED_DATA).toURI());
LabeledDataset training = LabeledDatasetLoader.loadFromTxtFile(path,
",", false, false);
@@ -209,9 +203,8 @@ public class LabeledDatasetTest extends
LabeledDatasetHelper implements External
}
/** */
+ @Test
public void testSplitting() {
-
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
double[][] mtx =
new double[][] {
{1.0, 1.0},
@@ -246,9 +239,8 @@ public class LabeledDatasetTest extends
LabeledDatasetHelper implements External
}
/** */
+ @Test
public void testLabels() {
-
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
double[][] mtx =
new double[][] {
{1.0, 1.0},
@@ -267,8 +259,6 @@ public class LabeledDatasetTest extends
LabeledDatasetHelper implements External
/** */
@Override public void testExternalization() {
-
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
double[][] mtx =
new double[][] {
{1.0, 1.0},