IGNITE-7829: Adopt kNN regression example to the new Partitioned Dataset this closes #3798
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/8550d61b Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/8550d61b Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/8550d61b Branch: refs/heads/ignite-7708 Commit: 8550d61b6b39625579eb7f69f4d1218b78f7cc5b Parents: 9be3357 Author: zaleslaw <zaleslaw....@gmail.com> Authored: Fri Apr 13 12:49:56 2018 +0300 Committer: YuriBabak <y.ch...@gmail.com> Committed: Fri Apr 13 12:49:56 2018 +0300 ---------------------------------------------------------------------- .../ml/knn/KNNClassificationExample.java | 4 +- .../examples/ml/knn/KNNRegressionExample.java | 310 +++++++++++++++++++ .../java/org/apache/ignite/ml/knn/KNNUtils.java | 10 +- .../classification/KNNClassificationModel.java | 9 +- .../ml/knn/partitions/KNNPartitionContext.java | 28 -- .../ignite/ml/knn/partitions/package-info.java | 22 -- .../ml/knn/regression/KNNRegressionModel.java | 7 +- .../partition/LabelPartitionContext.java | 28 -- .../LabelPartitionDataBuilderOnHeap.java | 1 - .../svm/SVMLinearBinaryClassificationModel.java | 3 + .../SVMLinearBinaryClassificationTrainer.java | 9 +- .../SVMLinearMultiClassClassificationModel.java | 3 + ...VMLinearMultiClassClassificationTrainer.java | 8 +- .../ignite/ml/svm/SVMPartitionContext.java | 28 -- .../org/apache/ignite/ml/knn/BaseKNNTest.java | 89 ------ .../ignite/ml/knn/KNNClassificationTest.java | 110 +++---- .../apache/ignite/ml/knn/KNNRegressionTest.java | 104 +++---- .../ignite/ml/knn/LabeledDatasetHelper.java | 87 ++++++ .../ignite/ml/knn/LabeledDatasetTest.java | 2 +- 19 files changed, 536 insertions(+), 326 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java index 39a8431..15375a1 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java @@ -80,7 +80,7 @@ public class KNNClassificationExample { double prediction = knnMdl.apply(new DenseLocalOnHeapVector(inputs)); totalAmount++; - if(groundTruth != prediction) + if (groundTruth != prediction) amountOfErrors++; System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); @@ -89,7 +89,7 @@ public class KNNClassificationExample { System.out.println(">>> ---------------------------------"); System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount)); } }); http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java new file mode 100644 index 0000000..76a07cd --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.knn; + +import java.util.Arrays; +import java.util.UUID; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; +import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer; +import org.apache.ignite.ml.knn.classification.KNNStrategy; +import org.apache.ignite.ml.knn.regression.KNNRegressionModel; +import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer; +import org.apache.ignite.ml.math.distances.ManhattanDistance; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.thread.IgniteThread; + +/** + * Run kNN regression trainer over distributed dataset. + * + * @see KNNClassificationTrainer + */ +public class KNNRegressionExample { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + System.out.println(); + System.out.println(">>> kNN regression over cached dataset usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + KNNRegressionExample.class.getSimpleName(), () -> { + IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); + + KNNRegressionTrainer trainer = new KNNRegressionTrainer(); + + KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit( + new CacheBasedDatasetBuilder<>(ignite, dataCache), + (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> v[0] + ).withK(5) + .withDistanceMeasure(new ManhattanDistance()) + .withStrategy(KNNStrategy.WEIGHTED); + + int totalAmount = 0; + // Calculate mean squared error (MSE) + double mse = 0.0; + // Calculate mean absolute error (MAE) + double mae = 0.0; + + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; + + double prediction = knnMdl.apply(new DenseLocalOnHeapVector(inputs)); + + mse += Math.pow(prediction - groundTruth, 2.0); + mae += Math.abs(prediction - groundTruth); + + totalAmount++; + } + + mse = mse / totalAmount; + System.out.println("\n>>> Mean squared error (MSE) " + mse); + + mae = mae / totalAmount; + System.out.println("\n>>> Mean absolute error (MAE) " + mae); + } + }); + + igniteThread.start(); + igniteThread.join(); + } + } + + /** + * Fills cache with data and returns it. + * + * @param ignite Ignite instance. + * @return Filled Ignite Cache. + */ + private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) { + CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>(); + cacheConfiguration.setName("TEST_" + UUID.randomUUID()); + cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); + + IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration); + + for (int i = 0; i < data.length; i++) + cache.put(i, data[i]); + + return cache; + } + + /** The Iris dataset. */ + private static final double[][] data = { + {199, 125, 256, 6000, 256, 16, 128}, + {253, 29, 8000, 32000, 32, 8, 32}, + {132, 29, 8000, 16000, 32, 8, 16}, + {290, 26, 8000, 32000, 64, 8, 32}, + {381, 23, 16000, 32000, 64, 16, 32}, + {749, 23, 16000, 64000, 64, 16, 32}, + {1238, 23, 32000, 64000, 128, 32, 64}, + {23, 400, 1000, 3000, 0, 1, 2}, + {24, 400, 512, 3500, 4, 1, 6}, + {70, 60, 2000, 8000, 65, 1, 8}, + {117, 50, 4000, 16000, 65, 1, 8}, + {15, 350, 64, 64, 0, 1, 4}, + {64, 200, 512, 16000, 0, 4, 32}, + {23, 167, 524, 2000, 8, 4, 15}, + {29, 143, 512, 5000, 0, 7, 32}, + {22, 143, 1000, 2000, 0, 5, 16}, + {124, 110, 5000, 5000, 142, 8, 64}, + {35, 143, 1500, 6300, 0, 5, 32}, + {39, 143, 3100, 6200, 0, 5, 20}, + {40, 143, 2300, 6200, 0, 6, 64}, + {45, 110, 3100, 6200, 0, 6, 64}, + {28, 320, 128, 6000, 0, 1, 12}, + {21, 320, 512, 2000, 4, 1, 3}, + {28, 320, 256, 6000, 0, 1, 6}, + {22, 320, 256, 3000, 4, 1, 3}, + {28, 320, 512, 5000, 4, 1, 5}, + {27, 320, 256, 5000, 4, 1, 6}, + {102, 25, 1310, 2620, 131, 12, 24}, + {74, 50, 2620, 10480, 30, 12, 24}, + {138, 56, 5240, 20970, 30, 12, 24}, + {136, 64, 5240, 20970, 30, 12, 24}, + {23, 50, 500, 2000, 8, 1, 4}, + {29, 50, 1000, 4000, 8, 1, 5}, + {44, 50, 2000, 8000, 8, 1, 5}, + {30, 50, 1000, 4000, 8, 3, 5}, + {41, 50, 1000, 8000, 8, 3, 5}, + {74, 50, 2000, 16000, 8, 3, 5}, + {54, 133, 1000, 12000, 9, 3, 12}, + {41, 133, 1000, 8000, 9, 3, 12}, + {18, 810, 512, 512, 8, 1, 1}, + {28, 810, 1000, 5000, 0, 1, 1}, + {36, 320, 512, 8000, 4, 1, 5}, + {38, 200, 512, 8000, 8, 1, 8}, + {34, 700, 384, 8000, 0, 1, 1}, + {19, 700, 256, 2000, 0, 1, 1}, + {72, 140, 1000, 16000, 16, 1, 3}, + {36, 200, 1000, 8000, 0, 1, 2}, + {30, 110, 1000, 4000, 16, 1, 2}, + {56, 110, 1000, 12000, 16, 1, 2}, + {42, 220, 1000, 8000, 16, 1, 2}, + {34, 800, 256, 8000, 0, 1, 4}, + {19, 125, 512, 1000, 0, 8, 20}, + {75, 75, 2000, 8000, 64, 1, 38}, + {113, 75, 2000, 16000, 64, 1, 38}, + {157, 75, 2000, 16000, 128, 1, 38}, + {18, 90, 256, 1000, 0, 3, 10}, + {20, 105, 256, 2000, 0, 3, 10}, + {28, 105, 1000, 4000, 0, 3, 24}, + {33, 105, 2000, 4000, 8, 3, 19}, + {47, 75, 2000, 8000, 8, 3, 24}, + {54, 75, 3000, 8000, 8, 3, 48}, + {20, 175, 256, 2000, 0, 3, 24}, + {23, 300, 768, 3000, 0, 6, 24}, + {25, 300, 768, 3000, 6, 6, 24}, + {52, 300, 768, 12000, 6, 6, 24}, + {27, 300, 768, 4500, 0, 1, 24}, + {50, 300, 384, 12000, 6, 1, 24}, + {18, 300, 192, 768, 6, 6, 24}, + {53, 180, 768, 12000, 6, 1, 31}, + {23, 330, 1000, 3000, 0, 2, 4}, + {30, 300, 1000, 4000, 8, 3, 64}, + {73, 300, 1000, 16000, 8, 2, 112}, + {20, 330, 1000, 2000, 0, 1, 2}, + {25, 330, 1000, 4000, 0, 3, 6}, + {28, 140, 2000, 4000, 0, 3, 6}, + {29, 140, 2000, 4000, 0, 4, 8}, + {32, 140, 2000, 4000, 8, 1, 20}, + {175, 140, 2000, 32000, 32, 1, 20}, + {57, 140, 2000, 8000, 32, 1, 54}, + {181, 140, 2000, 32000, 32, 1, 54}, + {32, 140, 2000, 4000, 8, 1, 20}, + {82, 57, 4000, 16000, 1, 6, 12}, + {171, 57, 4000, 24000, 64, 12, 16}, + {361, 26, 16000, 32000, 64, 16, 24}, + {350, 26, 16000, 32000, 64, 8, 24}, + {220, 26, 8000, 32000, 0, 8, 24}, + {113, 26, 8000, 16000, 0, 8, 16}, + {15, 480, 96, 512, 0, 1, 1}, + {21, 203, 1000, 2000, 0, 1, 5}, + {35, 115, 512, 6000, 16, 1, 6}, + {18, 1100, 512, 1500, 0, 1, 1}, + {20, 1100, 768, 2000, 0, 1, 1}, + {20, 600, 768, 2000, 0, 1, 1}, + {28, 400, 2000, 4000, 0, 1, 1}, + {45, 400, 4000, 8000, 0, 1, 1}, + {18, 900, 1000, 1000, 0, 1, 2}, + {17, 900, 512, 1000, 0, 1, 2}, + {26, 900, 1000, 4000, 4, 1, 2}, + {28, 900, 1000, 4000, 8, 1, 2}, + {28, 900, 2000, 4000, 0, 3, 6}, + {31, 225, 2000, 4000, 8, 3, 6}, + {42, 180, 2000, 8000, 8, 1, 6}, + {76, 185, 2000, 16000, 16, 1, 6}, + {76, 180, 2000, 16000, 16, 1, 6}, + {26, 225, 1000, 4000, 2, 3, 6}, + {59, 25, 2000, 12000, 8, 1, 4}, + {65, 25, 2000, 12000, 16, 3, 5}, + {101, 17, 4000, 16000, 8, 6, 12}, + {116, 17, 4000, 16000, 32, 6, 12}, + {18, 1500, 768, 1000, 0, 0, 0}, + {20, 1500, 768, 2000, 0, 0, 0}, + {20, 800, 768, 2000, 0, 0, 0}, + {30, 50, 2000, 4000, 0, 3, 6}, + {44, 50, 2000, 8000, 8, 3, 6}, + {82, 50, 2000, 16000, 24, 1, 6}, + {128, 50, 8000, 16000, 48, 1, 10}, + {37, 100, 1000, 8000, 0, 2, 6}, + {46, 100, 1000, 8000, 24, 2, 6}, + {46, 100, 1000, 8000, 24, 3, 6}, + {80, 50, 2000, 16000, 12, 3, 16}, + {88, 50, 2000, 16000, 24, 6, 16}, + {33, 150, 512, 4000, 0, 8, 128}, + {46, 115, 2000, 8000, 16, 1, 3}, + {29, 115, 2000, 4000, 2, 1, 5}, + {53, 92, 2000, 8000, 32, 1, 6}, + {41, 92, 2000, 8000, 4, 1, 6}, + {86, 75, 4000, 16000, 16, 1, 6}, + {95, 60, 4000, 16000, 32, 1, 6}, + {107, 60, 2000, 16000, 64, 5, 8}, + {117, 60, 4000, 16000, 64, 5, 8}, + {119, 50, 4000, 16000, 64, 5, 10}, + {120, 72, 4000, 16000, 64, 8, 16}, + {48, 72, 2000, 8000, 16, 6, 8}, + {126, 40, 8000, 16000, 32, 8, 16}, + {266, 40, 8000, 32000, 64, 8, 24}, + {270, 35, 8000, 32000, 64, 8, 24}, + {426, 38, 16000, 32000, 128, 16, 32}, + {151, 48, 4000, 24000, 32, 8, 24}, + {267, 38, 8000, 32000, 64, 8, 24}, + {603, 30, 16000, 32000, 256, 16, 24}, + {19, 112, 1000, 1000, 0, 1, 4}, + {21, 84, 1000, 2000, 0, 1, 6}, + {26, 56, 1000, 4000, 0, 1, 6}, + {35, 56, 2000, 6000, 0, 1, 8}, + {41, 56, 2000, 8000, 0, 1, 8}, + {47, 56, 4000, 8000, 0, 1, 8}, + {62, 56, 4000, 12000, 0, 1, 8}, + {78, 56, 4000, 16000, 0, 1, 8}, + {80, 38, 4000, 8000, 32, 16, 32}, + {142, 38, 8000, 16000, 64, 4, 8}, + {281, 38, 8000, 24000, 160, 4, 8}, + {190, 38, 4000, 16000, 128, 16, 32}, + {21, 200, 1000, 2000, 0, 1, 2}, + {25, 200, 1000, 4000, 0, 1, 4}, + {67, 200, 2000, 8000, 64, 1, 5}, + {24, 250, 512, 4000, 0, 1, 7}, + {24, 250, 512, 4000, 0, 4, 7}, + {64, 250, 1000, 16000, 1, 1, 8}, + {25, 160, 512, 4000, 2, 1, 5}, + {20, 160, 512, 2000, 2, 3, 8}, + {29, 160, 1000, 4000, 8, 1, 14}, + {43, 160, 1000, 8000, 16, 1, 14}, + {53, 160, 2000, 8000, 32, 1, 13}, + {19, 240, 512, 1000, 8, 1, 3}, + {22, 240, 512, 2000, 8, 1, 5}, + {31, 105, 2000, 4000, 8, 3, 8}, + {41, 105, 2000, 6000, 16, 6, 16}, + {47, 105, 2000, 8000, 16, 4, 14}, + {99, 52, 4000, 16000, 32, 4, 12}, + {67, 70, 4000, 12000, 8, 6, 8}, + {81, 59, 4000, 12000, 32, 6, 12}, + {149, 59, 8000, 16000, 64, 12, 24}, + {183, 26, 8000, 24000, 32, 8, 16}, + {275, 26, 8000, 32000, 64, 12, 16}, + {382, 26, 8000, 32000, 128, 24, 32}, + {56, 116, 2000, 8000, 32, 5, 28}, + {182, 50, 2000, 32000, 24, 6, 26}, + {227, 50, 2000, 32000, 48, 26, 52}, + {341, 50, 2000, 32000, 112, 52, 104}, + {360, 50, 4000, 32000, 112, 52, 104}, + {919, 30, 8000, 64000, 96, 12, 176}, + {978, 30, 8000, 64000, 128, 12, 176}, + {24, 180, 262, 4000, 0, 1, 3}, + {37, 124, 1000, 8000, 0, 1, 8}, + {50, 98, 1000, 8000, 32, 2, 8}, + {41, 125, 2000, 8000, 0, 2, 14}, + {47, 480, 512, 8000, 32, 0, 0}, + {25, 480, 1000, 4000, 0, 0, 0} + }; +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java index 88fa70f..716eb52 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java @@ -20,7 +20,7 @@ package org.apache.ignite.ml.knn; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; -import org.apache.ignite.ml.knn.partitions.KNNPartitionContext; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.structures.LabeledDataset; import org.apache.ignite.ml.structures.LabeledVector; @@ -39,18 +39,18 @@ public class KNNUtils { * @param lbExtractor Label extractor. * @return Dataset. */ - @Nullable public static <K, V> Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> buildDataset(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - PartitionDataBuilder<K, V, KNNPartitionContext, LabeledDataset<Double, LabeledVector>> partDataBuilder + @Nullable public static <K, V> Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> buildDataset(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + PartitionDataBuilder<K, V, EmptyContext, LabeledDataset<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>( featureExtractor, lbExtractor ); - Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> dataset = null; + Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset = null; if (datasetBuilder != null) { dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new KNNPartitionContext(), + (upstream, upstreamSize) -> new EmptyContext(), partDataBuilder ); } http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java index 373f822..693b81d 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java @@ -32,7 +32,7 @@ import org.apache.ignite.ml.Exportable; import org.apache.ignite.ml.Exporter; import org.apache.ignite.ml.Model; import org.apache.ignite.ml.dataset.Dataset; -import org.apache.ignite.ml.knn.partitions.KNNPartitionContext; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.distances.DistanceMeasure; import org.apache.ignite.ml.math.distances.EuclideanDistance; @@ -44,6 +44,9 @@ import org.jetbrains.annotations.NotNull; * kNN algorithm model to solve multi-class classification task. */ public class KNNClassificationModel<K, V> implements Model<Vector, Double>, Exportable<KNNModelFormat> { + /** */ + private static final long serialVersionUID = -127386523291350345L; + /** Amount of nearest neighbors. */ protected int k = 5; @@ -54,13 +57,13 @@ public class KNNClassificationModel<K, V> implements Model<Vector, Double>, Expo protected KNNStrategy stgy = KNNStrategy.SIMPLE; /** Dataset. */ - private Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> dataset; + private Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset; /** * Builds the model via prepared dataset. * @param dataset Specially prepared object to run algorithm over it. */ - public KNNClassificationModel(Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> dataset) { + public KNNClassificationModel(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset) { this.dataset = dataset; } http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/KNNPartitionContext.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/KNNPartitionContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/KNNPartitionContext.java deleted file mode 100644 index 0081612..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/KNNPartitionContext.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.knn.partitions; - -import java.io.Serializable; - -/** - * Partition context of the kNN classification algorithm. - */ -public class KNNPartitionContext implements Serializable { - /** */ - private static final long serialVersionUID = -7212307112344430126L; -} http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/package-info.java deleted file mode 100644 index 951a849..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * Contains helper classes for kNN classification algorithms. - */ -package org.apache.ignite.ml.knn.partitions; http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionModel.java index cabc143..f5def43 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionModel.java @@ -17,8 +17,8 @@ package org.apache.ignite.ml.knn.regression; import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; import org.apache.ignite.ml.knn.classification.KNNClassificationModel; -import org.apache.ignite.ml.knn.partitions.KNNPartitionContext; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException; import org.apache.ignite.ml.structures.LabeledDataset; @@ -38,11 +38,14 @@ import java.util.List; * </ul> */ public class KNNRegressionModel<K,V> extends KNNClassificationModel<K,V> { + /** */ + private static final long serialVersionUID = -721836321291120543L; + /** * Builds the model via prepared dataset. * @param dataset Specially prepared object to run algorithm over it. */ - public KNNRegressionModel(Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> dataset) { + public KNNRegressionModel(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset) { super(dataset); } http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionContext.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionContext.java deleted file mode 100644 index 1069ff8..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionContext.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.structures.partition; - -import java.io.Serializable; - -/** - * Base partition context. - */ -public class LabelPartitionContext implements Serializable { - /** */ - private static final long serialVersionUID = -7412302212344430126L; -} http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java index 14c053e..4fba028 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java @@ -22,7 +22,6 @@ import java.util.Iterator; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.structures.LabeledDataset; /** * Partition data builder that builds {@link LabelPartitionDataOnHeap}. http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java index dace8c6..f806fb8 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java @@ -28,6 +28,9 @@ import org.apache.ignite.ml.math.Vector; * Base class for SVM linear classification model. */ public class SVMLinearBinaryClassificationModel implements Model<Vector, Double>, Exportable<SVMLinearBinaryClassificationModel>, Serializable { + /** */ + private static final long serialVersionUID = -996984622291440226L; + /** Output label format. -1 and +1 for false value and raw distances from the separating hyperplane otherwise. */ private boolean isKeepingRawLabels = false; http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java index 7f11e20..d56848c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java @@ -18,6 +18,7 @@ package org.apache.ignite.ml.svm; import java.util.concurrent.ThreadLocalRandom; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap; import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; import org.apache.ignite.ml.dataset.Dataset; @@ -59,15 +60,15 @@ public class SVMLinearBinaryClassificationTrainer implements SingleLabelDatasetT assert datasetBuilder != null; - PartitionDataBuilder<K, V, SVMPartitionContext, LabeledDataset<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>( + PartitionDataBuilder<K, V, EmptyContext, LabeledDataset<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>( featureExtractor, lbExtractor ); Vector weights; - try(Dataset<SVMPartitionContext, LabeledDataset<Double, LabeledVector>> dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new SVMPartitionContext(), + try(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset = datasetBuilder.build( + (upstream, upstreamSize) -> new EmptyContext(), partDataBuilder )) { final int cols = dataset.compute(data -> data.colSize(), (a, b) -> a == null ? b : a); @@ -90,7 +91,7 @@ public class SVMLinearBinaryClassificationTrainer implements SingleLabelDatasetT } /** */ - private Vector calculateUpdates(Vector weights, Dataset<SVMPartitionContext, LabeledDataset<Double, LabeledVector>> dataset) { + private Vector calculateUpdates(Vector weights, Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset) { return dataset.compute(data -> { Vector copiedWeights = weights.copy(); Vector deltaWeights = initializeWeightsWithZeros(weights.size()); http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java index 5879ef0..bbec791 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java @@ -29,6 +29,9 @@ import org.apache.ignite.ml.math.Vector; /** Base class for multi-classification model for set of SVM classifiers. */ public class SVMLinearMultiClassClassificationModel implements Model<Vector, Double>, Exportable<SVMLinearMultiClassClassificationModel>, Serializable { + /** */ + private static final long serialVersionUID = -667986511191350227L; + /** List of models associated with each class. */ private Map<Double, SVMLinearBinaryClassificationModel> models; http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java index 88c342d..4e081c6 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java @@ -24,12 +24,12 @@ import java.util.List; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.structures.partition.LabelPartitionContext; import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap; import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap; @@ -89,12 +89,12 @@ public class SVMLinearMultiClassClassificationTrainer private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) { assert datasetBuilder != null; - PartitionDataBuilder<K, V, LabelPartitionContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor); + PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor); List<Double> res = new ArrayList<>(); - try (Dataset<LabelPartitionContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new LabelPartitionContext(), + try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build( + (upstream, upstreamSize) -> new EmptyContext(), partDataBuilder )) { final Set<Double> clsLabels = dataset.compute(data -> { http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionContext.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionContext.java deleted file mode 100644 index 0aee0fb..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionContext.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.svm; - -import java.io.Serializable; - -/** - * Partition context of the SVM classification algorithm. - */ -public class SVMPartitionContext implements Serializable { - /** */ - private static final long serialVersionUID = -7212307112344430126L; -} http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java deleted file mode 100644 index aeac2cf..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.knn; - -import java.io.IOException; -import java.net.URISyntaxException; -import java.nio.file.Path; -import java.nio.file.Paths; -import org.apache.ignite.Ignite; -import org.apache.ignite.ml.structures.LabeledDataset; -import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader; -import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; - -/** - * Base class for decision trees test. - */ -public class BaseKNNTest extends GridCommonAbstractTest { - /** Count of nodes. */ - private static final int NODE_COUNT = 4; - - /** Separator. */ - private static final String SEPARATOR = "\t"; - - /** Grid instance. */ - protected Ignite ignite; - - /** - * Default constructor. - */ - public BaseKNNTest() { - super(false); - } - - /** - * {@inheritDoc} - */ - @Override protected void beforeTest() throws Exception { - ignite = grid(NODE_COUNT); - } - - /** {@inheritDoc} */ - @Override protected void beforeTestsStarted() throws Exception { - for (int i = 1; i <= NODE_COUNT; i++) - startGrid(i); - } - - /** {@inheritDoc} */ - @Override protected void afterTestsStopped() throws Exception { - stopAllGrids(); - } - - /** - * Loads labeled dataset from file with .txt extension. - * - * @param rsrcPath path to dataset. - * @return null if path is incorrect. - */ - LabeledDataset loadDatasetFromTxt(String rsrcPath, boolean isFallOnBadData) { - try { - Path path = Paths.get(this.getClass().getClassLoader().getResource(rsrcPath).toURI()); - try { - return LabeledDatasetLoader.loadFromTxtFile(path, SEPARATOR, false, isFallOnBadData); - } - catch (IOException e) { - e.printStackTrace(); - } - } - catch (URISyntaxException e) { - e.printStackTrace(); - return null; - } - return null; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java index b27fcba..0877fc0 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java @@ -17,31 +17,35 @@ package org.apache.ignite.ml.knn; -import org.apache.ignite.internal.util.IgniteUtils; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import org.junit.Assert; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.knn.classification.KNNClassificationModel; import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer; import org.apache.ignite.ml.knn.classification.KNNStrategy; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; +import org.junit.Test; /** Tests behaviour of KNNClassificationTest. */ -public class KNNClassificationTest extends BaseKNNTest { +public class KNNClassificationTest { + /** Precision in test checks. */ + private static final double PRECISION = 1e-2; + /** */ - public void testBinaryClassificationTest() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + @Test + public void binaryClassificationTest() { Map<Integer, double[]> data = new HashMap<>(); - data.put(0, new double[] {1.0, 1.0, 1.0}); - data.put(1, new double[] {1.0, 2.0, 1.0}); - data.put(2, new double[] {2.0, 1.0, 1.0}); - data.put(3, new double[] {-1.0, -1.0, 2.0}); - data.put(4, new double[] {-1.0, -2.0, 2.0}); - data.put(5, new double[] {-2.0, -1.0, 2.0}); + data.put(0, new double[]{1.0, 1.0, 1.0}); + data.put(1, new double[]{1.0, 2.0, 1.0}); + data.put(2, new double[]{2.0, 1.0, 1.0}); + data.put(3, new double[]{-1.0, -1.0, 2.0}); + data.put(4, new double[]{-1.0, -2.0, 2.0}); + data.put(5, new double[]{-2.0, -1.0, 2.0}); KNNClassificationTrainer trainer = new KNNClassificationTrainer(); @@ -54,23 +58,23 @@ public class KNNClassificationTest extends BaseKNNTest { .withDistanceMeasure(new EuclideanDistance()) .withStrategy(KNNStrategy.SIMPLE); - Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0, 2.0}); - assertEquals(knnMdl.apply(firstVector), 1.0); - Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0, -2.0}); - assertEquals(knnMdl.apply(secondVector), 2.0); + Vector firstVector = new DenseLocalOnHeapVector(new double[]{2.0, 2.0}); + Assert.assertEquals(knnMdl.apply(firstVector), 1.0, PRECISION); + Vector secondVector = new DenseLocalOnHeapVector(new double[]{-2.0, -2.0}); + Assert.assertEquals(knnMdl.apply(secondVector), 2.0, PRECISION); } /** */ - public void testBinaryClassificationWithSmallestKTest() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - + @Test + public void binaryClassificationWithSmallestKTest() { Map<Integer, double[]> data = new HashMap<>(); - data.put(0, new double[] {1.0, 1.0, 1.0}); - data.put(1, new double[] {1.0, 2.0, 1.0}); - data.put(2, new double[] {2.0, 1.0, 1.0}); - data.put(3, new double[] {-1.0, -1.0, 2.0}); - data.put(4, new double[] {-1.0, -2.0, 2.0}); - data.put(5, new double[] {-2.0, -1.0, 2.0}); + + data.put(0, new double[]{1.0, 1.0, 1.0}); + data.put(1, new double[]{1.0, 2.0, 1.0}); + data.put(2, new double[]{2.0, 1.0, 1.0}); + data.put(3, new double[]{-1.0, -1.0, 2.0}); + data.put(4, new double[]{-1.0, -2.0, 2.0}); + data.put(5, new double[]{-2.0, -1.0, 2.0}); KNNClassificationTrainer trainer = new KNNClassificationTrainer(); @@ -83,23 +87,23 @@ public class KNNClassificationTest extends BaseKNNTest { .withDistanceMeasure(new EuclideanDistance()) .withStrategy(KNNStrategy.SIMPLE); - Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0, 2.0}); - assertEquals(knnMdl.apply(firstVector), 1.0); - Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0, -2.0}); - assertEquals(knnMdl.apply(secondVector), 2.0); + Vector firstVector = new DenseLocalOnHeapVector(new double[]{2.0, 2.0}); + Assert.assertEquals(knnMdl.apply(firstVector), 1.0, PRECISION); + Vector secondVector = new DenseLocalOnHeapVector(new double[]{-2.0, -2.0}); + Assert.assertEquals(knnMdl.apply(secondVector), 2.0, PRECISION); } /** */ - public void testBinaryClassificationFarPointsWithSimpleStrategy() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - + @Test + public void binaryClassificationFarPointsWithSimpleStrategy() { Map<Integer, double[]> data = new HashMap<>(); - data.put(0, new double[] {10.0, 10.0, 1.0}); - data.put(1, new double[] {10.0, 20.0, 1.0}); - data.put(2, new double[] {-1, -1, 1.0}); - data.put(3, new double[] {-2, -2, 2.0}); - data.put(4, new double[] {-1.0, -2.0, 2.0}); - data.put(5, new double[] {-2.0, -1.0, 2.0}); + + data.put(0, new double[]{10.0, 10.0, 1.0}); + data.put(1, new double[]{10.0, 20.0, 1.0}); + data.put(2, new double[]{-1, -1, 1.0}); + data.put(3, new double[]{-2, -2, 2.0}); + data.put(4, new double[]{-1.0, -2.0, 2.0}); + data.put(5, new double[]{-2.0, -1.0, 2.0}); KNNClassificationTrainer trainer = new KNNClassificationTrainer(); @@ -112,21 +116,21 @@ public class KNNClassificationTest extends BaseKNNTest { .withDistanceMeasure(new EuclideanDistance()) .withStrategy(KNNStrategy.SIMPLE); - Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01, -1.01}); - assertEquals(knnMdl.apply(vector), 2.0); + Vector vector = new DenseLocalOnHeapVector(new double[]{-1.01, -1.01}); + Assert.assertEquals(knnMdl.apply(vector), 2.0, PRECISION); } /** */ - public void testBinaryClassificationFarPointsWithWeightedStrategy() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - + @Test + public void binaryClassificationFarPointsWithWeightedStrategy() { Map<Integer, double[]> data = new HashMap<>(); - data.put(0, new double[] {10.0, 10.0, 1.0}); - data.put(1, new double[] {10.0, 20.0, 1.0}); - data.put(2, new double[] {-1, -1, 1.0}); - data.put(3, new double[] {-2, -2, 2.0}); - data.put(4, new double[] {-1.0, -2.0, 2.0}); - data.put(5, new double[] {-2.0, -1.0, 2.0}); + + data.put(0, new double[]{10.0, 10.0, 1.0}); + data.put(1, new double[]{10.0, 20.0, 1.0}); + data.put(2, new double[]{-1, -1, 1.0}); + data.put(3, new double[]{-2, -2, 2.0}); + data.put(4, new double[]{-1.0, -2.0, 2.0}); + data.put(5, new double[]{-2.0, -1.0, 2.0}); KNNClassificationTrainer trainer = new KNNClassificationTrainer(); @@ -139,7 +143,7 @@ public class KNNClassificationTest extends BaseKNNTest { .withDistanceMeasure(new EuclideanDistance()) .withStrategy(KNNStrategy.WEIGHTED); - Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01, -1.01}); - assertEquals(knnMdl.apply(vector), 1.0); + Vector vector = new DenseLocalOnHeapVector(new double[]{-1.01, -1.01}); + Assert.assertEquals(knnMdl.apply(vector), 1.0, PRECISION); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java index 66dbca9..ce9cae5 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java @@ -17,7 +17,6 @@ package org.apache.ignite.ml.knn; -import org.apache.ignite.internal.util.IgniteUtils; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.knn.classification.KNNStrategy; import org.apache.ignite.ml.knn.regression.KNNRegressionModel; @@ -30,28 +29,23 @@ import org.junit.Assert; import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import org.junit.Test; /** * Tests for {@link KNNRegressionTrainer}. */ -public class KNNRegressionTest extends BaseKNNTest { +public class KNNRegressionTest { /** */ - private double[] y; - - /** */ - private double[][] x; - - /** */ - public void testSimpleRegressionWithOneNeighbour() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - + @Test + public void simpleRegressionWithOneNeighbour() { Map<Integer, double[]> data = new HashMap<>(); - data.put(0, new double[] {11.0, 0, 0, 0, 0, 0}); - data.put(1, new double[] {12.0, 2.0, 0, 0, 0, 0}); - data.put(2, new double[] {13.0, 0, 3.0, 0, 0, 0}); - data.put(3, new double[] {14.0, 0, 0, 4.0, 0, 0}); - data.put(4, new double[] {15.0, 0, 0, 0, 5.0, 0}); - data.put(5, new double[] {16.0, 0, 0, 0, 0, 6.0}); + + data.put(0, new double[]{11.0, 0, 0, 0, 0, 0}); + data.put(1, new double[]{12.0, 2.0, 0, 0, 0, 0}); + data.put(2, new double[]{13.0, 0, 3.0, 0, 0, 0}); + data.put(3, new double[]{14.0, 0, 0, 4.0, 0, 0}); + data.put(4, new double[]{15.0, 0, 0, 0, 5.0, 0}); + data.put(5, new double[]{16.0, 0, 0, 0, 0, 6.0}); KNNRegressionTrainer trainer = new KNNRegressionTrainer(); @@ -63,32 +57,31 @@ public class KNNRegressionTest extends BaseKNNTest { .withDistanceMeasure(new EuclideanDistance()) .withStrategy(KNNStrategy.SIMPLE); - Vector vector = new DenseLocalOnHeapVector(new double[] {0, 0, 0, 5.0, 0.0}); + Vector vector = new DenseLocalOnHeapVector(new double[]{0, 0, 0, 5.0, 0.0}); System.out.println(knnMdl.apply(vector)); Assert.assertEquals(15, knnMdl.apply(vector), 1E-12); } /** */ - public void testLongly() { - - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - + @Test + public void longly() { Map<Integer, double[]> data = new HashMap<>(); - data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 1947}); - data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 1948}); - data.put(2, new double[] {60171, 88.2, 258054, 3682, 1616, 109773, 1949}); - data.put(3, new double[] {61187, 89.5, 284599, 3351, 1650, 110929, 1950}); - data.put(4, new double[] {63221, 96.2, 328975, 2099, 3099, 112075, 1951}); - data.put(5, new double[] {63639, 98.1, 346999, 1932, 3594, 113270, 1952}); - data.put(6, new double[] {64989, 99.0, 365385, 1870, 3547, 115094, 1953}); - data.put(7, new double[] {63761, 100.0, 363112, 3578, 3350, 116219, 1954}); - data.put(8, new double[] {66019, 101.2, 397469, 2904, 3048, 117388, 1955}); - data.put(9, new double[] {68169, 108.4, 442769, 2936, 2798, 120445, 1957}); - data.put(10, new double[] {66513, 110.8, 444546, 4681, 2637, 121950, 1958}); - data.put(11, new double[] {68655, 112.6, 482704, 3813, 2552, 123366, 1959}); - data.put(12, new double[] {69564, 114.2, 502601, 3931, 2514, 125368, 1960}); - data.put(13, new double[] {69331, 115.7, 518173, 4806, 2572, 127852, 1961}); - data.put(14, new double[] {70551, 116.9, 554894, 4007, 2827, 130081, 1962}); + + data.put(0, new double[]{60323, 83.0, 234289, 2356, 1590, 107608, 1947}); + data.put(1, new double[]{61122, 88.5, 259426, 2325, 1456, 108632, 1948}); + data.put(2, new double[]{60171, 88.2, 258054, 3682, 1616, 109773, 1949}); + data.put(3, new double[]{61187, 89.5, 284599, 3351, 1650, 110929, 1950}); + data.put(4, new double[]{63221, 96.2, 328975, 2099, 3099, 112075, 1951}); + data.put(5, new double[]{63639, 98.1, 346999, 1932, 3594, 113270, 1952}); + data.put(6, new double[]{64989, 99.0, 365385, 1870, 3547, 115094, 1953}); + data.put(7, new double[]{63761, 100.0, 363112, 3578, 3350, 116219, 1954}); + data.put(8, new double[]{66019, 101.2, 397469, 2904, 3048, 117388, 1955}); + data.put(9, new double[]{68169, 108.4, 442769, 2936, 2798, 120445, 1957}); + data.put(10, new double[]{66513, 110.8, 444546, 4681, 2637, 121950, 1958}); + data.put(11, new double[]{68655, 112.6, 482704, 3813, 2552, 123366, 1959}); + data.put(12, new double[]{69564, 114.2, 502601, 3931, 2514, 125368, 1960}); + data.put(13, new double[]{69331, 115.7, 518173, 4806, 2572, 127852, 1961}); + data.put(14, new double[]{70551, 116.9, 554894, 4007, 2827, 130081, 1962}); KNNRegressionTrainer trainer = new KNNRegressionTrainer(); @@ -100,31 +93,30 @@ public class KNNRegressionTest extends BaseKNNTest { .withDistanceMeasure(new EuclideanDistance()) .withStrategy(KNNStrategy.SIMPLE); - Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956}); + Vector vector = new DenseLocalOnHeapVector(new double[]{104.6, 419180, 2822, 2857, 118734, 1956}); System.out.println(knnMdl.apply(vector)); Assert.assertEquals(67857, knnMdl.apply(vector), 2000); } /** */ public void testLonglyWithWeightedStrategy() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - Map<Integer, double[]> data = new HashMap<>(); - data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 1947}); - data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 1948}); - data.put(2, new double[] {60171, 88.2, 258054, 3682, 1616, 109773, 1949}); - data.put(3, new double[] {61187, 89.5, 284599, 3351, 1650, 110929, 1950}); - data.put(4, new double[] {63221, 96.2, 328975, 2099, 3099, 112075, 1951}); - data.put(5, new double[] {63639, 98.1, 346999, 1932, 3594, 113270, 1952}); - data.put(6, new double[] {64989, 99.0, 365385, 1870, 3547, 115094, 1953}); - data.put(7, new double[] {63761, 100.0, 363112, 3578, 3350, 116219, 1954}); - data.put(8, new double[] {66019, 101.2, 397469, 2904, 3048, 117388, 1955}); - data.put(9, new double[] {68169, 108.4, 442769, 2936, 2798, 120445, 1957}); - data.put(10, new double[] {66513, 110.8, 444546, 4681, 2637, 121950, 1958}); - data.put(11, new double[] {68655, 112.6, 482704, 3813, 2552, 123366, 1959}); - data.put(12, new double[] {69564, 114.2, 502601, 3931, 2514, 125368, 1960}); - data.put(13, new double[] {69331, 115.7, 518173, 4806, 2572, 127852, 1961}); - data.put(14, new double[] {70551, 116.9, 554894, 4007, 2827, 130081, 1962}); + + data.put(0, new double[]{60323, 83.0, 234289, 2356, 1590, 107608, 1947}); + data.put(1, new double[]{61122, 88.5, 259426, 2325, 1456, 108632, 1948}); + data.put(2, new double[]{60171, 88.2, 258054, 3682, 1616, 109773, 1949}); + data.put(3, new double[]{61187, 89.5, 284599, 3351, 1650, 110929, 1950}); + data.put(4, new double[]{63221, 96.2, 328975, 2099, 3099, 112075, 1951}); + data.put(5, new double[]{63639, 98.1, 346999, 1932, 3594, 113270, 1952}); + data.put(6, new double[]{64989, 99.0, 365385, 1870, 3547, 115094, 1953}); + data.put(7, new double[]{63761, 100.0, 363112, 3578, 3350, 116219, 1954}); + data.put(8, new double[]{66019, 101.2, 397469, 2904, 3048, 117388, 1955}); + data.put(9, new double[]{68169, 108.4, 442769, 2936, 2798, 120445, 1957}); + data.put(10, new double[]{66513, 110.8, 444546, 4681, 2637, 121950, 1958}); + data.put(11, new double[]{68655, 112.6, 482704, 3813, 2552, 123366, 1959}); + data.put(12, new double[]{69564, 114.2, 502601, 3931, 2514, 125368, 1960}); + data.put(13, new double[]{69331, 115.7, 518173, 4806, 2572, 127852, 1961}); + data.put(14, new double[]{70551, 116.9, 554894, 4007, 2827, 130081, 1962}); KNNRegressionTrainer trainer = new KNNRegressionTrainer(); @@ -136,7 +128,7 @@ public class KNNRegressionTest extends BaseKNNTest { .withDistanceMeasure(new EuclideanDistance()) .withStrategy(KNNStrategy.SIMPLE); - Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956}); + Vector vector = new DenseLocalOnHeapVector(new double[]{104.6, 419180, 2822, 2857, 118734, 1956}); System.out.println(knnMdl.apply(vector)); Assert.assertEquals(67857, knnMdl.apply(vector), 2000); } http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java new file mode 100644 index 0000000..a25b303 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.knn; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Path; +import java.nio.file.Paths; +import org.apache.ignite.Ignite; +import org.apache.ignite.ml.structures.LabeledDataset; +import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader; +import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; + +/** + * Base class for decision trees test. + */ +public class LabeledDatasetHelper extends GridCommonAbstractTest { + /** Count of nodes. */ + private static final int NODE_COUNT = 4; + + /** Separator. */ + private static final String SEPARATOR = "\t"; + + /** Grid instance. */ + protected Ignite ignite; + + /** + * Default constructor. + */ + public LabeledDatasetHelper() { + super(false); + } + + /** + * {@inheritDoc} + */ + @Override protected void beforeTest() throws Exception { + ignite = grid(NODE_COUNT); + } + + /** {@inheritDoc} */ + @Override protected void beforeTestsStarted() throws Exception { + for (int i = 1; i <= NODE_COUNT; i++) + startGrid(i); + } + + /** {@inheritDoc} */ + @Override protected void afterTestsStopped() throws Exception { + stopAllGrids(); + } + + /** + * Loads labeled dataset from file with .txt extension. + * + * @param rsrcPath path to dataset. + * @return null if path is incorrect. + */ + LabeledDataset loadDatasetFromTxt(String rsrcPath, boolean isFallOnBadData) { + try { + Path path = Paths.get(this.getClass().getClassLoader().getResource(rsrcPath).toURI()); + try { + return LabeledDatasetLoader.loadFromTxtFile(path, SEPARATOR, false, isFallOnBadData); + } catch (IOException e) { + e.printStackTrace(); + } + } catch (URISyntaxException e) { + e.printStackTrace(); + return null; + } + return null; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java index cdd5dc4..77d40a6 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java @@ -34,7 +34,7 @@ import org.apache.ignite.ml.structures.LabeledVector; import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader; /** Tests behaviour of KNNClassificationTest. */ -public class LabeledDatasetTest extends BaseKNNTest implements ExternalizableTest<LabeledDataset> { +public class LabeledDatasetTest extends LabeledDatasetHelper implements ExternalizableTest<LabeledDataset> { /** */ private static final String KNN_IRIS_TXT = "datasets/knn/iris.txt";