Repository: ignite Updated Branches: refs/heads/master bcf987a8a -> c4d859c92
IGNITE-8403: [ML] Add Binary Logistic Regression based on partitioned datasets and MLP this closes #3924 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/c4d859c9 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/c4d859c9 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/c4d859c9 Branch: refs/heads/master Commit: c4d859c924e86755bca4b4cbea26eef8d8ff0547 Parents: bcf987a Author: zaleslaw <[email protected]> Authored: Fri Apr 27 13:11:58 2018 +0300 Committer: YuriBabak <[email protected]> Committed: Fri Apr 27 13:11:58 2018 +0300 ---------------------------------------------------------------------- .../LinearRegressionLSQRTrainerExample.java | 4 +- ...sionLSQRTrainerWithNormalizationExample.java | 9 +- .../LinearRegressionSGDTrainerExample.java | 6 +- .../LogisticRegressionSGDTrainerSample.java | 239 +++++++++++++++++++ .../ml/regression/logistic/package-info.java | 22 ++ .../ignite/ml/optimization/LossFunctions.java | 75 ++++++ .../binomial/LogisticRegressionModel.java | 200 ++++++++++++++++ .../binomial/LogisticRegressionSGDTrainer.java | 111 +++++++++ .../logistic/binomial/package-info.java | 22 ++ .../ml/regressions/logistic/package-info.java | 22 ++ .../ml/regressions/RegressionsTestSuite.java | 10 +- .../logistic/LogisticRegressionModelTest.java | 76 ++++++ .../LogisticRegressionSGDTrainerTest.java | 103 ++++++++ 13 files changed, 884 insertions(+), 15 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/c4d859c9/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java index 276d43f..04d1778 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java @@ -34,7 +34,7 @@ import java.util.Arrays; import java.util.UUID; /** - * Run linear regression model over distributed matrix. + * Run linear regression model over cached dataset. * * @see LinearRegressionLSQRTrainer */ @@ -104,8 +104,6 @@ public class LinearRegressionLSQRTrainerExample { try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread - // because we create ignite cache internally. IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), LinearRegressionLSQRTrainerExample.class.getSimpleName(), () -> { IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); http://git-wip-us.apache.org/repos/asf/ignite/blob/c4d859c9/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithNormalizationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithNormalizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithNormalizationExample.java index 0358f44..6c9273c 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithNormalizationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithNormalizationExample.java @@ -24,7 +24,6 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.cache.query.QueryCursor; import org.apache.ignite.cache.query.ScanQuery; import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessor; @@ -38,7 +37,7 @@ import java.util.Arrays; import java.util.UUID; /** - * Run linear regression model over distributed matrix. + * Run linear regression model over cached dataset. * * @see LinearRegressionLSQRTrainer * @see NormalizationTrainer @@ -105,15 +104,13 @@ public class LinearRegressionLSQRTrainerWithNormalizationExample { /** Run example. */ public static void main(String[] args) throws InterruptedException { System.out.println(); - System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started."); + System.out.println(">>> Linear regression model over cached dataset usage example started."); // Start ignite grid. try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread - // because we create ignite cache internally. IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - SparseDistributedMatrixExample.class.getSimpleName(), () -> { + LinearRegressionLSQRTrainerWithNormalizationExample.class.getSimpleName(), () -> { IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); System.out.println(">>> Create new normalization trainer object."); http://git-wip-us.apache.org/repos/asf/ignite/blob/c4d859c9/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java index ce6ad3b..da5f942 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java @@ -25,11 +25,11 @@ import org.apache.ignite.cache.query.QueryCursor; import org.apache.ignite.cache.query.ScanQuery; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer; -import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.thread.IgniteThread; import javax.cache.Cache; @@ -37,7 +37,7 @@ import java.util.Arrays; import java.util.UUID; /** - * Run linear regression model over distributed matrix. + * Run linear regression model over cached dataset. * * @see LinearRegressionSGDTrainer */ @@ -106,8 +106,6 @@ public class LinearRegressionSGDTrainerExample { // Start ignite grid. try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread - // because we create ignite cache internally. IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), LinearRegressionSGDTrainerExample.class.getSimpleName(), () -> { http://git-wip-us.apache.org/repos/asf/ignite/blob/c4d859c9/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/LogisticRegressionSGDTrainerSample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/LogisticRegressionSGDTrainerSample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/LogisticRegressionSGDTrainerSample.java new file mode 100644 index 0000000..0505ddd --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/LogisticRegressionSGDTrainerSample.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.regression.logistic; + +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.thread.IgniteThread; + +import javax.cache.Cache; +import java.util.Arrays; +import java.util.UUID; + +/** + * Run logistic regression model over distributed cache. + * + * @see LogisticRegressionSGDTrainer + */ +public class LogisticRegressionSGDTrainerSample { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + System.out.println(); + System.out.println(">>> Logistic regression model over partitioned dataset usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + LogisticRegressionSGDTrainerSample.class.getSimpleName(), () -> { + + IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); + + System.out.println(">>> Create new logistic regression trainer object."); + LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ), 100000, 10, 100, 123L); + + System.out.println(">>> Perform the training to get the model."); + LogisticRegressionModel mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> v[0] + ).withRawLabels(true); + + System.out.println(">>> Logistic regression model: " + mdl); + + int amountOfErrors = 0; + int totalAmount = 0; + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0}, {0, 0}}; + + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; + + double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs)); + + totalAmount++; + if(groundTruth != prediction) + amountOfErrors++; + + int idx1 = (int)prediction; + int idx2 = (int)groundTruth; + + confusionMtx[idx1][idx2]++; + + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + } + + System.out.println(">>> ---------------------------------"); + + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); + } + + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + System.out.println(">>> ---------------------------------"); + }); + + igniteThread.start(); + + igniteThread.join(); + } + } + /** + * Fills cache with data and returns it. + * + * @param ignite Ignite instance. + * @return Filled Ignite Cache. + */ + private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) { + CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>(); + cacheConfiguration.setName("TEST_" + UUID.randomUUID()); + cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); + + IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration); + + for (int i = 0; i < data.length; i++) + cache.put(i, data[i]); + + return cache; + } + + + /** The 1st and 2nd classes from the Iris dataset. */ + private static final double[][] data = { + {0, 5.1, 3.5, 1.4, 0.2}, + {0, 4.9, 3, 1.4, 0.2}, + {0, 4.7, 3.2, 1.3, 0.2}, + {0, 4.6, 3.1, 1.5, 0.2}, + {0, 5, 3.6, 1.4, 0.2}, + {0, 5.4, 3.9, 1.7, 0.4}, + {0, 4.6, 3.4, 1.4, 0.3}, + {0, 5, 3.4, 1.5, 0.2}, + {0, 4.4, 2.9, 1.4, 0.2}, + {0, 4.9, 3.1, 1.5, 0.1}, + {0, 5.4, 3.7, 1.5, 0.2}, + {0, 4.8, 3.4, 1.6, 0.2}, + {0, 4.8, 3, 1.4, 0.1}, + {0, 4.3, 3, 1.1, 0.1}, + {0, 5.8, 4, 1.2, 0.2}, + {0, 5.7, 4.4, 1.5, 0.4}, + {0, 5.4, 3.9, 1.3, 0.4}, + {0, 5.1, 3.5, 1.4, 0.3}, + {0, 5.7, 3.8, 1.7, 0.3}, + {0, 5.1, 3.8, 1.5, 0.3}, + {0, 5.4, 3.4, 1.7, 0.2}, + {0, 5.1, 3.7, 1.5, 0.4}, + {0, 4.6, 3.6, 1, 0.2}, + {0, 5.1, 3.3, 1.7, 0.5}, + {0, 4.8, 3.4, 1.9, 0.2}, + {0, 5, 3, 1.6, 0.2}, + {0, 5, 3.4, 1.6, 0.4}, + {0, 5.2, 3.5, 1.5, 0.2}, + {0, 5.2, 3.4, 1.4, 0.2}, + {0, 4.7, 3.2, 1.6, 0.2}, + {0, 4.8, 3.1, 1.6, 0.2}, + {0, 5.4, 3.4, 1.5, 0.4}, + {0, 5.2, 4.1, 1.5, 0.1}, + {0, 5.5, 4.2, 1.4, 0.2}, + {0, 4.9, 3.1, 1.5, 0.1}, + {0, 5, 3.2, 1.2, 0.2}, + {0, 5.5, 3.5, 1.3, 0.2}, + {0, 4.9, 3.1, 1.5, 0.1}, + {0, 4.4, 3, 1.3, 0.2}, + {0, 5.1, 3.4, 1.5, 0.2}, + {0, 5, 3.5, 1.3, 0.3}, + {0, 4.5, 2.3, 1.3, 0.3}, + {0, 4.4, 3.2, 1.3, 0.2}, + {0, 5, 3.5, 1.6, 0.6}, + {0, 5.1, 3.8, 1.9, 0.4}, + {0, 4.8, 3, 1.4, 0.3}, + {0, 5.1, 3.8, 1.6, 0.2}, + {0, 4.6, 3.2, 1.4, 0.2}, + {0, 5.3, 3.7, 1.5, 0.2}, + {0, 5, 3.3, 1.4, 0.2}, + {1, 7, 3.2, 4.7, 1.4}, + {1, 6.4, 3.2, 4.5, 1.5}, + {1, 6.9, 3.1, 4.9, 1.5}, + {1, 5.5, 2.3, 4, 1.3}, + {1, 6.5, 2.8, 4.6, 1.5}, + {1, 5.7, 2.8, 4.5, 1.3}, + {1, 6.3, 3.3, 4.7, 1.6}, + {1, 4.9, 2.4, 3.3, 1}, + {1, 6.6, 2.9, 4.6, 1.3}, + {1, 5.2, 2.7, 3.9, 1.4}, + {1, 5, 2, 3.5, 1}, + {1, 5.9, 3, 4.2, 1.5}, + {1, 6, 2.2, 4, 1}, + {1, 6.1, 2.9, 4.7, 1.4}, + {1, 5.6, 2.9, 3.6, 1.3}, + {1, 6.7, 3.1, 4.4, 1.4}, + {1, 5.6, 3, 4.5, 1.5}, + {1, 5.8, 2.7, 4.1, 1}, + {1, 6.2, 2.2, 4.5, 1.5}, + {1, 5.6, 2.5, 3.9, 1.1}, + {1, 5.9, 3.2, 4.8, 1.8}, + {1, 6.1, 2.8, 4, 1.3}, + {1, 6.3, 2.5, 4.9, 1.5}, + {1, 6.1, 2.8, 4.7, 1.2}, + {1, 6.4, 2.9, 4.3, 1.3}, + {1, 6.6, 3, 4.4, 1.4}, + {1, 6.8, 2.8, 4.8, 1.4}, + {1, 6.7, 3, 5, 1.7}, + {1, 6, 2.9, 4.5, 1.5}, + {1, 5.7, 2.6, 3.5, 1}, + {1, 5.5, 2.4, 3.8, 1.1}, + {1, 5.5, 2.4, 3.7, 1}, + {1, 5.8, 2.7, 3.9, 1.2}, + {1, 6, 2.7, 5.1, 1.6}, + {1, 5.4, 3, 4.5, 1.5}, + {1, 6, 3.4, 4.5, 1.6}, + {1, 6.7, 3.1, 4.7, 1.5}, + {1, 6.3, 2.3, 4.4, 1.3}, + {1, 5.6, 3, 4.1, 1.3}, + {1, 5.5, 2.5, 4, 1.3}, + {1, 5.5, 2.6, 4.4, 1.2}, + {1, 6.1, 3, 4.6, 1.4}, + {1, 5.8, 2.6, 4, 1.2}, + {1, 5, 2.3, 3.3, 1}, + {1, 5.6, 2.7, 4.2, 1.3}, + {1, 5.7, 3, 4.2, 1.2}, + {1, 5.7, 2.9, 4.2, 1.3}, + {1, 6.2, 2.9, 4.3, 1.3}, + {1, 5.1, 2.5, 3, 1.1}, + {1, 5.7, 2.8, 4.1, 1.3}, + }; + +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c4d859c9/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/package-info.java new file mode 100644 index 0000000..cf27a94 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * ML logistic regression examples. + */ +package org.apache.ignite.examples.ml.regression.logistic; http://git-wip-us.apache.org/repos/asf/ignite/blob/c4d859c9/modules/ml/src/main/java/org/apache/ignite/ml/optimization/LossFunctions.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/LossFunctions.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/LossFunctions.java index 13fcb60..a0e8c66 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/LossFunctions.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/LossFunctions.java @@ -44,4 +44,79 @@ public class LossFunctions { }).sum() / (vector.size()); } }; + /** + * Log loss function. + */ + public static IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> LOG = groundTruth -> + new IgniteDifferentiableVectorToDoubleFunction() { + /** {@inheritDoc} */ + @Override public Vector differential(Vector pnt) { + double multiplier = 2.0 / pnt.size(); + return pnt.minus(groundTruth).times(multiplier); + } + + /** {@inheritDoc} */ + @Override public Double apply(Vector vector) { + return groundTruth.copy().map(vector, + (a, b) -> a == 1 ? - Math.log(b) : -Math.log(1 - b) + ).sum(); + } + }; + + /** + * L2 loss function. + */ + public static IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> L2 = groundTruth -> + new IgniteDifferentiableVectorToDoubleFunction() { + /** {@inheritDoc} */ + @Override public Vector differential(Vector pnt) { + double multiplier = 2.0 / pnt.size(); + return pnt.minus(groundTruth).times(multiplier); + } + + /** {@inheritDoc} */ + @Override public Double apply(Vector vector) { + return groundTruth.copy().map(vector, (a, b) -> { + double diff = a - b; + return diff * diff; + }).sum(); + } + }; + + /** + * L1 loss function. + */ + public static IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> L1 = groundTruth -> + new IgniteDifferentiableVectorToDoubleFunction() { + /** {@inheritDoc} */ + @Override public Vector differential(Vector pnt) { + double multiplier = 2.0 / pnt.size(); + return pnt.minus(groundTruth).times(multiplier); + } + + /** {@inheritDoc} */ + @Override public Double apply(Vector vector) { + return groundTruth.copy().map(vector, (a, b) -> { + double diff = a - b; + return Math.abs(diff); + }).sum(); + } + }; + + /** + * Hinge loss function. + */ + public static IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> HINGE = groundTruth -> + new IgniteDifferentiableVectorToDoubleFunction() { + /** {@inheritDoc} */ + @Override public Vector differential(Vector pnt) { + double multiplier = 2.0 / pnt.size(); + return pnt.minus(groundTruth).times(multiplier); + } + + /** {@inheritDoc} */ + @Override public Double apply(Vector vector) { + return Math.max(0, 1 - groundTruth.dot(vector)); + } + }; } http://git-wip-us.apache.org/repos/asf/ignite/blob/c4d859c9/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java new file mode 100644 index 0000000..8ea1490 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.logistic.binomial; + +import org.apache.ignite.ml.Exportable; +import org.apache.ignite.ml.Exporter; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.math.Vector; + +import java.io.Serializable; +import java.util.Objects; + +/** + * Logistic regression (logit model) is a generalized linear model used for binomial regression. + */ +public class LogisticRegressionModel implements Model<Vector, Double>, Exportable<LogisticRegressionModel>, Serializable { + /** */ + private static final long serialVersionUID = -133984600091550776L; + + /** Multiplier of the objects's vector required to make prediction. */ + private Vector weights; + + /** Intercept of the linear regression model. */ + private double intercept; + + /** Output label format. 0 and 1 for false value and raw sigmoid regression value otherwise. */ + private boolean isKeepingRawLabels = false; + + /** Threshold to assign '1' label to the observation if raw value more than this threshold. */ + private double threshold = 0.5; + + /** */ + public LogisticRegressionModel(Vector weights, double intercept) { + this.weights = weights; + this.intercept = intercept; + } + + /** + * Set up the output label format. + * + * @param isKeepingRawLabels The parameter value. + * @return Model with new isKeepingRawLabels parameter value. + */ + public LogisticRegressionModel withRawLabels(boolean isKeepingRawLabels) { + this.isKeepingRawLabels = isKeepingRawLabels; + return this; + } + + /** + * Set up the threshold. + * + * @param threshold The parameter value. + * @return Model with new threshold parameter value. + */ + public LogisticRegressionModel withThreshold(double threshold) { + this.threshold = threshold; + return this; + } + + /** + * Set up the weights. + * + * @param weights The parameter value. + * @return Model with new weights parameter value. + */ + public LogisticRegressionModel withWeights(Vector weights) { + this.weights = weights; + return this; + } + + /** + * Set up the intercept. + * + * @param intercept The parameter value. + * @return Model with new intercept parameter value. + */ + public LogisticRegressionModel withIntercept(double intercept) { + this.intercept = intercept; + return this; + } + + /** + * Gets the output label format mode. + * + * @return The parameter value. + */ + public boolean isKeepingRawLabels() { + return isKeepingRawLabels; + } + + /** + * Gets the threshold. + * + * @return The parameter value. + */ + public double threshold() { + return threshold; + } + + /** + * Gets the weights. + * + * @return The parameter value. + */ + public Vector weights() { + return weights; + } + + /** + * Gets the intercept. + * + * @return The parameter value. + */ + public double intercept() { + return intercept; + } + + /** {@inheritDoc} */ + @Override public Double apply(Vector input) { + final double res = sigmoid(input.dot(weights) + intercept); + + if (isKeepingRawLabels) + return res; + else + return res - threshold > 0 ? 1.0 : 0; + } + + /** + * Sigmoid function. + * @param z The regression value. + * @return The result. + */ + private static double sigmoid(double z) { + return 1.0 / (1.0 + Math.exp(-z)); + } + + /** {@inheritDoc} */ + @Override public <P> void saveModel(Exporter<LogisticRegressionModel, P> exporter, P path) { + exporter.save(this, path); + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + LogisticRegressionModel mdl = (LogisticRegressionModel)o; + + return Double.compare(mdl.intercept, intercept) == 0 + && Double.compare(mdl.threshold, threshold) == 0 + && Boolean.compare(mdl.isKeepingRawLabels, isKeepingRawLabels) == 0 + && Objects.equals(weights, mdl.weights); + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + return Objects.hash(weights, intercept, isKeepingRawLabels, threshold); + } + + /** {@inheritDoc} */ + @Override public String toString() { + if (weights.size() < 20) { + StringBuilder builder = new StringBuilder(); + + for (int i = 0; i < weights.size(); i++) { + double nextItem = i == weights.size() - 1 ? intercept : weights.get(i + 1); + + builder.append(String.format("%.4f", Math.abs(weights.get(i)))) + .append("*x") + .append(i) + .append(nextItem > 0 ? " + " : " - "); + } + + builder.append(String.format("%.4f", Math.abs(intercept))); + return builder.toString(); + } + + return "LogisticRegressionModel{" + + "weights=" + weights + + ", intercept=" + intercept + + '}'; + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/c4d859c9/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java new file mode 100644 index 0000000..8fe57cf --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.logistic.binomial; + +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.nn.Activators; +import org.apache.ignite.ml.nn.MLPTrainer; +import org.apache.ignite.ml.nn.MultilayerPerceptron; +import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.nn.architecture.MLPArchitecture; +import org.apache.ignite.ml.optimization.LossFunctions; +import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; + +import java.io.Serializable; +import java.util.Arrays; + +/** + * Trainer of the logistic regression model based on stochastic gradient descent algorithm. + */ +public class LogisticRegressionSGDTrainer<P extends Serializable> implements SingleLabelDatasetTrainer<LogisticRegressionModel> { + /** Update strategy. */ + private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy; + + /** Max number of iteration. */ + private final int maxIterations; + + /** Batch size. */ + private final int batchSize; + + /** Number of local iterations. */ + private final int locIterations; + + /** Seed for random generator. */ + private final long seed; + + /** + * Constructs a new instance of linear regression SGD trainer. + * + * @param updatesStgy Update strategy. + * @param maxIterations Max number of iteration. + * @param batchSize Batch size. + * @param locIterations Number of local iterations. + * @param seed Seed for random generator. + */ + public LogisticRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations, + int batchSize, int locIterations, long seed) { + this.updatesStgy = updatesStgy; + this.maxIterations = maxIterations; + this.batchSize = batchSize; + this.locIterations = locIterations; + this.seed = seed; + } + + /** {@inheritDoc} */ + @Override public <K, V> LogisticRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> { + + int cols = dataset.compute(data -> { + if (data.getFeatures() == null) + return null; + return data.getFeatures().length / data.getRows(); + }, (a, b) -> a == null ? b : a); + + MLPArchitecture architecture = new MLPArchitecture(cols); + architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID); + + return architecture; + }; + + MLPTrainer<?> trainer = new MLPTrainer<>( + archSupplier, + LossFunctions.L2, + updatesStgy, + maxIterations, + batchSize, + locIterations, + seed + ); + + MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, (k, v) -> new double[]{lbExtractor.apply(k, v)}); + + double[] params = mlp.parameters().getStorage().data(); + + return new LogisticRegressionModel(new DenseLocalOnHeapVector(Arrays.copyOf(params, params.length - 1)), + params[params.length - 1] + ); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c4d859c9/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/package-info.java new file mode 100644 index 0000000..d32b1ee --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains binomial logistic regression. + */ +package org.apache.ignite.ml.regressions.logistic.binomial; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/c4d859c9/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/package-info.java new file mode 100644 index 0000000..b1f8331 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains various logistic regressions. + */ +package org.apache.ignite.ml.regressions.logistic; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/c4d859c9/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java index 5005ef2..2d21d3b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java @@ -17,7 +17,11 @@ package org.apache.ignite.ml.regressions; -import org.apache.ignite.ml.regressions.linear.*; +import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainerTest; +import org.apache.ignite.ml.regressions.linear.LinearRegressionModelTest; +import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainerTest; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModelTest; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainerTest; import org.junit.runner.RunWith; import org.junit.runners.Suite; @@ -28,7 +32,9 @@ import org.junit.runners.Suite; @Suite.SuiteClasses({ LinearRegressionModelTest.class, LinearRegressionLSQRTrainerTest.class, - LinearRegressionSGDTrainerTest.class + LinearRegressionSGDTrainerTest.class, + LogisticRegressionModelTest.class, + LogisticRegressionSGDTrainerTest.class }) public class RegressionsTestSuite { // No-op. http://git-wip-us.apache.org/repos/asf/ignite/blob/c4d859c9/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java new file mode 100644 index 0000000..1268a7d --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.logistic; + +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.exceptions.CardinalityException; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; +import org.junit.Test; + +/** + * Tests for {@link LogisticRegressionModel}. + */ +public class LogisticRegressionModelTest { + /** */ + private static final double PRECISION = 1e-6; + + /** */ + @Test + public void testPredict() { + Vector weights = new DenseLocalOnHeapVector(new double[]{2.0, 3.0}); + LogisticRegressionModel mdl = new LogisticRegressionModel(weights, 1.0).withRawLabels(true); + + Vector observation = new DenseLocalOnHeapVector(new double[]{1.0, 1.0}); + TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION); + + observation = new DenseLocalOnHeapVector(new double[]{2.0, 1.0}); + TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION); + + observation = new DenseLocalOnHeapVector(new double[]{1.0, 2.0}); + TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 2.0), mdl.apply(observation), PRECISION); + + observation = new DenseLocalOnHeapVector(new double[]{-2.0, 1.0}); + TestUtils.assertEquals(sigmoid(1.0 - 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION); + + observation = new DenseLocalOnHeapVector(new double[]{1.0, -2.0}); + TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 - 3.0 * 2.0), mdl.apply(observation), PRECISION); + } + + /** */ + @Test(expected = CardinalityException.class) + public void testPredictOnAnObservationWithWrongCardinality() { + Vector weights = new DenseLocalOnHeapVector(new double[]{2.0, 3.0}); + + LogisticRegressionModel mdl = new LogisticRegressionModel(weights, 1.0); + + Vector observation = new DenseLocalOnHeapVector(new double[]{1.0}); + + mdl.apply(observation); + } + + /** + * Sigmoid function. + * @param z The regression value. + * @return The result. + */ + private static double sigmoid(double z) { + return 1.0 / (1.0 + Math.exp(-z)); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c4d859c9/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java new file mode 100644 index 0000000..27d3a30e --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.logistic; + +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ThreadLocalRandom; + +/** + * Tests for {@LogisticRegressionSGDTrainer}. + */ +@RunWith(Parameterized.class) +public class LogisticRegressionSGDTrainerTest { + /** Fixed size of Dataset. */ + private static final int AMOUNT_OF_OBSERVATIONS = 1000; + + /** Fixed size of columns in Dataset. */ + private static final int AMOUNT_OF_FEATURES = 2; + + /** Precision in test checks. */ + private static final double PRECISION = 1e-2; + + /** Parameters. */ + @Parameterized.Parameters(name = "Data divided on {0} partitions") + public static Iterable<Integer[]> data() { + return Arrays.asList( + new Integer[] {1}, + new Integer[] {2}, + new Integer[] {3}, + new Integer[] {5}, + new Integer[] {7}, + new Integer[] {100} + ); + } + + /** Number of partitions. */ + @Parameterized.Parameter + public int parts; + + /** + * Test trainer on classification model y = x. + */ + @Test + public void trainWithTheLinearlySeparableCase() { + Map<Integer, double[]> data = new HashMap<>(); + + ThreadLocalRandom rndX = ThreadLocalRandom.current(); + ThreadLocalRandom rndY = ThreadLocalRandom.current(); + + for (int i = 0; i < AMOUNT_OF_OBSERVATIONS; i++) { + double x = rndX.nextDouble(-1000, 1000); + double y = rndY.nextDouble(-1000, 1000); + double[] vec = new double[AMOUNT_OF_FEATURES + 1]; + vec[0] = y - x > 0 ? 1 : 0; // assign label. + vec[1] = x; + vec[2] = y; + data.put(i, vec); + } + + LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ), 100000, 10, 100, 123L); + + LogisticRegressionModel mdl = trainer.fit( + data, + 10, + (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> v[0] + ); + + TestUtils.assertEquals(0, mdl.apply(new DenseLocalOnHeapVector(new double[]{100, 10})), PRECISION); + TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[]{10, 100})), PRECISION); + } +}
