IGNITE-9747: [ML] Add Bernoulli Naive Bayes classifier

This closes #5204


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/00af5e62
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/00af5e62
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/00af5e62

Branch: refs/heads/ignite-10639
Commit: 00af5e62ac7d35ada714ef81d2b7a19d43bb36af
Parents: 32f564c
Author: Ravil Galeyev <[email protected]>
Authored: Tue Dec 18 15:49:06 2018 +0300
Committer: Yury Babak <[email protected]>
Committed: Tue Dec 18 15:49:06 2018 +0300

----------------------------------------------------------------------
 .../DiscreteNaiveBayesTrainerExample.java       | 115 +++++++++
 .../discrete/DiscreteNaiveBayesModel.java       | 127 ++++++++++
 .../discrete/DiscreteNaiveBayesSumsHolder.java  |  54 +++++
 .../discrete/DiscreteNaiveBayesTrainer.java     | 233 +++++++++++++++++++
 .../ml/naivebayes/discrete/package-info.java    |  22 ++
 .../ignite/ml/util/MLSandboxDatasets.java       |   5 +-
 .../english_vs_scottish_binary_dataset.csv      |  14 ++
 .../discrete/DiscreteNaiveBayesModelTest.java   |  45 ++++
 .../discrete/DiscreteNaiveBayesTest.java        |  67 ++++++
 .../discrete/DiscreteNaiveBayesTrainerTest.java | 185 +++++++++++++++
 10 files changed, 866 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
----------------------------------------------------------------------
diff --git 
a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
 
b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
new file mode 100644
index 0000000..5af3f69
--- /dev/null
+++ 
b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.naivebayes;
+
+import java.io.FileNotFoundException;
+import java.util.Arrays;
+import javax.cache.Cache;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
+import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer;
+import org.apache.ignite.ml.util.MLSandboxDatasets;
+import org.apache.ignite.ml.util.SandboxMLCache;
+
+/**
+ * Run naive Bayes classification model based on <a 
href=https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes";>
+ * naive Bayes classifier</a> algorithm ({@link DiscreteNaiveBayesTrainer}) 
over distributed cache.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test 
data points.
+ * </p>
+ * <p>
+ * After that it trains the Discrete naive Bayes classification model based on 
the specified data.</p>
+ * <p>
+ * Finally, this example loops over the test set of data points, applies the 
trained model to predict the target value,
+ * compares prediction to expected outcome (ground truth), and builds
+ * <a href="https://en.wikipedia.org/wiki/Confusion_matrix";>confusion 
matrix</a>.</p>
+ * <p>
+ * You can change the test data used in this example and re-run it to explore 
this algorithm further.</p>
+ */
+public class DiscreteNaiveBayesTrainerExample {
+    /** Run example. */
+    public static void main(String[] args) throws FileNotFoundException {
+        System.out.println();
+        System.out.println(">>> Discrete naive Bayes classification model over 
partitioned dataset usage example started.");
+        // Start ignite grid.
+        try (Ignite ignite = 
Ignition.start("examples/config/example-ignite.xml")) {
+            System.out.println(">>> Ignite grid started.");
+
+            IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+                .fillCacheWith(MLSandboxDatasets.ENGLISH_VS_SCOTTISH);
+
+            double[][] thresholds = new double[][] {{.5}, {.5}, {.5}, {.5}, 
{.5}};
+            System.out.println(">>> Create new Discrete naive Bayes 
classification trainer object.");
+            DiscreteNaiveBayesTrainer trainer = new DiscreteNaiveBayesTrainer()
+                .setBucketThresholds(thresholds);
+
+            System.out.println(">>> Perform the training to get the model.");
+            DiscreteNaiveBayesModel mdl = trainer.fit(
+                ignite,
+                dataCache,
+                (k, v) -> v.copyOfRange(1, v.size()),
+                (k, v) -> v.get(0)
+            );
+
+            System.out.println(">>> Discrete Naive Bayes model: " + mdl);
+
+            int amountOfErrors = 0;
+            int totalAmount = 0;
+
+            // Build confusion matrix. See 
https://en.wikipedia.org/wiki/Confusion_matrix
+            int[][] confusionMtx = {{0, 0}, {0, 0}};
+
+            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = 
dataCache.query(new ScanQuery<>())) {
+                for (Cache.Entry<Integer, Vector> observation : observations) {
+                    Vector val = observation.getValue();
+                    Vector inputs = val.copyOfRange(1, val.size());
+                    double groundTruth = val.get(0);
+
+                    double prediction = mdl.apply(inputs);
+
+                    totalAmount++;
+                    if (groundTruth != prediction)
+                        amountOfErrors++;
+
+                    int idx1 = (int)prediction;
+                    int idx2 = (int)groundTruth;
+
+                    confusionMtx[idx1][idx2]++;
+
+                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", 
prediction, groundTruth);
+                }
+
+                System.out.println(">>> ---------------------------------");
+
+                System.out.println("\n>>> Absolute amount of errors " + 
amountOfErrors);
+                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / 
(double)totalAmount));
+            }
+
+            System.out.println("\n>>> Confusion matrix is " + 
Arrays.deepToString(confusionMtx));
+            System.out.println(">>> ---------------------------------");
+
+            System.out.println(">>> Discrete Naive bayes model over 
partitioned dataset usage example completed.");
+        }
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java
new file mode 100644
index 0000000..7ab2957
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.discrete;
+
+import java.io.Serializable;
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.Exporter;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * Discrete naive Bayes model which predicts result value {@code y} belongs to 
a class {@code C_k, k in [0..K]} as
+ * {@code p(C_k,y) =x_1*p_k1^x *...*x_i*p_ki^x_i}. Where {@code x_i} is a 
discrete feature, {@code p_ki} is a prior
+ * probability probability of class {@code p(x|C_k)}. Returns the number of 
the most possible class.
+ */
+public class DiscreteNaiveBayesModel implements Model<Vector, Double>, 
Exportable<DiscreteNaiveBayesModel>, Serializable {
+    /** */
+    private static final long serialVersionUID = -127386523291350345L;
+    /**
+     * Probabilities of features for all classes for each label. {@code 
labels[c][f][b]} contains a probability for
+     * class {@code c} for feature {@code f} for bucket {@code b}.
+     */
+    private final double[][][] probabilities;
+    /** Prior probabilities of each class */
+    private final double[] clsProbabilities;
+    /** Labels. */
+    private final double[] labels;
+    /**
+     * The bucket thresholds to convert a features to discrete values. {@code 
bucketThresholds[f][b]} contains the right
+     * border for feature {@code f} for bucket {@code b}. Everything which is 
above the last thresdold goes to the next
+     * bucket.
+     */
+    private final double[][] bucketThresholds;
+    /** Amount values in each buckek for each feature per label. */
+    private final DiscreteNaiveBayesSumsHolder sumsHolder;
+
+    /**
+     * @param probabilities Probabilities of features for classes.
+     * @param clsProbabilities Prior probabilities for classes.
+     * @param bucketThresholds The threshold to convert a feature to a binary 
value.
+     * @param sumsHolder Amount values which are abouve the threshold per 
label.
+     * @param labels Labels.
+     */
+    public DiscreteNaiveBayesModel(double[][][] probabilities, double[] 
clsProbabilities, double[] labels,
+        double[][] bucketThresholds, DiscreteNaiveBayesSumsHolder sumsHolder) {
+        this.probabilities = probabilities;
+        this.clsProbabilities = clsProbabilities;
+        this.labels = labels;
+        this.bucketThresholds = bucketThresholds;
+        this.sumsHolder = sumsHolder;
+    }
+
+    /** {@inheritDoc} */
+    @Override public <P> void saveModel(Exporter<DiscreteNaiveBayesModel, P> 
exporter, P path) {
+        exporter.save(this, path);
+    }
+
+    /**
+     * @param vector features vector.
+     * @return a label with max probability.
+     */
+    @Override public Double apply(Vector vector) {
+        double maxProbapilityPower = -Double.MAX_VALUE;
+        int maxLabelIndex = -1;
+
+        for (int i = 0; i < clsProbabilities.length; i++) {
+            double probabilityPower = Math.log(clsProbabilities[i]);
+
+            for (int j = 0; j < probabilities[0].length; j++) {
+                int x = toBucketNumber(vector.get(j), bucketThresholds[j]);
+                double p = probabilities[i][j][x];
+                probabilityPower += (p > 0 ? Math.log(p) : .0);
+            }
+
+            if (probabilityPower > maxProbapilityPower) {
+                maxLabelIndex = i;
+                maxProbapilityPower = probabilityPower;
+            }
+        }
+        return labels[maxLabelIndex];
+    }
+
+    /** */
+    public double[][][] getProbabilities() {
+        return probabilities;
+    }
+
+    /** */
+    public double[] getClsProbabilities() {
+        return clsProbabilities;
+    }
+
+    /** */
+    public double[][] getBucketThresholds() {
+        return bucketThresholds;
+    }
+
+    /** */
+    public DiscreteNaiveBayesSumsHolder getSumsHolder() {
+        return sumsHolder;
+    }
+
+    /** Returs a bucket number to which the {@code value} corresponds. */
+    private int toBucketNumber(double val, double[] thresholds) {
+        for (int i = 0; i < thresholds.length; i++) {
+            if (val < thresholds[i])
+                return i;
+        }
+
+        return thresholds.length;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java
new file mode 100644
index 0000000..61ac692
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ignite.ml.naivebayes.discrete;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.math.util.MapUtil;
+
+/** Service class is used to calculate amount of values which are below the 
threshold. */
+public class DiscreteNaiveBayesSumsHolder implements AutoCloseable, 
Serializable {
+    /** Serial version uid. */
+    private static final long serialVersionUID = -2059362365851744206L;
+    /** Sums of values correspones to a particular bucket for all features for 
each label */
+    Map<Double, long[][]> valuesInBucketPerLbl = new HashMap<>();
+    /** Rows count for each label */
+    Map<Double, Integer> featureCountersPerLbl = new HashMap<>();
+
+    /** Merge to current */
+    DiscreteNaiveBayesSumsHolder merge(DiscreteNaiveBayesSumsHolder other) {
+        valuesInBucketPerLbl = MapUtil.mergeMaps(valuesInBucketPerLbl, 
other.valuesInBucketPerLbl, this::sum, HashMap::new);
+        featureCountersPerLbl = MapUtil.mergeMaps(featureCountersPerLbl, 
other.featureCountersPerLbl, (i1, i2) -> i1 + i2, HashMap::new);
+        return this;
+    }
+
+    /** In-place operation. Sums {@code arr2} to {@code arr1} element to 
element. */
+    private long[][] sum(long[][] arr1, long[][] arr2) {
+        for (int i = 0; i < arr1.length; i++) {
+            for (int j = 0; j < arr1[i].length; j++)
+                arr1[i][j] += arr2[i][j];
+        }
+
+        return arr1;
+    }
+
+    /** */
+    @Override public void close() {
+        // Do nothing, GC will clean up.
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java
new file mode 100644
index 0000000..0779b84
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.discrete;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+
+/**
+ * Trainer for the Discrete naive Bayes classification model. The trainer 
calculates prior probabilities from the input
+ * dataset. Prior probabilities can be also set by {@code 
setPriorProbabilities} or {@code withEquiprobableClasses}. If
+ * {@code equiprobableClasses} is set, the probalilities of all classes will 
be {@code 1/k}, where {@code k} is classes
+ * count. Also, the trainer converts feature to discrete values by using 
{@code bucketThresholds}.
+ */
+public class DiscreteNaiveBayesTrainer extends 
SingleLabelDatasetTrainer<DiscreteNaiveBayesModel> {
+    /** Precision to compare bucketThresholds. */
+    private static final double PRECISION = 1e-10;
+    /* Preset prior probabilities. */
+    private double[] priorProbabilities;
+    /* Sets equivalent probability for all classes. */
+    private boolean equiprobableClasses;
+    /** The threshold to convert a feature to a discrete value. */
+    private double[][] bucketThresholds;
+
+    /**
+     * Trains model based on the specified data.
+     *
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @return Model.
+     */
+    @Override public <K, V> DiscreteNaiveBayesModel fit(DatasetBuilder<K, V> 
datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
+        return updateModel(null, datasetBuilder, featureExtractor, 
lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(DiscreteNaiveBayesModel mdl) {
+        if (mdl.getBucketThresholds().length != bucketThresholds.length)
+            return false;
+
+        for (int i = 0; i < bucketThresholds.length; i++) {
+            for (int j = 0; i < bucketThresholds[i].length; i++) {
+                if (Math.abs(mdl.getBucketThresholds()[i][j] - 
bucketThresholds[i][j]) > PRECISION)
+                    return false;
+            }
+        }
+
+        return true;
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> DiscreteNaiveBayesModel 
updateModel(DiscreteNaiveBayesModel mdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> 
featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        try (Dataset<EmptyContext, DiscreteNaiveBayesSumsHolder> dataset = 
datasetBuilder.build(
+            envBuilder,
+            (env, upstream, upstreamSize) -> new EmptyContext(),
+            (env, upstream, upstreamSize, ctx) -> {
+                DiscreteNaiveBayesSumsHolder res = new 
DiscreteNaiveBayesSumsHolder();
+                while (upstream.hasNext()) {
+                    UpstreamEntry<K, V> entity = upstream.next();
+
+                    Vector features = featureExtractor.apply(entity.getKey(), 
entity.getValue());
+                    Double lb = lbExtractor.apply(entity.getKey(), 
entity.getValue());
+
+                    long[][] valuesInBucket;
+
+                    int size = features.size();
+                    if (!res.valuesInBucketPerLbl.containsKey(lb)) {
+                        valuesInBucket = new long[size][];
+                        for (int i = 0; i < size; i++) {
+                            valuesInBucket[i] = new 
long[bucketThresholds[i].length + 1];
+                            Arrays.fill(valuesInBucket[i], 0L);
+                        }
+                        res.valuesInBucketPerLbl.put(lb, valuesInBucket);
+                    }
+
+                    if (!res.featureCountersPerLbl.containsKey(lb))
+                        res.featureCountersPerLbl.put(lb, 0);
+
+                    res.featureCountersPerLbl.put(lb, 
res.featureCountersPerLbl.get(lb) + 1);
+
+                    valuesInBucket = res.valuesInBucketPerLbl.get(lb);
+
+                    for (int j = 0; j < size; j++) {
+                        double x = features.get(j);
+                        int bucketNum = toBucketNumber(x, bucketThresholds[j]);
+                        valuesInBucket[j][bucketNum] += 1;
+                    }
+                }
+                return res;
+            })) {
+                DiscreteNaiveBayesSumsHolder sumsHolder = dataset.compute(t -> 
t, (a, b) -> {
+                    if (a == null)
+                        return b == null ? new DiscreteNaiveBayesSumsHolder() 
: b;
+                    if (b == null)
+                        return a;
+                    return a.merge(b);
+                });
+
+                if (mdl != null && checkState(mdl)) {
+                    if (checkSumsHolder(sumsHolder, mdl.getSumsHolder()))
+                        sumsHolder = sumsHolder.merge(mdl.getSumsHolder());
+                }
+
+                List<Double> sortedLabels = new 
ArrayList<>(sumsHolder.featureCountersPerLbl.keySet());
+                sortedLabels.sort(Double::compareTo);
+                assert !sortedLabels.isEmpty() : "The dataset should contain 
at least one feature";
+
+                int lbCnt = sortedLabels.size();
+                int featureCnt = 
sumsHolder.valuesInBucketPerLbl.get(sortedLabels.get(0)).length;
+
+                double[][][] probabilities = new double[lbCnt][featureCnt][];
+                double[] classProbabilities = new double[lbCnt];
+                double[] labels = new double[lbCnt];
+                long datasetSize = 
sumsHolder.featureCountersPerLbl.values().stream().mapToInt(i -> i).sum();
+
+                int lbl = 0;
+
+                for (Double label : sortedLabels) {
+                    int cnt = sumsHolder.featureCountersPerLbl.get(label);
+                    long[][] sum = sumsHolder.valuesInBucketPerLbl.get(label);
+
+                    for (int i = 0; i < featureCnt; i++) {
+
+                        int bucketsCnt = sum[i].length;
+                        probabilities[lbl][i] = new double[bucketsCnt];
+
+                        for (int j = 0; j < bucketsCnt; j++)
+                            probabilities[lbl][i][j] = (double)sum[i][j] / cnt;
+                    }
+
+                    if (equiprobableClasses)
+                        classProbabilities[lbl] = 1. / lbCnt;
+                    else if (priorProbabilities != null) {
+                        assert classProbabilities.length == 
priorProbabilities.length;
+                        classProbabilities[lbl] = priorProbabilities[lbl];
+                    }
+                    else
+                        classProbabilities[lbl] = (double)cnt / datasetSize;
+
+                    labels[lbl] = label;
+                    ++lbl;
+                }
+            return new DiscreteNaiveBayesModel(probabilities, 
classProbabilities, labels, bucketThresholds, sumsHolder);
+        }
+        catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+
+    }
+
+    /** Checks that two {@code DiscreteNaiveBayesSumsHolder} contain the same 
lengths of future vectors. */
+    private boolean checkSumsHolder(DiscreteNaiveBayesSumsHolder holder1, 
DiscreteNaiveBayesSumsHolder holder2) {
+        if (holder1 == null || holder2 == null)
+            return false;
+
+        Optional<long[][]> optionalFirst = 
holder1.valuesInBucketPerLbl.values().stream().findFirst();
+        Optional<long[][]> optionalSecond = 
holder2.valuesInBucketPerLbl.values().stream().findFirst();
+
+        if (optionalFirst.isPresent()) {
+            if (optionalSecond.isPresent())
+                return optionalFirst.get().length == 
optionalSecond.get().length;
+            else
+                return false;
+        }
+        else
+            return !optionalSecond.isPresent();
+    }
+
+    /** Sets equal probability for all classes. */
+    public DiscreteNaiveBayesTrainer withEquiprobableClasses() {
+        resetProbabilitiesSettings();
+        equiprobableClasses = true;
+        return this;
+    }
+
+    /** Sets prior probabilities. */
+    public DiscreteNaiveBayesTrainer setPriorProbabilities(double[] 
priorProbabilities) {
+        resetProbabilitiesSettings();
+        this.priorProbabilities = priorProbabilities.clone();
+        return this;
+    }
+
+    /** Sets buckest borders. */
+    public DiscreteNaiveBayesTrainer setBucketThresholds(double[][] 
bucketThresholds) {
+        this.bucketThresholds = bucketThresholds;
+        return this;
+    }
+
+    /** Sets default settings {@code equiprobableClasses} to {@code false} and 
removes priorProbabilities. */
+    public DiscreteNaiveBayesTrainer resetProbabilitiesSettings() {
+        equiprobableClasses = false;
+        priorProbabilities = null;
+        return this;
+    }
+
+    /** Returs a bucket number to which the {@code value} corresponds. */
+    private int toBucketNumber(double val, double[] thresholds) {
+        for (int i = 0; i < thresholds.length; i++) {
+            if (val < thresholds[i])
+                return i;
+        }
+
+        return thresholds.length;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/package-info.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/package-info.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/package-info.java
new file mode 100644
index 0000000..092d22d
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains Bernoulli naive Bayes classifier.
+ */
+package org.apache.ignite.ml.naivebayes.discrete;
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java 
b/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java
index 84ae8f6..b8d5fb4 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java
@@ -50,7 +50,10 @@ public enum MLSandboxDatasets {
     WINE_RECOGNITION("modules/ml/src/main/resources/datasets/wine.txt", false, 
","),
 
     /** The Boston house-prices dataset. Could be found <a 
href="https://archive.ics.uci.edu/ml/machine-learning-databases/housing/";>here</a>.
 */
-    
BOSTON_HOUSE_PRICES("modules/ml/src/main/resources/datasets/boston_housing_dataset.txt",
 false, ",");
+    
BOSTON_HOUSE_PRICES("modules/ml/src/main/resources/datasets/boston_housing_dataset.txt",
 false, ","),
+
+    /** Example from book Barber D. Bayesian reasoning and machine learning. 
Chapter 10. */
+    
ENGLISH_VS_SCOTTISH("modules/ml/src/main/resources/datasets/english_vs_scottish_binary_dataset.csv",
 true, ",");
 
     /** Filename. */
     private final String filename;

http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/main/resources/datasets/english_vs_scottish_binary_dataset.csv
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/resources/datasets/english_vs_scottish_binary_dataset.csv 
b/modules/ml/src/main/resources/datasets/english_vs_scottish_binary_dataset.csv
new file mode 100644
index 0000000..0fa0f8b
--- /dev/null
+++ 
b/modules/ml/src/main/resources/datasets/english_vs_scottish_binary_dataset.csv
@@ -0,0 +1,14 @@
+english, shortbread, lager, whiskey, porridge, football
+1, 0, 0, 1, 1, 1
+1, 1, 0, 1, 1, 0
+1, 1, 1, 0, 0, 1
+1, 1, 1, 0, 0, 0
+1, 0, 1, 0, 0, 1
+1, 0, 0, 0, 1, 0
+0, 1, 0, 0, 1, 1
+0, 1, 1, 0, 0, 1
+0, 1, 1, 1, 1, 0
+0, 1, 1, 0, 1, 0
+0, 1, 1, 0, 1, 1
+0, 1, 0, 1, 1, 0
+0, 1, 0, 1, 0, 0
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java
new file mode 100644
index 0000000..f6b947b
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.discrete;
+
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/** Tests for {@code DiscreteNaiveBayesModel} */
+public class DiscreteNaiveBayesModelTest {
+    /** */
+    @Test
+    public void testPredictWithTwoClasses() {
+        double first = 1;
+        double second = 2;
+        double[][][] probabilities = new double[][][] {
+            {{.5, .5}, {.2, .3, .5}, {2. / 3., 1. / 3.}, {.4, .1, .5}, {.5, 
.5}},
+            {{0, 1}, {1. / 7, 2. / 7, 4. / 7}, {4. / 7, 3. / 7}, {2. / 7, 3. / 
7, 2. / 7}, {4. / 7, 3. / 7,}}
+        };
+
+        double[] classProbabilities = new double[] {6. / 13, 7. / 13};
+        double[][] thresholds = new double[][] {{.5}, {.2, .7}, {.5}, {.5, 
1.5}, {.5}};
+        DiscreteNaiveBayesModel mdl = new 
DiscreteNaiveBayesModel(probabilities, classProbabilities, new double[] {first, 
second}, thresholds, new DiscreteNaiveBayesSumsHolder());
+        Vector observation = VectorUtils.of(2, 0, 1, 2, 0);
+
+        Assert.assertEquals(second, mdl.apply(observation), 0.0001);
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTest.java
new file mode 100644
index 0000000..25fb37b
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.discrete;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Integration tests for Bernoulli naive Bayes algorithm with different 
datasets.
+ */
+public class DiscreteNaiveBayesTest {
+    /** Precision in test checks. */
+    private static final double PRECISION = 1e-2;
+
+    /** Example from book Barber D. Bayesian reasoning and machine learning. 
Chapter 10. */
+    @Test
+    public void testLearnsAndPredictCorrently() {
+        double english = 1.;
+        double scottish = 2.;
+        Map<Integer, double[]> data = new HashMap<>();
+        data.put(0, new double[] {0, 0, 1, 1, 1, english});
+        data.put(1, new double[] {1, 0, 1, 1, 0, english});
+        data.put(2, new double[] {1, 1, 0, 0, 1, english});
+        data.put(3, new double[] {1, 1, 0, 0, 0, english});
+        data.put(4, new double[] {0, 1, 0, 0, 1, english});
+        data.put(5, new double[] {0, 0, 0, 1, 0, english});
+        data.put(6, new double[] {1, 0, 0, 1, 1, scottish});
+        data.put(7, new double[] {1, 1, 0, 0, 1, scottish});
+        data.put(8, new double[] {1, 1, 1, 1, 0, scottish});
+        data.put(9, new double[] {1, 1, 0, 1, 0, scottish});
+        data.put(10, new double[] {1, 1, 0, 1, 1, scottish});
+        data.put(11, new double[] {1, 0, 1, 1, 0, scottish});
+        data.put(12, new double[] {1, 0, 1, 0, 0, scottish});
+        double[][] thresholds = new double[][] {{.5}, {.5}, {.5}, {.5}, {.5}};
+        DiscreteNaiveBayesTrainer trainer = new 
DiscreteNaiveBayesTrainer().setBucketThresholds(thresholds);
+
+        DiscreteNaiveBayesModel model = trainer.fit(
+            new LocalDatasetBuilder<>(data, 2),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[v.length - 1]
+        );
+        Vector observation = VectorUtils.of(1, 0, 1, 1, 0);
+
+        Assert.assertEquals(scottish, model.apply(observation), PRECISION);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/00af5e62/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java
new file mode 100644
index 0000000..3ffd5cf
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ignite.ml.naivebayes.discrete;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+/** Test for {@link DiscreteNaiveBayesTrainer} */
+public class DiscreteNaiveBayesTrainerTest extends TrainerTest {
+
+    /** Precision in test checks. */
+    private static final double PRECISION = 1e-2;
+    /** */
+    private static final double LABEL_1 = 1.;
+    /** */
+    private static final double LABEL_2 = 2.;
+
+    /** Binary data. */
+    private static final Map<Integer, double[]> binarizedData = new 
HashMap<>();
+    /** Data. */
+    private static final Map<Integer, double[]> data = new HashMap<>();
+    /** */
+    private static final double[][] binarizedDatathresholds = new double[][] 
{{.5}, {.5}, {.5}, {.5}, {.5}};
+    private static final double[][] thresholds = new double[][] {{4, 8}, {.5}, 
{.3, .4, .5}, {250, 500, 750}};
+
+    static {
+        binarizedData.put(0, new double[] {0, 0, 1, 1, 1, LABEL_1});
+        binarizedData.put(1, new double[] {1, 0, 1, 1, 0, LABEL_1});
+        binarizedData.put(2, new double[] {1, 1, 0, 0, 1, LABEL_1});
+        binarizedData.put(3, new double[] {1, 1, 0, 0, 0, LABEL_1});
+        binarizedData.put(4, new double[] {0, 1, 0, 0, 1, LABEL_1});
+        binarizedData.put(5, new double[] {0, 0, 0, 1, 0, LABEL_1});
+
+        binarizedData.put(6, new double[] {1, 0, 0, 1, 1, LABEL_2});
+        binarizedData.put(7, new double[] {1, 1, 0, 0, 1, LABEL_2});
+        binarizedData.put(8, new double[] {1, 1, 1, 1, 0, LABEL_2});
+        binarizedData.put(9, new double[] {1, 1, 0, 1, 0, LABEL_2});
+        binarizedData.put(10, new double[] {1, 1, 0, 1, 1, LABEL_2});
+        binarizedData.put(11, new double[] {1, 0, 1, 1, 0, LABEL_2});
+        binarizedData.put(12, new double[] {1, 0, 1, 0, 0, LABEL_2});
+
+        data.put(0, new double[] {2, 0, .34, 123, LABEL_1});
+        data.put(1, new double[] {8, 0, .37, 561, LABEL_1});
+        data.put(2, new double[] {5, 1, .01, 678, LABEL_1});
+        data.put(3, new double[] {2, 1, .32, 453, LABEL_1});
+        data.put(4, new double[] {7, 1, .67, 980, LABEL_1});
+        data.put(5, new double[] {2, 1, .69, 912, LABEL_1});
+        data.put(6, new double[] {8, 0, .43, 453, LABEL_1});
+        data.put(7, new double[] {2, 0, .45, 752, LABEL_1});
+        data.put(8, new double[] {7, 1, .01, 132, LABEL_2});
+        data.put(9, new double[] {2, 1, .68, 169, LABEL_2});
+        data.put(10, new double[] {8, 0, .43, 453, LABEL_2});
+        data.put(11, new double[] {2, 1, .45, 748, LABEL_2});
+    }
+
+    /** */
+    private DiscreteNaiveBayesTrainer trainer;
+
+    /** Initialization {@code DiscreteNaiveBayesTrainer}. */
+    @Before
+    public void createTrainer() {
+        trainer = new 
DiscreteNaiveBayesTrainer().setBucketThresholds(binarizedDatathresholds);
+    }
+
+    /** */
+    @Test
+    public void testReturnsCorrectLabelProbalities() {
+
+        DiscreteNaiveBayesModel model = trainer.fit(
+            new LocalDatasetBuilder<>(binarizedData, parts),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[v.length - 1]
+        );
+
+        double[] expectedProbabilities = {6. / binarizedData.size(), 7. / 
binarizedData.size()};
+        Assert.assertArrayEquals(expectedProbabilities, 
model.getClsProbabilities(), PRECISION);
+    }
+
+    /** */
+    @Test
+    public void testReturnsEquivalentProbalitiesWhenSetEquiprobableClasses_() {
+        DiscreteNaiveBayesTrainer trainer = new DiscreteNaiveBayesTrainer()
+            .setBucketThresholds(binarizedDatathresholds)
+            .withEquiprobableClasses();
+
+        DiscreteNaiveBayesModel model = trainer.fit(
+            new LocalDatasetBuilder<>(binarizedData, parts),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[v.length - 1]
+        );
+
+        Assert.assertArrayEquals(new double[] {.5, .5}, 
model.getClsProbabilities(), PRECISION);
+    }
+
+    /** */
+    @Test
+    public void testReturnsPresetProbalitiesWhenSetPriorProbabilities() {
+        double[] priorProbabilities = new double[] {.35, .65};
+        DiscreteNaiveBayesTrainer trainer = new DiscreteNaiveBayesTrainer()
+            .setBucketThresholds(binarizedDatathresholds)
+            .setPriorProbabilities(priorProbabilities);
+
+        DiscreteNaiveBayesModel model = trainer.fit(
+            new LocalDatasetBuilder<>(binarizedData, parts),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[v.length - 1]
+        );
+
+        Assert.assertArrayEquals(priorProbabilities, 
model.getClsProbabilities(), PRECISION);
+    }
+
+    /** */
+    @Test
+    public void testReturnsCorrectPriorProbabilities() {
+        double[][][] expectedPriorProbabilites = new double[][][] {
+            {{.5, .5}, {.5, .5}, {2. / 3., 1. / 3.}, {.5, .5}, {.5, .5}},
+            {{0, 1}, {3. / 7, 4. / 7}, {4. / 7, 3. / 7}, {2. / 7, 5. / 7}, {4. 
/ 7, 3. / 7,}}
+        };
+
+        DiscreteNaiveBayesModel model = trainer.fit(
+            new LocalDatasetBuilder<>(binarizedData, parts),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[v.length - 1]
+        );
+
+        for (int i = 0; i < expectedPriorProbabilites.length; i++) {
+            for (int j = 0; j < expectedPriorProbabilites[i].length; j++) {
+                Assert.assertArrayEquals(expectedPriorProbabilites[i][j], 
model.getProbabilities()[i][j], PRECISION);
+            }
+        }
+    }
+
+    /** */
+    @Test
+    public void testReturnsCorrectPriorProbabilitiesWithDefferentThresholds() {
+        double[][][] expectedPriorProbabilites = new double[][][] {
+            {
+                {4. / 8, 2. / 8, 2. / 8},
+                {.5, .5},
+                {1. / 8, 3. / 8, 2. / 8, 2. / 8},
+                {1. / 8, 2. / 8, 2. / 8, 3. / 8}},
+            {
+                {2. / 4, 1. / 4, 1. / 4},
+                {1. / 4, 3. / 4},
+                {1. / 4, 0, 2. / 4, 1. / 4},
+                {2. / 4, 1. / 4, 1. / 4, 0}}
+        };
+
+        DiscreteNaiveBayesModel model = trainer
+            .setBucketThresholds(thresholds)
+            .fit(
+                new LocalDatasetBuilder<>(data, parts),
+                (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 
1)),
+                (k, v) -> v[v.length - 1]
+            );
+
+        for (int i = 0; i < expectedPriorProbabilites.length; i++) {
+            for (int j = 0; j < expectedPriorProbabilites[i].length; j++) {
+                Assert.assertArrayEquals(expectedPriorProbabilites[i][j], 
model.getProbabilities()[i][j], PRECISION);
+            }
+        }
+    }
+
+}

Reply via email to