Repository: ignite Updated Branches: refs/heads/master 6557fe626 -> a4653b7c1
IGNITE-7830: Knn Lin Reg with new datasets this closes #3583 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/a4653b7c Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/a4653b7c Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/a4653b7c Branch: refs/heads/master Commit: a4653b7c1287a039206bf22e9d85125bb15bc412 Parents: 6557fe6 Author: zaleslaw <zaleslaw....@gmail.com> Authored: Wed Apr 11 12:31:48 2018 +0300 Committer: YuriBabak <y.ch...@gmail.com> Committed: Wed Apr 11 12:31:48 2018 +0300 ---------------------------------------------------------------------- .../java/org/apache/ignite/ml/knn/KNNUtils.java | 59 ++++++++ .../KNNClassificationTrainer.java | 23 +-- .../ml/knn/regression/KNNRegressionModel.java | 87 +++++++++++ .../ml/knn/regression/KNNRegressionTrainer.java | 40 ++++++ .../ignite/ml/knn/regression/package-info.java | 22 +++ .../apache/ignite/ml/knn/KNNRegressionTest.java | 143 +++++++++++++++++++ .../org/apache/ignite/ml/knn/KNNTestSuite.java | 1 + 7 files changed, 354 insertions(+), 21 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/a4653b7c/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java new file mode 100644 index 0000000..88fa70f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.knn; + +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.PartitionDataBuilder; +import org.apache.ignite.ml.knn.partitions.KNNPartitionContext; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.structures.LabeledDataset; +import org.apache.ignite.ml.structures.LabeledVector; +import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap; +import org.jetbrains.annotations.Nullable; + +/** + * Helper class for KNNRegression. + */ +public class KNNUtils { + /** + * Builds dataset. + * + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @return Dataset. + */ + @Nullable public static <K, V> Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> buildDataset(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + PartitionDataBuilder<K, V, KNNPartitionContext, LabeledDataset<Double, LabeledVector>> partDataBuilder + = new LabeledDatasetPartitionDataBuilderOnHeap<>( + featureExtractor, + lbExtractor + ); + + Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> dataset = null; + + if (datasetBuilder != null) { + dataset = datasetBuilder.build( + (upstream, upstreamSize) -> new KNNPartitionContext(), + partDataBuilder + ); + } + return dataset; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/a4653b7c/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java index 357047f..c0c8e65 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java @@ -17,14 +17,9 @@ package org.apache.ignite.ml.knn.classification; -import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; -import org.apache.ignite.ml.dataset.PartitionDataBuilder; -import org.apache.ignite.ml.knn.partitions.KNNPartitionContext; +import org.apache.ignite.ml.knn.KNNUtils; import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.structures.LabeledDataset; -import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap; -import org.apache.ignite.ml.structures.LabeledVector; import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; /** @@ -41,20 +36,6 @@ public class KNNClassificationTrainer implements SingleLabelDatasetTrainer<KNNCl */ @Override public <K, V> KNNClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - PartitionDataBuilder<K, V, KNNPartitionContext, LabeledDataset<Double, LabeledVector>> partDataBuilder - = new LabeledDatasetPartitionDataBuilderOnHeap<>( - featureExtractor, - lbExtractor - ); - - Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> dataset = null; - - if (datasetBuilder != null) { - dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new KNNPartitionContext(), - partDataBuilder - ); - } - return new KNNClassificationModel<>(dataset); + return new KNNClassificationModel<>(KNNUtils.buildDataset(datasetBuilder, featureExtractor, lbExtractor)); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/a4653b7c/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionModel.java new file mode 100644 index 0000000..cabc143 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionModel.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.knn.regression; + +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.knn.classification.KNNClassificationModel; +import org.apache.ignite.ml.knn.partitions.KNNPartitionContext; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException; +import org.apache.ignite.ml.structures.LabeledDataset; +import org.apache.ignite.ml.structures.LabeledVector; + +import java.util.List; + +/** + * This class provides kNN Multiple Linear Regression or Locally [weighted] regression (Simple and Weighted versions). + * + * <p> This is an instance-based learning method. </p> + * + * <ul> + * <li>Local means using nearby points (i.e. a nearest neighbors approach).</li> + * <li>Weighted means we value points based upon how far away they are.</li> + * <li>Regression means approximating a function.</li> + * </ul> + */ +public class KNNRegressionModel<K,V> extends KNNClassificationModel<K,V> { + /** + * Builds the model via prepared dataset. + * @param dataset Specially prepared object to run algorithm over it. + */ + public KNNRegressionModel(Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> dataset) { + super(dataset); + } + + /** {@inheritDoc} */ + @Override public Double apply(Vector v) { + List<LabeledVector> neighbors = findKNearestNeighbors(v); + + return predictYBasedOn(neighbors, v); + } + + /** */ + private double predictYBasedOn(List<LabeledVector> neighbors, Vector v) { + switch (stgy) { + case SIMPLE: + return simpleRegression(neighbors); + case WEIGHTED: + return weightedRegression(neighbors, v); + default: + throw new UnsupportedOperationException("Strategy " + stgy.name() + " is not supported"); + } + } + + /** */ + private double weightedRegression(List<LabeledVector> neighbors, Vector v) { + double sum = 0.0; + double div = 0.0; + for (LabeledVector<Vector, Double> neighbor : neighbors) { + double distance = distanceMeasure.compute(v, neighbor.features()); + sum += neighbor.label() * distance; + div += distance; + } + return sum / div; + } + + /** */ + private double simpleRegression(List<LabeledVector> neighbors) { + double sum = 0.0; + for (LabeledVector<Vector, Double> neighbor : neighbors) + sum += neighbor.label(); + return sum / (double)k; + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/a4653b7c/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java new file mode 100644 index 0000000..2d13cd5 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.knn.regression; + +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.knn.KNNUtils; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; + +/** + * kNN algorithm trainer to solve regression task. + */ +public class KNNRegressionTrainer{ + /** + * Trains model based on the specified data. + * + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @return Model. + */ + public <K, V> KNNRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + return new KNNRegressionModel<>(KNNUtils.buildDataset(datasetBuilder, featureExtractor, lbExtractor)); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/a4653b7c/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/package-info.java new file mode 100644 index 0000000..82e7192 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains helper classes for kNN regression algorithms. + */ +package org.apache.ignite.ml.knn.regression; http://git-wip-us.apache.org/repos/asf/ignite/blob/a4653b7c/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java new file mode 100644 index 0000000..66dbca9 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.knn; + +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.knn.classification.KNNStrategy; +import org.apache.ignite.ml.knn.regression.KNNRegressionModel; +import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.distances.EuclideanDistance; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.junit.Assert; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * Tests for {@link KNNRegressionTrainer}. + */ +public class KNNRegressionTest extends BaseKNNTest { + /** */ + private double[] y; + + /** */ + private double[][] x; + + /** */ + public void testSimpleRegressionWithOneNeighbour() { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + + Map<Integer, double[]> data = new HashMap<>(); + data.put(0, new double[] {11.0, 0, 0, 0, 0, 0}); + data.put(1, new double[] {12.0, 2.0, 0, 0, 0, 0}); + data.put(2, new double[] {13.0, 0, 3.0, 0, 0, 0}); + data.put(3, new double[] {14.0, 0, 0, 4.0, 0, 0}); + data.put(4, new double[] {15.0, 0, 0, 0, 5.0, 0}); + data.put(5, new double[] {16.0, 0, 0, 0, 0, 6.0}); + + KNNRegressionTrainer trainer = new KNNRegressionTrainer(); + + KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit( + new LocalDatasetBuilder<>(data, 2), + (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> v[0] + ).withK(1) + .withDistanceMeasure(new EuclideanDistance()) + .withStrategy(KNNStrategy.SIMPLE); + + Vector vector = new DenseLocalOnHeapVector(new double[] {0, 0, 0, 5.0, 0.0}); + System.out.println(knnMdl.apply(vector)); + Assert.assertEquals(15, knnMdl.apply(vector), 1E-12); + } + + /** */ + public void testLongly() { + + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + + Map<Integer, double[]> data = new HashMap<>(); + data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 1947}); + data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 1948}); + data.put(2, new double[] {60171, 88.2, 258054, 3682, 1616, 109773, 1949}); + data.put(3, new double[] {61187, 89.5, 284599, 3351, 1650, 110929, 1950}); + data.put(4, new double[] {63221, 96.2, 328975, 2099, 3099, 112075, 1951}); + data.put(5, new double[] {63639, 98.1, 346999, 1932, 3594, 113270, 1952}); + data.put(6, new double[] {64989, 99.0, 365385, 1870, 3547, 115094, 1953}); + data.put(7, new double[] {63761, 100.0, 363112, 3578, 3350, 116219, 1954}); + data.put(8, new double[] {66019, 101.2, 397469, 2904, 3048, 117388, 1955}); + data.put(9, new double[] {68169, 108.4, 442769, 2936, 2798, 120445, 1957}); + data.put(10, new double[] {66513, 110.8, 444546, 4681, 2637, 121950, 1958}); + data.put(11, new double[] {68655, 112.6, 482704, 3813, 2552, 123366, 1959}); + data.put(12, new double[] {69564, 114.2, 502601, 3931, 2514, 125368, 1960}); + data.put(13, new double[] {69331, 115.7, 518173, 4806, 2572, 127852, 1961}); + data.put(14, new double[] {70551, 116.9, 554894, 4007, 2827, 130081, 1962}); + + KNNRegressionTrainer trainer = new KNNRegressionTrainer(); + + KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit( + new LocalDatasetBuilder<>(data, 2), + (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> v[0] + ).withK(3) + .withDistanceMeasure(new EuclideanDistance()) + .withStrategy(KNNStrategy.SIMPLE); + + Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956}); + System.out.println(knnMdl.apply(vector)); + Assert.assertEquals(67857, knnMdl.apply(vector), 2000); + } + + /** */ + public void testLonglyWithWeightedStrategy() { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + + Map<Integer, double[]> data = new HashMap<>(); + data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 1947}); + data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 1948}); + data.put(2, new double[] {60171, 88.2, 258054, 3682, 1616, 109773, 1949}); + data.put(3, new double[] {61187, 89.5, 284599, 3351, 1650, 110929, 1950}); + data.put(4, new double[] {63221, 96.2, 328975, 2099, 3099, 112075, 1951}); + data.put(5, new double[] {63639, 98.1, 346999, 1932, 3594, 113270, 1952}); + data.put(6, new double[] {64989, 99.0, 365385, 1870, 3547, 115094, 1953}); + data.put(7, new double[] {63761, 100.0, 363112, 3578, 3350, 116219, 1954}); + data.put(8, new double[] {66019, 101.2, 397469, 2904, 3048, 117388, 1955}); + data.put(9, new double[] {68169, 108.4, 442769, 2936, 2798, 120445, 1957}); + data.put(10, new double[] {66513, 110.8, 444546, 4681, 2637, 121950, 1958}); + data.put(11, new double[] {68655, 112.6, 482704, 3813, 2552, 123366, 1959}); + data.put(12, new double[] {69564, 114.2, 502601, 3931, 2514, 125368, 1960}); + data.put(13, new double[] {69331, 115.7, 518173, 4806, 2572, 127852, 1961}); + data.put(14, new double[] {70551, 116.9, 554894, 4007, 2827, 130081, 1962}); + + KNNRegressionTrainer trainer = new KNNRegressionTrainer(); + + KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit( + new LocalDatasetBuilder<>(data, 2), + (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> v[0] + ).withK(3) + .withDistanceMeasure(new EuclideanDistance()) + .withStrategy(KNNStrategy.SIMPLE); + + Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956}); + System.out.println(knnMdl.apply(vector)); + Assert.assertEquals(67857, knnMdl.apply(vector), 2000); + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/a4653b7c/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java index 95ebec5..55ef24e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java @@ -26,6 +26,7 @@ import org.junit.runners.Suite; @RunWith(Suite.class) @Suite.SuiteClasses({ KNNClassificationTest.class, + KNNRegressionTest.class, LabeledDatasetTest.class }) public class KNNTestSuite {