IGNITE-10405: [ML] Refactor GaussianNaiveBayesTrainerExample to read data sample from file
This closes #5551 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/c6a05f8f Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/c6a05f8f Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/c6a05f8f Branch: refs/heads/ignite-10044 Commit: c6a05f8f2d190b67c1e06ed1a2856ed414198132 Parents: 0cd303f Author: zaleslaw <[email protected]> Authored: Tue Dec 4 14:18:16 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Tue Dec 4 14:18:16 2018 +0300 ---------------------------------------------------------------------- .../GaussianNaiveBayesTrainerExample.java | 25 ++-- .../ignite/examples/ml/util/SandboxMLCache.java | 19 --- .../ignite/examples/util/IrisDataset.java | 129 ------------------- 3 files changed, 12 insertions(+), 161 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/c6a05f8f/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java index b5e36ea..7cbc6d1 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java @@ -17,6 +17,7 @@ package org.apache.ignite.examples.ml.naivebayes; +import java.io.FileNotFoundException; import java.util.Arrays; import javax.cache.Cache; import org.apache.ignite.Ignite; @@ -24,14 +25,12 @@ import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.cache.query.QueryCursor; import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.examples.ml.util.MLSandboxDatasets; import org.apache.ignite.examples.ml.util.SandboxMLCache; import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel; import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer; -import static org.apache.ignite.examples.util.IrisDataset.irisDatasetFirstAndSecondClasses; - /** * Run naive Bayes classification model based on <a href="https://en.wikipedia.org/wiki/Naive_Bayes_classifier"> naive * Bayes classifier</a> algorithm ({@link GaussianNaiveBayesTrainer}) over distributed cache. @@ -49,15 +48,15 @@ import static org.apache.ignite.examples.util.IrisDataset.irisDatasetFirstAndSec */ public class GaussianNaiveBayesTrainerExample { /** Run example. */ - public static void main(String[] args) throws InterruptedException { + public static void main(String[] args) throws FileNotFoundException { System.out.println(); System.out.println(">>> Naive Bayes classification model over partitioned dataset usage example started."); // Start ignite grid. try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - IgniteCache<Integer, double[]> dataCache = new SandboxMLCache(ignite) - .fillCacheWith(irisDatasetFirstAndSecondClasses); + IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite) + .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS); System.out.println(">>> Create new naive Bayes classification trainer object."); GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer(); @@ -66,8 +65,8 @@ public class GaussianNaiveBayesTrainerExample { GaussianNaiveBayesModel mdl = trainer.fit( ignite, dataCache, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] + (k, v) -> v.copyOfRange(1, v.size()), + (k, v) -> v.get(0) ); System.out.println(">>> Naive Bayes model: " + mdl); @@ -78,11 +77,11 @@ public class GaussianNaiveBayesTrainerExample { // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix int[][] confusionMtx = {{0, 0}, {0, 0}}; - try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, double[]> observation : observations) { - double[] val = observation.getValue(); - Vector inputs = VectorUtils.of(Arrays.copyOfRange(val, 1, val.length)); - double groundTruth = val[0]; + try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, Vector> observation : observations) { + Vector val = observation.getValue(); + Vector inputs = val.copyOfRange(1, val.size()); + double groundTruth = val.get(0); double prediction = mdl.apply(inputs); http://git-wip-us.apache.org/repos/asf/ignite/blob/c6a05f8f/examples/src/main/java/org/apache/ignite/examples/ml/util/SandboxMLCache.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/util/SandboxMLCache.java b/examples/src/main/java/org/apache/ignite/examples/ml/util/SandboxMLCache.java index a8431de..20b22ef 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/util/SandboxMLCache.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/util/SandboxMLCache.java @@ -69,25 +69,6 @@ public class SandboxMLCache { /** * Fills cache with data and returns it. * - * @param data Data to fill the cache with. - * @return Filled Ignite Cache. - */ - public IgniteCache<Integer, Vector> getVectors(double[][] data) { - CacheConfiguration<Integer, Vector> cacheConfiguration = new CacheConfiguration<>(); - cacheConfiguration.setName("TEST_" + UUID.randomUUID()); - cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); - - IgniteCache<Integer, Vector> cache = ignite.createCache(cacheConfiguration); - - for (int i = 0; i < data.length; i++) - cache.put(i, VectorUtils.of(data[i])); - - return cache; - } - - /** - * Fills cache with data and returns it. - * * @param dataset The chosen dataset. * @return Filled Ignite Cache. * @throws FileNotFoundException If file not found. http://git-wip-us.apache.org/repos/asf/ignite/blob/c6a05f8f/examples/src/main/java/org/apache/ignite/examples/util/IrisDataset.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/util/IrisDataset.java b/examples/src/main/java/org/apache/ignite/examples/util/IrisDataset.java deleted file mode 100644 index 53080e8..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/util/IrisDataset.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.ignite.examples.util; - -/** Contains data from the <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>. */ -public final class IrisDataset { - - /** The 1st and 2nd classes from the Iris dataset. */ - public static final double[][] irisDatasetFirstAndSecondClasses = { - {0, 5.1, 3.5, 1.4, 0.2}, - {0, 4.9, 3, 1.4, 0.2}, - {0, 4.7, 3.2, 1.3, 0.2}, - {0, 4.6, 3.1, 1.5, 0.2}, - {0, 5, 3.6, 1.4, 0.2}, - {0, 5.4, 3.9, 1.7, 0.4}, - {0, 4.6, 3.4, 1.4, 0.3}, - {0, 5, 3.4, 1.5, 0.2}, - {0, 4.4, 2.9, 1.4, 0.2}, - {0, 4.9, 3.1, 1.5, 0.1}, - {0, 5.4, 3.7, 1.5, 0.2}, - {0, 4.8, 3.4, 1.6, 0.2}, - {0, 4.8, 3, 1.4, 0.1}, - {0, 4.3, 3, 1.1, 0.1}, - {0, 5.8, 4, 1.2, 0.2}, - {0, 5.7, 4.4, 1.5, 0.4}, - {0, 5.4, 3.9, 1.3, 0.4}, - {0, 5.1, 3.5, 1.4, 0.3}, - {0, 5.7, 3.8, 1.7, 0.3}, - {0, 5.1, 3.8, 1.5, 0.3}, - {0, 5.4, 3.4, 1.7, 0.2}, - {0, 5.1, 3.7, 1.5, 0.4}, - {0, 4.6, 3.6, 1, 0.2}, - {0, 5.1, 3.3, 1.7, 0.5}, - {0, 4.8, 3.4, 1.9, 0.2}, - {0, 5, 3, 1.6, 0.2}, - {0, 5, 3.4, 1.6, 0.4}, - {0, 5.2, 3.5, 1.5, 0.2}, - {0, 5.2, 3.4, 1.4, 0.2}, - {0, 4.7, 3.2, 1.6, 0.2}, - {0, 4.8, 3.1, 1.6, 0.2}, - {0, 5.4, 3.4, 1.5, 0.4}, - {0, 5.2, 4.1, 1.5, 0.1}, - {0, 5.5, 4.2, 1.4, 0.2}, - {0, 4.9, 3.1, 1.5, 0.1}, - {0, 5, 3.2, 1.2, 0.2}, - {0, 5.5, 3.5, 1.3, 0.2}, - {0, 4.9, 3.1, 1.5, 0.1}, - {0, 4.4, 3, 1.3, 0.2}, - {0, 5.1, 3.4, 1.5, 0.2}, - {0, 5, 3.5, 1.3, 0.3}, - {0, 4.5, 2.3, 1.3, 0.3}, - {0, 4.4, 3.2, 1.3, 0.2}, - {0, 5, 3.5, 1.6, 0.6}, - {0, 5.1, 3.8, 1.9, 0.4}, - {0, 4.8, 3, 1.4, 0.3}, - {0, 5.1, 3.8, 1.6, 0.2}, - {0, 4.6, 3.2, 1.4, 0.2}, - {0, 5.3, 3.7, 1.5, 0.2}, - {0, 5, 3.3, 1.4, 0.2}, - {1, 7, 3.2, 4.7, 1.4}, - {1, 6.4, 3.2, 4.5, 1.5}, - {1, 6.9, 3.1, 4.9, 1.5}, - {1, 5.5, 2.3, 4, 1.3}, - {1, 6.5, 2.8, 4.6, 1.5}, - {1, 5.7, 2.8, 4.5, 1.3}, - {1, 6.3, 3.3, 4.7, 1.6}, - {1, 4.9, 2.4, 3.3, 1}, - {1, 6.6, 2.9, 4.6, 1.3}, - {1, 5.2, 2.7, 3.9, 1.4}, - {1, 5, 2, 3.5, 1}, - {1, 5.9, 3, 4.2, 1.5}, - {1, 6, 2.2, 4, 1}, - {1, 6.1, 2.9, 4.7, 1.4}, - {1, 5.6, 2.9, 3.6, 1.3}, - {1, 6.7, 3.1, 4.4, 1.4}, - {1, 5.6, 3, 4.5, 1.5}, - {1, 5.8, 2.7, 4.1, 1}, - {1, 6.2, 2.2, 4.5, 1.5}, - {1, 5.6, 2.5, 3.9, 1.1}, - {1, 5.9, 3.2, 4.8, 1.8}, - {1, 6.1, 2.8, 4, 1.3}, - {1, 6.3, 2.5, 4.9, 1.5}, - {1, 6.1, 2.8, 4.7, 1.2}, - {1, 6.4, 2.9, 4.3, 1.3}, - {1, 6.6, 3, 4.4, 1.4}, - {1, 6.8, 2.8, 4.8, 1.4}, - {1, 6.7, 3, 5, 1.7}, - {1, 6, 2.9, 4.5, 1.5}, - {1, 5.7, 2.6, 3.5, 1}, - {1, 5.5, 2.4, 3.8, 1.1}, - {1, 5.5, 2.4, 3.7, 1}, - {1, 5.8, 2.7, 3.9, 1.2}, - {1, 6, 2.7, 5.1, 1.6}, - {1, 5.4, 3, 4.5, 1.5}, - {1, 6, 3.4, 4.5, 1.6}, - {1, 6.7, 3.1, 4.7, 1.5}, - {1, 6.3, 2.3, 4.4, 1.3}, - {1, 5.6, 3, 4.1, 1.3}, - {1, 5.5, 2.5, 4, 1.3}, - {1, 5.5, 2.6, 4.4, 1.2}, - {1, 6.1, 3, 4.6, 1.4}, - {1, 5.8, 2.6, 4, 1.2}, - {1, 5, 2.3, 3.3, 1}, - {1, 5.6, 2.7, 4.2, 1.3}, - {1, 5.7, 3, 4.2, 1.2}, - {1, 5.7, 2.9, 4.2, 1.3}, - {1, 6.2, 2.9, 4.3, 1.3}, - {1, 5.1, 2.5, 3, 1.1}, - {1, 5.7, 2.8, 4.1, 1.3}, - }; - - /** */ - private IrisDataset() { - } -}
