IGNITE-9747: [ML] Add Bernoulli Naive Bayes classifier This closes #5204
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/00af5e62 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/00af5e62 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/00af5e62 Branch: refs/heads/ignite-10639 Commit: 00af5e62ac7d35ada714ef81d2b7a19d43bb36af Parents: 32f564c Author: Ravil Galeyev <[email protected]> Authored: Tue Dec 18 15:49:06 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Tue Dec 18 15:49:06 2018 +0300 ---------------------------------------------------------------------- .../DiscreteNaiveBayesTrainerExample.java | 115 +++++++++ .../discrete/DiscreteNaiveBayesModel.java | 127 ++++++++++ .../discrete/DiscreteNaiveBayesSumsHolder.java | 54 +++++ .../discrete/DiscreteNaiveBayesTrainer.java | 233 +++++++++++++++++++ .../ml/naivebayes/discrete/package-info.java | 22 ++ .../ignite/ml/util/MLSandboxDatasets.java | 5 +- .../english_vs_scottish_binary_dataset.csv | 14 ++ .../discrete/DiscreteNaiveBayesModelTest.java | 45 ++++ .../discrete/DiscreteNaiveBayesTest.java | 67 ++++++ .../discrete/DiscreteNaiveBayesTrainerTest.java | 185 +++++++++++++++ 10 files changed, 866 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java new file mode 100644 index 0000000..5af3f69 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.naivebayes; + +import java.io.FileNotFoundException; +import java.util.Arrays; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel; +import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer; +import org.apache.ignite.ml.util.MLSandboxDatasets; +import org.apache.ignite.ml.util.SandboxMLCache; + +/** + * Run naive Bayes classification model based on <a href=https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes"> + * naive Bayes classifier</a> algorithm ({@link DiscreteNaiveBayesTrainer}) over distributed cache. + * <p> + * Code in this example launches Ignite grid and fills the cache with test data points. + * </p> + * <p> + * After that it trains the Discrete naive Bayes classification model based on the specified data.</p> + * <p> + * Finally, this example loops over the test set of data points, applies the trained model to predict the target value, + * compares prediction to expected outcome (ground truth), and builds + * <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a>.</p> + * <p> + * You can change the test data used in this example and re-run it to explore this algorithm further.</p> + */ +public class DiscreteNaiveBayesTrainerExample { + /** Run example. */ + public static void main(String[] args) throws FileNotFoundException { + System.out.println(); + System.out.println(">>> Discrete naive Bayes classification model over partitioned dataset usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite) + .fillCacheWith(MLSandboxDatasets.ENGLISH_VS_SCOTTISH); + + double[][] thresholds = new double[][] {{.5}, {.5}, {.5}, {.5}, {.5}}; + System.out.println(">>> Create new Discrete naive Bayes classification trainer object."); + DiscreteNaiveBayesTrainer trainer = new DiscreteNaiveBayesTrainer() + .setBucketThresholds(thresholds); + + System.out.println(">>> Perform the training to get the model."); + DiscreteNaiveBayesModel mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> v.copyOfRange(1, v.size()), + (k, v) -> v.get(0) + ); + + System.out.println(">>> Discrete Naive Bayes model: " + mdl); + + int amountOfErrors = 0; + int totalAmount = 0; + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0}, {0, 0}}; + + try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, Vector> observation : observations) { + Vector val = observation.getValue(); + Vector inputs = val.copyOfRange(1, val.size()); + double groundTruth = val.get(0); + + double prediction = mdl.apply(inputs); + + totalAmount++; + if (groundTruth != prediction) + amountOfErrors++; + + int idx1 = (int)prediction; + int idx2 = (int)groundTruth; + + confusionMtx[idx1][idx2]++; + + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + } + + System.out.println(">>> ---------------------------------"); + + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); + } + + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + System.out.println(">>> ---------------------------------"); + + System.out.println(">>> Discrete Naive bayes model over partitioned dataset usage example completed."); + } + } + +} http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java new file mode 100644 index 0000000..7ab2957 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.naivebayes.discrete; + +import java.io.Serializable; +import org.apache.ignite.ml.Exportable; +import org.apache.ignite.ml.Exporter; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.math.primitives.vector.Vector; + +/** + * Discrete naive Bayes model which predicts result value {@code y} belongs to a class {@code C_k, k in [0..K]} as + * {@code p(C_k,y) =x_1*p_k1^x *...*x_i*p_ki^x_i}. Where {@code x_i} is a discrete feature, {@code p_ki} is a prior + * probability probability of class {@code p(x|C_k)}. Returns the number of the most possible class. + */ +public class DiscreteNaiveBayesModel implements Model<Vector, Double>, Exportable<DiscreteNaiveBayesModel>, Serializable { + /** */ + private static final long serialVersionUID = -127386523291350345L; + /** + * Probabilities of features for all classes for each label. {@code labels[c][f][b]} contains a probability for + * class {@code c} for feature {@code f} for bucket {@code b}. + */ + private final double[][][] probabilities; + /** Prior probabilities of each class */ + private final double[] clsProbabilities; + /** Labels. */ + private final double[] labels; + /** + * The bucket thresholds to convert a features to discrete values. {@code bucketThresholds[f][b]} contains the right + * border for feature {@code f} for bucket {@code b}. Everything which is above the last thresdold goes to the next + * bucket. + */ + private final double[][] bucketThresholds; + /** Amount values in each buckek for each feature per label. */ + private final DiscreteNaiveBayesSumsHolder sumsHolder; + + /** + * @param probabilities Probabilities of features for classes. + * @param clsProbabilities Prior probabilities for classes. + * @param bucketThresholds The threshold to convert a feature to a binary value. + * @param sumsHolder Amount values which are abouve the threshold per label. + * @param labels Labels. + */ + public DiscreteNaiveBayesModel(double[][][] probabilities, double[] clsProbabilities, double[] labels, + double[][] bucketThresholds, DiscreteNaiveBayesSumsHolder sumsHolder) { + this.probabilities = probabilities; + this.clsProbabilities = clsProbabilities; + this.labels = labels; + this.bucketThresholds = bucketThresholds; + this.sumsHolder = sumsHolder; + } + + /** {@inheritDoc} */ + @Override public <P> void saveModel(Exporter<DiscreteNaiveBayesModel, P> exporter, P path) { + exporter.save(this, path); + } + + /** + * @param vector features vector. + * @return a label with max probability. + */ + @Override public Double apply(Vector vector) { + double maxProbapilityPower = -Double.MAX_VALUE; + int maxLabelIndex = -1; + + for (int i = 0; i < clsProbabilities.length; i++) { + double probabilityPower = Math.log(clsProbabilities[i]); + + for (int j = 0; j < probabilities[0].length; j++) { + int x = toBucketNumber(vector.get(j), bucketThresholds[j]); + double p = probabilities[i][j][x]; + probabilityPower += (p > 0 ? Math.log(p) : .0); + } + + if (probabilityPower > maxProbapilityPower) { + maxLabelIndex = i; + maxProbapilityPower = probabilityPower; + } + } + return labels[maxLabelIndex]; + } + + /** */ + public double[][][] getProbabilities() { + return probabilities; + } + + /** */ + public double[] getClsProbabilities() { + return clsProbabilities; + } + + /** */ + public double[][] getBucketThresholds() { + return bucketThresholds; + } + + /** */ + public DiscreteNaiveBayesSumsHolder getSumsHolder() { + return sumsHolder; + } + + /** Returs a bucket number to which the {@code value} corresponds. */ + private int toBucketNumber(double val, double[] thresholds) { + for (int i = 0; i < thresholds.length; i++) { + if (val < thresholds[i]) + return i; + } + + return thresholds.length; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java new file mode 100644 index 0000000..61ac692 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.naivebayes.discrete; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.math.util.MapUtil; + +/** Service class is used to calculate amount of values which are below the threshold. */ +public class DiscreteNaiveBayesSumsHolder implements AutoCloseable, Serializable { + /** Serial version uid. */ + private static final long serialVersionUID = -2059362365851744206L; + /** Sums of values correspones to a particular bucket for all features for each label */ + Map<Double, long[][]> valuesInBucketPerLbl = new HashMap<>(); + /** Rows count for each label */ + Map<Double, Integer> featureCountersPerLbl = new HashMap<>(); + + /** Merge to current */ + DiscreteNaiveBayesSumsHolder merge(DiscreteNaiveBayesSumsHolder other) { + valuesInBucketPerLbl = MapUtil.mergeMaps(valuesInBucketPerLbl, other.valuesInBucketPerLbl, this::sum, HashMap::new); + featureCountersPerLbl = MapUtil.mergeMaps(featureCountersPerLbl, other.featureCountersPerLbl, (i1, i2) -> i1 + i2, HashMap::new); + return this; + } + + /** In-place operation. Sums {@code arr2} to {@code arr1} element to element. */ + private long[][] sum(long[][] arr1, long[][] arr2) { + for (int i = 0; i < arr1.length; i++) { + for (int j = 0; j < arr1[i].length; j++) + arr1[i][j] += arr2[i][j]; + } + + return arr1; + } + + /** */ + @Override public void close() { + // Do nothing, GC will clean up. + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java new file mode 100644 index 0000000..0779b84 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.naivebayes.discrete; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; + +/** + * Trainer for the Discrete naive Bayes classification model. The trainer calculates prior probabilities from the input + * dataset. Prior probabilities can be also set by {@code setPriorProbabilities} or {@code withEquiprobableClasses}. If + * {@code equiprobableClasses} is set, the probalilities of all classes will be {@code 1/k}, where {@code k} is classes + * count. Also, the trainer converts feature to discrete values by using {@code bucketThresholds}. + */ +public class DiscreteNaiveBayesTrainer extends SingleLabelDatasetTrainer<DiscreteNaiveBayesModel> { + /** Precision to compare bucketThresholds. */ + private static final double PRECISION = 1e-10; + /* Preset prior probabilities. */ + private double[] priorProbabilities; + /* Sets equivalent probability for all classes. */ + private boolean equiprobableClasses; + /** The threshold to convert a feature to a discrete value. */ + private double[][] bucketThresholds; + + /** + * Trains model based on the specified data. + * + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @return Model. + */ + @Override public <K, V> DiscreteNaiveBayesModel fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override protected boolean checkState(DiscreteNaiveBayesModel mdl) { + if (mdl.getBucketThresholds().length != bucketThresholds.length) + return false; + + for (int i = 0; i < bucketThresholds.length; i++) { + for (int j = 0; i < bucketThresholds[i].length; i++) { + if (Math.abs(mdl.getBucketThresholds()[i][j] - bucketThresholds[i][j]) > PRECISION) + return false; + } + } + + return true; + } + + /** {@inheritDoc} */ + @Override protected <K, V> DiscreteNaiveBayesModel updateModel(DiscreteNaiveBayesModel mdl, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + + try (Dataset<EmptyContext, DiscreteNaiveBayesSumsHolder> dataset = datasetBuilder.build( + envBuilder, + (env, upstream, upstreamSize) -> new EmptyContext(), + (env, upstream, upstreamSize, ctx) -> { + DiscreteNaiveBayesSumsHolder res = new DiscreteNaiveBayesSumsHolder(); + while (upstream.hasNext()) { + UpstreamEntry<K, V> entity = upstream.next(); + + Vector features = featureExtractor.apply(entity.getKey(), entity.getValue()); + Double lb = lbExtractor.apply(entity.getKey(), entity.getValue()); + + long[][] valuesInBucket; + + int size = features.size(); + if (!res.valuesInBucketPerLbl.containsKey(lb)) { + valuesInBucket = new long[size][]; + for (int i = 0; i < size; i++) { + valuesInBucket[i] = new long[bucketThresholds[i].length + 1]; + Arrays.fill(valuesInBucket[i], 0L); + } + res.valuesInBucketPerLbl.put(lb, valuesInBucket); + } + + if (!res.featureCountersPerLbl.containsKey(lb)) + res.featureCountersPerLbl.put(lb, 0); + + res.featureCountersPerLbl.put(lb, res.featureCountersPerLbl.get(lb) + 1); + + valuesInBucket = res.valuesInBucketPerLbl.get(lb); + + for (int j = 0; j < size; j++) { + double x = features.get(j); + int bucketNum = toBucketNumber(x, bucketThresholds[j]); + valuesInBucket[j][bucketNum] += 1; + } + } + return res; + })) { + DiscreteNaiveBayesSumsHolder sumsHolder = dataset.compute(t -> t, (a, b) -> { + if (a == null) + return b == null ? new DiscreteNaiveBayesSumsHolder() : b; + if (b == null) + return a; + return a.merge(b); + }); + + if (mdl != null && checkState(mdl)) { + if (checkSumsHolder(sumsHolder, mdl.getSumsHolder())) + sumsHolder = sumsHolder.merge(mdl.getSumsHolder()); + } + + List<Double> sortedLabels = new ArrayList<>(sumsHolder.featureCountersPerLbl.keySet()); + sortedLabels.sort(Double::compareTo); + assert !sortedLabels.isEmpty() : "The dataset should contain at least one feature"; + + int lbCnt = sortedLabels.size(); + int featureCnt = sumsHolder.valuesInBucketPerLbl.get(sortedLabels.get(0)).length; + + double[][][] probabilities = new double[lbCnt][featureCnt][]; + double[] classProbabilities = new double[lbCnt]; + double[] labels = new double[lbCnt]; + long datasetSize = sumsHolder.featureCountersPerLbl.values().stream().mapToInt(i -> i).sum(); + + int lbl = 0; + + for (Double label : sortedLabels) { + int cnt = sumsHolder.featureCountersPerLbl.get(label); + long[][] sum = sumsHolder.valuesInBucketPerLbl.get(label); + + for (int i = 0; i < featureCnt; i++) { + + int bucketsCnt = sum[i].length; + probabilities[lbl][i] = new double[bucketsCnt]; + + for (int j = 0; j < bucketsCnt; j++) + probabilities[lbl][i][j] = (double)sum[i][j] / cnt; + } + + if (equiprobableClasses) + classProbabilities[lbl] = 1. / lbCnt; + else if (priorProbabilities != null) { + assert classProbabilities.length == priorProbabilities.length; + classProbabilities[lbl] = priorProbabilities[lbl]; + } + else + classProbabilities[lbl] = (double)cnt / datasetSize; + + labels[lbl] = label; + ++lbl; + } + return new DiscreteNaiveBayesModel(probabilities, classProbabilities, labels, bucketThresholds, sumsHolder); + } + catch (Exception e) { + throw new RuntimeException(e); + } + + } + + /** Checks that two {@code DiscreteNaiveBayesSumsHolder} contain the same lengths of future vectors. */ + private boolean checkSumsHolder(DiscreteNaiveBayesSumsHolder holder1, DiscreteNaiveBayesSumsHolder holder2) { + if (holder1 == null || holder2 == null) + return false; + + Optional<long[][]> optionalFirst = holder1.valuesInBucketPerLbl.values().stream().findFirst(); + Optional<long[][]> optionalSecond = holder2.valuesInBucketPerLbl.values().stream().findFirst(); + + if (optionalFirst.isPresent()) { + if (optionalSecond.isPresent()) + return optionalFirst.get().length == optionalSecond.get().length; + else + return false; + } + else + return !optionalSecond.isPresent(); + } + + /** Sets equal probability for all classes. */ + public DiscreteNaiveBayesTrainer withEquiprobableClasses() { + resetProbabilitiesSettings(); + equiprobableClasses = true; + return this; + } + + /** Sets prior probabilities. */ + public DiscreteNaiveBayesTrainer setPriorProbabilities(double[] priorProbabilities) { + resetProbabilitiesSettings(); + this.priorProbabilities = priorProbabilities.clone(); + return this; + } + + /** Sets buckest borders. */ + public DiscreteNaiveBayesTrainer setBucketThresholds(double[][] bucketThresholds) { + this.bucketThresholds = bucketThresholds; + return this; + } + + /** Sets default settings {@code equiprobableClasses} to {@code false} and removes priorProbabilities. */ + public DiscreteNaiveBayesTrainer resetProbabilitiesSettings() { + equiprobableClasses = false; + priorProbabilities = null; + return this; + } + + /** Returs a bucket number to which the {@code value} corresponds. */ + private int toBucketNumber(double val, double[] thresholds) { + for (int i = 0; i < thresholds.length; i++) { + if (val < thresholds[i]) + return i; + } + + return thresholds.length; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/package-info.java new file mode 100644 index 0000000..092d22d --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains Bernoulli naive Bayes classifier. + */ +package org.apache.ignite.ml.naivebayes.discrete; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java index 84ae8f6..b8d5fb4 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java @@ -50,7 +50,10 @@ public enum MLSandboxDatasets { WINE_RECOGNITION("modules/ml/src/main/resources/datasets/wine.txt", false, ","), /** The Boston house-prices dataset. Could be found <a href="https://archive.ics.uci.edu/ml/machine-learning-databases/housing/">here</a>. */ - BOSTON_HOUSE_PRICES("modules/ml/src/main/resources/datasets/boston_housing_dataset.txt", false, ","); + BOSTON_HOUSE_PRICES("modules/ml/src/main/resources/datasets/boston_housing_dataset.txt", false, ","), + + /** Example from book Barber D. Bayesian reasoning and machine learning. Chapter 10. */ + ENGLISH_VS_SCOTTISH("modules/ml/src/main/resources/datasets/english_vs_scottish_binary_dataset.csv", true, ","); /** Filename. */ private final String filename; http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/main/resources/datasets/english_vs_scottish_binary_dataset.csv ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/resources/datasets/english_vs_scottish_binary_dataset.csv b/modules/ml/src/main/resources/datasets/english_vs_scottish_binary_dataset.csv new file mode 100644 index 0000000..0fa0f8b --- /dev/null +++ b/modules/ml/src/main/resources/datasets/english_vs_scottish_binary_dataset.csv @@ -0,0 +1,14 @@ +english, shortbread, lager, whiskey, porridge, football +1, 0, 0, 1, 1, 1 +1, 1, 0, 1, 1, 0 +1, 1, 1, 0, 0, 1 +1, 1, 1, 0, 0, 0 +1, 0, 1, 0, 0, 1 +1, 0, 0, 0, 1, 0 +0, 1, 0, 0, 1, 1 +0, 1, 1, 0, 0, 1 +0, 1, 1, 1, 1, 0 +0, 1, 1, 0, 1, 0 +0, 1, 1, 0, 1, 1 +0, 1, 0, 1, 1, 0 +0, 1, 0, 1, 0, 0 \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java new file mode 100644 index 0000000..f6b947b --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.naivebayes.discrete; + +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.junit.Assert; +import org.junit.Test; + +/** Tests for {@code DiscreteNaiveBayesModel} */ +public class DiscreteNaiveBayesModelTest { + /** */ + @Test + public void testPredictWithTwoClasses() { + double first = 1; + double second = 2; + double[][][] probabilities = new double[][][] { + {{.5, .5}, {.2, .3, .5}, {2. / 3., 1. / 3.}, {.4, .1, .5}, {.5, .5}}, + {{0, 1}, {1. / 7, 2. / 7, 4. / 7}, {4. / 7, 3. / 7}, {2. / 7, 3. / 7, 2. / 7}, {4. / 7, 3. / 7,}} + }; + + double[] classProbabilities = new double[] {6. / 13, 7. / 13}; + double[][] thresholds = new double[][] {{.5}, {.2, .7}, {.5}, {.5, 1.5}, {.5}}; + DiscreteNaiveBayesModel mdl = new DiscreteNaiveBayesModel(probabilities, classProbabilities, new double[] {first, second}, thresholds, new DiscreteNaiveBayesSumsHolder()); + Vector observation = VectorUtils.of(2, 0, 1, 2, 0); + + Assert.assertEquals(second, mdl.apply(observation), 0.0001); + } + +} http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTest.java new file mode 100644 index 0000000..25fb37b --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.naivebayes.discrete; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.junit.Assert; +import org.junit.Test; + +/** + * Integration tests for Bernoulli naive Bayes algorithm with different datasets. + */ +public class DiscreteNaiveBayesTest { + /** Precision in test checks. */ + private static final double PRECISION = 1e-2; + + /** Example from book Barber D. Bayesian reasoning and machine learning. Chapter 10. */ + @Test + public void testLearnsAndPredictCorrently() { + double english = 1.; + double scottish = 2.; + Map<Integer, double[]> data = new HashMap<>(); + data.put(0, new double[] {0, 0, 1, 1, 1, english}); + data.put(1, new double[] {1, 0, 1, 1, 0, english}); + data.put(2, new double[] {1, 1, 0, 0, 1, english}); + data.put(3, new double[] {1, 1, 0, 0, 0, english}); + data.put(4, new double[] {0, 1, 0, 0, 1, english}); + data.put(5, new double[] {0, 0, 0, 1, 0, english}); + data.put(6, new double[] {1, 0, 0, 1, 1, scottish}); + data.put(7, new double[] {1, 1, 0, 0, 1, scottish}); + data.put(8, new double[] {1, 1, 1, 1, 0, scottish}); + data.put(9, new double[] {1, 1, 0, 1, 0, scottish}); + data.put(10, new double[] {1, 1, 0, 1, 1, scottish}); + data.put(11, new double[] {1, 0, 1, 1, 0, scottish}); + data.put(12, new double[] {1, 0, 1, 0, 0, scottish}); + double[][] thresholds = new double[][] {{.5}, {.5}, {.5}, {.5}, {.5}}; + DiscreteNaiveBayesTrainer trainer = new DiscreteNaiveBayesTrainer().setBucketThresholds(thresholds); + + DiscreteNaiveBayesModel model = trainer.fit( + new LocalDatasetBuilder<>(data, 2), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[v.length - 1] + ); + Vector observation = VectorUtils.of(1, 0, 1, 1, 0); + + Assert.assertEquals(scottish, model.apply(observation), PRECISION); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java new file mode 100644 index 0000000..3ffd5cf --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.naivebayes.discrete; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.common.TrainerTest; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +/** Test for {@link DiscreteNaiveBayesTrainer} */ +public class DiscreteNaiveBayesTrainerTest extends TrainerTest { + + /** Precision in test checks. */ + private static final double PRECISION = 1e-2; + /** */ + private static final double LABEL_1 = 1.; + /** */ + private static final double LABEL_2 = 2.; + + /** Binary data. */ + private static final Map<Integer, double[]> binarizedData = new HashMap<>(); + /** Data. */ + private static final Map<Integer, double[]> data = new HashMap<>(); + /** */ + private static final double[][] binarizedDatathresholds = new double[][] {{.5}, {.5}, {.5}, {.5}, {.5}}; + private static final double[][] thresholds = new double[][] {{4, 8}, {.5}, {.3, .4, .5}, {250, 500, 750}}; + + static { + binarizedData.put(0, new double[] {0, 0, 1, 1, 1, LABEL_1}); + binarizedData.put(1, new double[] {1, 0, 1, 1, 0, LABEL_1}); + binarizedData.put(2, new double[] {1, 1, 0, 0, 1, LABEL_1}); + binarizedData.put(3, new double[] {1, 1, 0, 0, 0, LABEL_1}); + binarizedData.put(4, new double[] {0, 1, 0, 0, 1, LABEL_1}); + binarizedData.put(5, new double[] {0, 0, 0, 1, 0, LABEL_1}); + + binarizedData.put(6, new double[] {1, 0, 0, 1, 1, LABEL_2}); + binarizedData.put(7, new double[] {1, 1, 0, 0, 1, LABEL_2}); + binarizedData.put(8, new double[] {1, 1, 1, 1, 0, LABEL_2}); + binarizedData.put(9, new double[] {1, 1, 0, 1, 0, LABEL_2}); + binarizedData.put(10, new double[] {1, 1, 0, 1, 1, LABEL_2}); + binarizedData.put(11, new double[] {1, 0, 1, 1, 0, LABEL_2}); + binarizedData.put(12, new double[] {1, 0, 1, 0, 0, LABEL_2}); + + data.put(0, new double[] {2, 0, .34, 123, LABEL_1}); + data.put(1, new double[] {8, 0, .37, 561, LABEL_1}); + data.put(2, new double[] {5, 1, .01, 678, LABEL_1}); + data.put(3, new double[] {2, 1, .32, 453, LABEL_1}); + data.put(4, new double[] {7, 1, .67, 980, LABEL_1}); + data.put(5, new double[] {2, 1, .69, 912, LABEL_1}); + data.put(6, new double[] {8, 0, .43, 453, LABEL_1}); + data.put(7, new double[] {2, 0, .45, 752, LABEL_1}); + data.put(8, new double[] {7, 1, .01, 132, LABEL_2}); + data.put(9, new double[] {2, 1, .68, 169, LABEL_2}); + data.put(10, new double[] {8, 0, .43, 453, LABEL_2}); + data.put(11, new double[] {2, 1, .45, 748, LABEL_2}); + } + + /** */ + private DiscreteNaiveBayesTrainer trainer; + + /** Initialization {@code DiscreteNaiveBayesTrainer}. */ + @Before + public void createTrainer() { + trainer = new DiscreteNaiveBayesTrainer().setBucketThresholds(binarizedDatathresholds); + } + + /** */ + @Test + public void testReturnsCorrectLabelProbalities() { + + DiscreteNaiveBayesModel model = trainer.fit( + new LocalDatasetBuilder<>(binarizedData, parts), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[v.length - 1] + ); + + double[] expectedProbabilities = {6. / binarizedData.size(), 7. / binarizedData.size()}; + Assert.assertArrayEquals(expectedProbabilities, model.getClsProbabilities(), PRECISION); + } + + /** */ + @Test + public void testReturnsEquivalentProbalitiesWhenSetEquiprobableClasses_() { + DiscreteNaiveBayesTrainer trainer = new DiscreteNaiveBayesTrainer() + .setBucketThresholds(binarizedDatathresholds) + .withEquiprobableClasses(); + + DiscreteNaiveBayesModel model = trainer.fit( + new LocalDatasetBuilder<>(binarizedData, parts), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[v.length - 1] + ); + + Assert.assertArrayEquals(new double[] {.5, .5}, model.getClsProbabilities(), PRECISION); + } + + /** */ + @Test + public void testReturnsPresetProbalitiesWhenSetPriorProbabilities() { + double[] priorProbabilities = new double[] {.35, .65}; + DiscreteNaiveBayesTrainer trainer = new DiscreteNaiveBayesTrainer() + .setBucketThresholds(binarizedDatathresholds) + .setPriorProbabilities(priorProbabilities); + + DiscreteNaiveBayesModel model = trainer.fit( + new LocalDatasetBuilder<>(binarizedData, parts), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[v.length - 1] + ); + + Assert.assertArrayEquals(priorProbabilities, model.getClsProbabilities(), PRECISION); + } + + /** */ + @Test + public void testReturnsCorrectPriorProbabilities() { + double[][][] expectedPriorProbabilites = new double[][][] { + {{.5, .5}, {.5, .5}, {2. / 3., 1. / 3.}, {.5, .5}, {.5, .5}}, + {{0, 1}, {3. / 7, 4. / 7}, {4. / 7, 3. / 7}, {2. / 7, 5. / 7}, {4. / 7, 3. / 7,}} + }; + + DiscreteNaiveBayesModel model = trainer.fit( + new LocalDatasetBuilder<>(binarizedData, parts), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[v.length - 1] + ); + + for (int i = 0; i < expectedPriorProbabilites.length; i++) { + for (int j = 0; j < expectedPriorProbabilites[i].length; j++) { + Assert.assertArrayEquals(expectedPriorProbabilites[i][j], model.getProbabilities()[i][j], PRECISION); + } + } + } + + /** */ + @Test + public void testReturnsCorrectPriorProbabilitiesWithDefferentThresholds() { + double[][][] expectedPriorProbabilites = new double[][][] { + { + {4. / 8, 2. / 8, 2. / 8}, + {.5, .5}, + {1. / 8, 3. / 8, 2. / 8, 2. / 8}, + {1. / 8, 2. / 8, 2. / 8, 3. / 8}}, + { + {2. / 4, 1. / 4, 1. / 4}, + {1. / 4, 3. / 4}, + {1. / 4, 0, 2. / 4, 1. / 4}, + {2. / 4, 1. / 4, 1. / 4, 0}} + }; + + DiscreteNaiveBayesModel model = trainer + .setBucketThresholds(thresholds) + .fit( + new LocalDatasetBuilder<>(data, parts), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[v.length - 1] + ); + + for (int i = 0; i < expectedPriorProbabilites.length; i++) { + for (int j = 0; j < expectedPriorProbabilites[i].length; j++) { + Assert.assertArrayEquals(expectedPriorProbabilites[i][j], model.getProbabilities()[i][j], PRECISION); + } + } + } + +}
