IGNITE-9412: [ML] GDB convergence by error support. this closes #4670
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/6225c56e Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/6225c56e Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/6225c56e Branch: refs/heads/master Commit: 6225c56ea70e0af26eee51c7e5e7e53af93386ca Parents: ed6bf5a Author: Alexey Platonov <[email protected]> Authored: Thu Sep 6 12:08:36 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Thu Sep 6 12:08:36 2018 +0300 ---------------------------------------------------------------------- .../GDBOnTreesClassificationTrainerExample.java | 10 +- .../GDBOnTreesRegressionTrainerExample.java | 4 +- .../boosting/GDBBinaryClassifierTrainer.java | 57 +++--- .../boosting/GDBLearningStrategy.java | 115 ++++++++++-- .../boosting/GDBRegressionTrainer.java | 8 +- .../ml/composition/boosting/GDBTrainer.java | 122 ++++++++---- .../LossGradientPerPredictionFunctions.java | 33 ---- .../convergence/ConvergenceChecker.java | 140 ++++++++++++++ .../convergence/ConvergenceCheckerFactory.java | 58 ++++++ .../mean/MeanAbsValueConvergenceChecker.java | 116 ++++++++++++ .../MeanAbsValueConvergenceCheckerFactory.java | 47 +++++ .../boosting/convergence/mean/package-info.java | 22 +++ .../MedianOfMedianConvergenceChecker.java | 126 +++++++++++++ ...MedianOfMedianConvergenceCheckerFactory.java | 47 +++++ .../convergence/median/package-info.java | 22 +++ .../boosting/convergence/package-info.java | 24 +++ .../simple/ConvergenceCheckerStub.java | 79 ++++++++ .../simple/ConvergenceCheckerStubFactory.java | 48 +++++ .../convergence/simple/package-info.java | 24 +++ .../ml/composition/boosting/loss/LogLoss.java | 36 ++++ .../ml/composition/boosting/loss/Loss.java | 45 +++++ .../composition/boosting/loss/SquaredError.java | 36 ++++ .../composition/boosting/loss/package-info.java | 22 +++ .../WeightedPredictionsAggregator.java | 10 + .../FeatureMatrixWithLabelsOnHeapData.java | 57 ++++++ ...eatureMatrixWithLabelsOnHeapDataBuilder.java | 76 ++++++++ .../boosting/GDBOnTreesLearningStrategy.java | 24 ++- .../ignite/ml/tree/data/DecisionTreeData.java | 36 ++-- .../ml/composition/boosting/GDBTrainerTest.java | 81 +++++++- .../convergence/ConvergenceCheckerTest.java | 82 ++++++++ .../MeanAbsValueConvergenceCheckerTest.java | 73 ++++++++ .../MedianOfMedianConvergenceCheckerTest.java | 57 ++++++ .../ml/environment/LearningEnvironmentTest.java | 187 +++---------------- .../ignite/ml/knn/ANNClassificationTest.java | 12 +- 34 files changed, 1608 insertions(+), 328 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java index 075eab2..e092e5c 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java @@ -22,9 +22,8 @@ import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.ml.Model; import org.apache.ignite.ml.composition.ModelsComposition; -import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.trainers.DatasetTrainer; import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer; @@ -59,10 +58,11 @@ public class GDBOnTreesClassificationTrainerExample { IgniteCache<Integer, double[]> trainingSet = fillTrainingData(ignite, trainingSetCfg); // Create regression trainer. - DatasetTrainer<ModelsComposition, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.); + DatasetTrainer<ModelsComposition, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.) + .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.1)); // Train decision tree model. - Model<Vector, Double> mdl = trainer.fit( + ModelsComposition mdl = trainer.fit( ignite, trainingSet, (k, v) -> VectorUtils.of(v[0]), @@ -81,6 +81,8 @@ public class GDBOnTreesClassificationTrainerExample { } System.out.println(">>> ---------------------------------"); + System.out.println(">>> Count of trees = " + mdl.getModels().size()); + System.out.println(">>> ---------------------------------"); System.out.println(">>> GDB classification trainer example completed."); }); http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java index b2b08d0..3662973 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java @@ -24,6 +24,7 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.ml.Model; import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.trainers.DatasetTrainer; @@ -59,7 +60,8 @@ public class GDBOnTreesRegressionTrainerExample { IgniteCache<Integer, double[]> trainingSet = fillTrainingData(ignite, trainingSetCfg); // Create regression trainer. - DatasetTrainer<ModelsComposition, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.); + DatasetTrainer<ModelsComposition, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.) + .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.001)); // Train decision tree model. Model<Vector, Double> mdl = trainer.fit( http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java index 3701557..f6ddfed 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java @@ -19,24 +19,23 @@ package org.apache.ignite.ml.composition.boosting; import java.util.ArrayList; import java.util.Arrays; -import java.util.List; import java.util.Set; import java.util.stream.Collectors; -import org.apache.ignite.internal.util.typedef.internal.A; +import org.apache.ignite.ml.composition.boosting.loss.LogLoss; +import org.apache.ignite.ml.composition.boosting.loss.Loss; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.functions.IgniteTriFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.structures.LabeledVector; import org.apache.ignite.ml.structures.LabeledVectorSet; import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap; /** - * Trainer for binary classifier using Gradient Boosting. - * As preparing stage this algorithm learn labels in dataset and create mapping dataset labels to 0 and 1. - * This algorithm uses gradient of Logarithmic Loss metric [LogLoss] by default in each step of learning. + * Trainer for binary classifier using Gradient Boosting. As preparing stage this algorithm learn labels in dataset and + * create mapping dataset labels to 0 and 1. This algorithm uses gradient of Logarithmic Loss metric [LogLoss] by + * default in each step of learning. */ public abstract class GDBBinaryClassifierTrainer extends GDBTrainer { /** External representation of first class. */ @@ -51,9 +50,7 @@ public abstract class GDBBinaryClassifierTrainer extends GDBTrainer { * @param cntOfIterations Count of learning iterations. */ public GDBBinaryClassifierTrainer(double gradStepSize, Integer cntOfIterations) { - super(gradStepSize, - cntOfIterations, - LossGradientPerPredictionFunctions.LOG_LOSS); + super(gradStepSize, cntOfIterations, new LogLoss()); } /** @@ -61,35 +58,37 @@ public abstract class GDBBinaryClassifierTrainer extends GDBTrainer { * * @param gradStepSize Grad step size. * @param cntOfIterations Count of learning iterations. - * @param lossGradient Gradient of loss function. First argument is sample size, second argument is valid answer, third argument is current model prediction. + * @param loss Loss function. */ - public GDBBinaryClassifierTrainer(double gradStepSize, - Integer cntOfIterations, - IgniteTriFunction<Long, Double, Double, Double> lossGradient) { - - super(gradStepSize, cntOfIterations, lossGradient); + public GDBBinaryClassifierTrainer(double gradStepSize, Integer cntOfIterations, Loss loss) { + super(gradStepSize, cntOfIterations, loss); } /** {@inheritDoc} */ - @Override protected <V, K> void learnLabels(DatasetBuilder<K, V> builder, IgniteBiFunction<K, V, Vector> featureExtractor, + @Override protected <V, K> boolean learnLabels(DatasetBuilder<K, V> builder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lExtractor) { - List<Double> uniqLabels = new ArrayList<Double>( - builder.build(new EmptyContextBuilder<>(), new LabeledDatasetPartitionDataBuilderOnHeap<>(featureExtractor, lExtractor)) - .compute((IgniteFunction<LabeledVectorSet<Double,LabeledVector>, Set<Double>>) x -> + Set<Double> uniqLabels = builder.build(new EmptyContextBuilder<>(), new LabeledDatasetPartitionDataBuilderOnHeap<>(featureExtractor, lExtractor)) + .compute((IgniteFunction<LabeledVectorSet<Double, LabeledVector>, Set<Double>>)x -> Arrays.stream(x.labels()).boxed().collect(Collectors.toSet()), (a, b) -> { - if (a == null) - return b; - if (b == null) - return a; - a.addAll(b); + if (a == null) + return b; + if (b == null) return a; - } - )); + a.addAll(b); + return a; + } + ); - A.ensure(uniqLabels.size() == 2, "Binary classifier expects two types of labels in learning dataset"); - externalFirstCls = uniqLabels.get(0); - externalSecondCls = uniqLabels.get(1); + if (uniqLabels != null && uniqLabels.size() == 2) { + ArrayList<Double> lblsArray = new ArrayList<>(uniqLabels); + externalFirstCls = lblsArray.get(0); + externalSecondCls = lblsArray.get(1); + return true; + } else { + return false; + } } /** {@inheritDoc} */ http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java index 375748a..737495e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java @@ -22,6 +22,10 @@ import java.util.Arrays; import java.util.List; import org.apache.ignite.ml.Model; import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker; +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory; +import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory; +import org.apache.ignite.ml.composition.boosting.loss.Loss; import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.environment.LearningEnvironment; @@ -29,9 +33,9 @@ import org.apache.ignite.ml.environment.logging.MLLogger; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.functions.IgniteSupplier; -import org.apache.ignite.ml.math.functions.IgniteTriFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.trainers.DatasetTrainer; +import org.jetbrains.annotations.NotNull; /** * Learning strategy for gradient boosting. @@ -44,7 +48,7 @@ public class GDBLearningStrategy { protected int cntOfIterations; /** Loss of gradient. */ - protected IgniteTriFunction<Long, Double, Double, Double> lossGradient; + protected Loss loss; /** External label to internal mapping. */ protected IgniteFunction<Double, Double> externalLbToInternalMapping; @@ -61,9 +65,15 @@ public class GDBLearningStrategy { /** Composition weights. */ protected double[] compositionWeights; + /** Check convergence strategy factory. */ + protected ConvergenceCheckerFactory checkConvergenceStgyFactory = new MeanAbsValueConvergenceCheckerFactory(0.001); + + /** Default gradient step size. */ + private double defaultGradStepSize; + /** - * Implementation of gradient boosting iterations. At each step of iterations this algorithm - * build a regression model based on gradient of loss-function for current models composition. + * Implementation of gradient boosting iterations. At each step of iterations this algorithm build a regression + * model based on gradient of loss-function for current models composition. * * @param datasetBuilder Dataset builder. * @param featureExtractor Feature extractor. @@ -73,18 +83,43 @@ public class GDBLearningStrategy { public <K, V> List<Model<Vector, Double>> learnModels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - List<Model<Vector, Double>> models = new ArrayList<>(); + return update(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** + * Gets state of model in arguments, compare it with training parameters of trainer and if they are fit then + * trainer updates model in according to new data and return new model. In other case trains new model. + * + * @param mdlToUpdate Learned model. + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @return Updated models list. + */ + public <K,V> List<Model<Vector, Double>> update(GDBTrainer.GDBModel mdlToUpdate, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + + List<Model<Vector, Double>> models = initLearningState(mdlToUpdate); + + ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize, + externalLbToInternalMapping, loss, datasetBuilder, featureExtractor, lbExtractor); + DatasetTrainer<? extends Model<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get(); for (int i = 0; i < cntOfIterations; i++) { - double[] weights = Arrays.copyOf(compositionWeights, i); + double[] weights = Arrays.copyOf(compositionWeights, models.size()); WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLabelValue); - Model<Vector, Double> currComposition = new ModelsComposition(models, aggregator); + ModelsComposition currComposition = new ModelsComposition(models, aggregator); + if (convCheck.isConverged(datasetBuilder, currComposition)) + break; IgniteBiFunction<K, V, Double> lbExtractorWrap = (k, v) -> { Double realAnswer = externalLbToInternalMapping.apply(lbExtractor.apply(k, v)); Double mdlAnswer = currComposition.apply(featureExtractor.apply(k, v)); - return -lossGradient.apply(sampleSize, realAnswer, mdlAnswer); + return -loss.gradient(sampleSize, realAnswer, mdlAnswer); }; long startTs = System.currentTimeMillis(); @@ -97,6 +132,29 @@ public class GDBLearningStrategy { } /** + * Restores state of already learned model if can and sets learning parameters according to this state. + * + * @param mdlToUpdate Model to update. + * @return list of already learned models. + */ + @NotNull protected List<Model<Vector, Double>> initLearningState(GDBTrainer.GDBModel mdlToUpdate) { + List<Model<Vector, Double>> models = new ArrayList<>(); + if(mdlToUpdate != null) { + models.addAll(mdlToUpdate.getModels()); + WeightedPredictionsAggregator aggregator = (WeightedPredictionsAggregator) mdlToUpdate.getPredictionsAggregator(); + meanLabelValue = aggregator.getBias(); + compositionWeights = new double[models.size() + cntOfIterations]; + for(int i = 0; i < models.size(); i++) + compositionWeights[i] = aggregator.getWeights()[i]; + } else { + compositionWeights = new double[cntOfIterations]; + } + + Arrays.fill(compositionWeights, models.size(), compositionWeights.length, defaultGradStepSize); + return models; + } + + /** * Sets learning environment. * * @param environment Learning Environment. @@ -117,12 +175,12 @@ public class GDBLearningStrategy { } /** - * Sets gradient of loss function. + * Loss function. * - * @param lossGradient Loss gradient. + * @param loss Loss function. */ - public GDBLearningStrategy withLossGradient(IgniteTriFunction<Long, Double, Double, Double> lossGradient) { - this.lossGradient = lossGradient; + public GDBLearningStrategy withLossGradient(Loss loss) { + this.loss = loss; return this; } @@ -141,7 +199,8 @@ public class GDBLearningStrategy { * * @param buildBaseMdlTrainer Build base model trainer. */ - public GDBLearningStrategy withBaseModelTrainerBuilder(IgniteSupplier<DatasetTrainer<? extends Model<Vector, Double>, Double>> buildBaseMdlTrainer) { + public GDBLearningStrategy withBaseModelTrainerBuilder( + IgniteSupplier<DatasetTrainer<? extends Model<Vector, Double>, Double>> buildBaseMdlTrainer) { this.baseMdlTrainerBuilder = buildBaseMdlTrainer; return this; } @@ -175,4 +234,34 @@ public class GDBLearningStrategy { this.compositionWeights = compositionWeights; return this; } + + /** + * Sets CheckConvergenceStgyFactory. + * + * @param factory Factory. + */ + public GDBLearningStrategy withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory) { + this.checkConvergenceStgyFactory = factory; + return this; + } + + /** + * Sets default gradient step size. + * + * @param defaultGradStepSize Default gradient step size. + */ + public GDBLearningStrategy withDefaultGradStepSize(double defaultGradStepSize) { + this.defaultGradStepSize = defaultGradStepSize; + return this; + } + + /** */ + public double[] getCompositionWeights() { + return compositionWeights; + } + + /** */ + public double getMeanValue() { + return meanLabelValue; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java index 201586e..8c1afd7 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java @@ -17,6 +17,7 @@ package org.apache.ignite.ml.composition.boosting; +import org.apache.ignite.ml.composition.boosting.loss.SquaredError; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; @@ -33,15 +34,14 @@ public abstract class GDBRegressionTrainer extends GDBTrainer { * @param cntOfIterations Count of learning iterations. */ public GDBRegressionTrainer(double gradStepSize, Integer cntOfIterations) { - super(gradStepSize, - cntOfIterations, - LossGradientPerPredictionFunctions.MSE); + super(gradStepSize, cntOfIterations, new SquaredError()); } /** {@inheritDoc} */ - @Override protected <V, K> void learnLabels(DatasetBuilder<K, V> builder, IgniteBiFunction<K, V, Vector> featureExtractor, + @Override protected <V, K> boolean learnLabels(DatasetBuilder<K, V> builder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lExtractor) { + return true; } /** {@inheritDoc} */ http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java index c7f21dd..85af798 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java @@ -22,6 +22,9 @@ import java.util.List; import org.apache.ignite.lang.IgniteBiTuple; import org.apache.ignite.ml.Model; import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory; +import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory; +import org.apache.ignite.ml.composition.boosting.loss.Loss; import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; @@ -30,7 +33,7 @@ import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; import org.apache.ignite.ml.environment.logging.MLLogger; import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer; import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.functions.IgniteTriFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer; import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer; @@ -60,24 +63,25 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl private final int cntOfIterations; /** - * Gradient of loss function. First argument is sample size, second argument is valid answer, third argument is - * current model prediction. + * Loss function. */ - protected final IgniteTriFunction<Long, Double, Double, Double> lossGradient; + protected final Loss loss; + + /** Check convergence strategy factory. */ + protected ConvergenceCheckerFactory checkConvergenceStgyFactory = new MeanAbsValueConvergenceCheckerFactory(0.001); /** * Constructs GDBTrainer instance. * * @param gradStepSize Grad step size. * @param cntOfIterations Count of learning iterations. - * @param lossGradient Gradient of loss function. First argument is sample size, second argument is valid answer + * @param loss Gradient of loss function. First argument is sample size, second argument is valid answer * third argument is current model prediction. */ - public GDBTrainer(double gradStepSize, Integer cntOfIterations, - IgniteTriFunction<Long, Double, Double, Double> lossGradient) { + public GDBTrainer(double gradStepSize, Integer cntOfIterations, Loss loss) { gradientStep = gradStepSize; this.cntOfIterations = cntOfIterations; - this.lossGradient = lossGradient; + this.loss = loss; } /** {@inheritDoc} */ @@ -85,53 +89,55 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - learnLabels(datasetBuilder, featureExtractor, lbExtractor); + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override protected <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + if (!learnLabels(datasetBuilder, featureExtractor, lbExtractor)) + return getLastTrainedModelOrThrowEmptyDatasetException(mdl); + + IgniteBiTuple<Double, Long> initAndSampleSize = computeInitialValue(datasetBuilder, featureExtractor, lbExtractor); + if(initAndSampleSize == null) + return getLastTrainedModelOrThrowEmptyDatasetException(mdl); - IgniteBiTuple<Double, Long> initAndSampleSize = computeInitialValue(datasetBuilder, - featureExtractor, lbExtractor); Double mean = initAndSampleSize.get1(); Long sampleSize = initAndSampleSize.get2(); - double[] compositionWeights = new double[cntOfIterations]; - Arrays.fill(compositionWeights, gradientStep); - WeightedPredictionsAggregator resAggregator = new WeightedPredictionsAggregator(compositionWeights, mean); - long learningStartTs = System.currentTimeMillis(); - List<Model<Vector, Double>> models = getLearningStrategy() + GDBLearningStrategy stgy = getLearningStrategy() .withBaseModelTrainerBuilder(this::buildBaseModelTrainer) .withExternalLabelToInternal(this::externalLabelToInternal) .withCntOfIterations(cntOfIterations) - .withCompositionWeights(compositionWeights) .withEnvironment(environment) - .withLossGradient(lossGradient) + .withLossGradient(loss) .withSampleSize(sampleSize) .withMeanLabelValue(mean) - .learnModels(datasetBuilder, featureExtractor, lbExtractor); + .withDefaultGradStepSize(gradientStep) + .withCheckConvergenceStgyFactory(checkConvergenceStgyFactory); + + List<Model<Vector, Double>> models; + if (mdl != null) + models = stgy.update((GDBModel)mdl, datasetBuilder, featureExtractor, lbExtractor); + else + models = stgy.learnModels(datasetBuilder, featureExtractor, lbExtractor); double learningTime = (double)(System.currentTimeMillis() - learningStartTs) / 1000.0; environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", learningTime); - return new ModelsComposition(models, resAggregator) { - @Override public Double apply(Vector features) { - return internalLabelToExternal(super.apply(features)); - } - }; - } - - - //TODO: This method will be implemented in IGNITE-9412 - /** {@inheritDoc} */ - @Override public <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - - throw new UnsupportedOperationException(); + WeightedPredictionsAggregator resAggregator = new WeightedPredictionsAggregator( + stgy.getCompositionWeights(), + stgy.getMeanValue() + ); + return new GDBModel(models, resAggregator, this::internalLabelToExternal); } - //TODO: This method will be implemented in IGNITE-9412 /** {@inheritDoc} */ @Override protected boolean checkState(ModelsComposition mdl) { - throw new UnsupportedOperationException(); + return mdl instanceof GDBModel; } /** @@ -140,8 +146,9 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl * @param builder Dataset builder. * @param featureExtractor Feature extractor. * @param lExtractor Labels extractor. + * @return true if labels learning was successful. */ - protected abstract <V, K> void learnLabels(DatasetBuilder<K, V> builder, + protected abstract <V, K> boolean learnLabels(DatasetBuilder<K, V> builder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lExtractor); /** @@ -196,7 +203,8 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl } ); - meanTuple.set1(meanTuple.get1() / meanTuple.get2()); + if (meanTuple != null) + meanTuple.set1(meanTuple.get1() / meanTuple.get2()); return meanTuple; } catch (Exception e) { @@ -205,6 +213,17 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl } /** + * Sets CheckConvergenceStgyFactory. + * + * @param factory + * @return trainer. + */ + public GDBTrainer withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory) { + this.checkConvergenceStgyFactory = factory; + return this; + } + + /** * Returns learning strategy. * * @return learning strategy. @@ -212,4 +231,33 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl protected GDBLearningStrategy getLearningStrategy() { return new GDBLearningStrategy(); } + + /** */ + public static class GDBModel extends ModelsComposition { + /** Serial version uid. */ + private static final long serialVersionUID = 3476661240155508004L; + + /** Internal to external lbl mapping. */ + private final IgniteFunction<Double, Double> internalToExternalLblMapping; + + /** + * Creates an instance of GDBModel. + * + * @param models Models. + * @param predictionsAggregator Predictions aggregator. + * @param internalToExternalLblMapping Internal to external lbl mapping. + */ + public GDBModel(List<? extends Model<Vector, Double>> models, + WeightedPredictionsAggregator predictionsAggregator, + IgniteFunction<Double, Double> internalToExternalLblMapping) { + + super(models, predictionsAggregator); + this.internalToExternalLblMapping = internalToExternalLblMapping; + } + + /** {@inheritDoc} */ + @Override public Double apply(Vector features) { + return internalToExternalLblMapping.apply(super.apply(features)); + } + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/LossGradientPerPredictionFunctions.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/LossGradientPerPredictionFunctions.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/LossGradientPerPredictionFunctions.java deleted file mode 100644 index 488c0e3..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/LossGradientPerPredictionFunctions.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.composition.boosting; - -import org.apache.ignite.ml.math.functions.IgniteTriFunction; - -/** - * Contains implementations of per-prediction loss functions for gradient boosting algorithm. - */ -public class LossGradientPerPredictionFunctions { - /** Mean squared error loss for regression. */ - public static IgniteTriFunction<Long, Double, Double, Double> MSE = - (sampleSize, answer, prediction) -> (2.0 / sampleSize) * (prediction - answer); - - /** Logarithmic loss for binary classification. */ - public static IgniteTriFunction<Long, Double, Double, Double> LOG_LOSS = - (sampleSize, answer, prediction) -> (prediction - answer) / (prediction * (1.0 - prediction)); -} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java new file mode 100644 index 0000000..3f6e8ca --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting.convergence; + +import java.io.Serializable; +import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.boosting.loss.Loss; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData; +import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder; +import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; + +/** + * Contains logic of error computing and convergence checking for Gradient Boosting algorithms. + * + * @param <K> Type of a key in upstream data. + * @param <V> Type of a value in upstream data. + */ +public abstract class ConvergenceChecker<K, V> implements Serializable { + /** Serial version uid. */ + private static final long serialVersionUID = 710762134746674105L; + + /** Sample size. */ + private long sampleSize; + + /** External label to internal mapping. */ + private IgniteFunction<Double, Double> externalLbToInternalMapping; + + /** Loss function. */ + private Loss loss; + + /** Feature extractor. */ + private IgniteBiFunction<K, V, Vector> featureExtractor; + + /** Label extractor. */ + private IgniteBiFunction<K, V, Double> lbExtractor; + + /** Precision of convergence check. */ + private double precision; + + /** + * Constructs an instance of ConvergenceChecker. + * + * @param sampleSize Sample size. + * @param externalLbToInternalMapping External label to internal mapping. + * @param loss Loss gradient. + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param precision + */ + public ConvergenceChecker(long sampleSize, + IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor, + double precision) { + + assert precision < 1 && precision >= 0; + + this.sampleSize = sampleSize; + this.externalLbToInternalMapping = externalLbToInternalMapping; + this.loss = loss; + this.featureExtractor = featureExtractor; + this.lbExtractor = lbExtractor; + this.precision = precision; + } + + /** + * Checks convergency on dataset. + * + * @param currMdl Current model. + * @return true if GDB is converged. + */ + public boolean isConverged(DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) { + try (Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build( + new EmptyContextBuilder<>(), + new FeatureMatrixWithLabelsOnHeapDataBuilder<>(featureExtractor, lbExtractor) + )) { + return isConverged(dataset, currMdl); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Checks convergency on dataset. + * + * @param dataset Dataset. + * @param currMdl Current model. + * @return true if GDB is converged. + */ + public boolean isConverged(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset, ModelsComposition currMdl) { + Double error = computeMeanErrorOnDataset(dataset, currMdl); + return error < precision || error.isNaN(); + } + + /** + * Compute error for given model on learning dataset. + * + * @param dataset Learning dataset. + * @param mdl Model. + * @return error mean value. + */ + public abstract Double computeMeanErrorOnDataset( + Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset, + ModelsComposition mdl); + + /** + * Compute error for the specific vector of dataset. + * + * @param currMdl Current model. + * @return error. + */ + public double computeError(Vector features, Double answer, ModelsComposition currMdl) { + Double realAnswer = externalLbToInternalMapping.apply(answer); + Double mdlAnswer = currMdl.apply(features); + return -loss.gradient(sampleSize, realAnswer, mdlAnswer); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerFactory.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerFactory.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerFactory.java new file mode 100644 index 0000000..7592f50 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerFactory.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting.convergence; + +import org.apache.ignite.ml.composition.boosting.loss.Loss; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; + +/** + * Factory for ConvergenceChecker. + */ +public abstract class ConvergenceCheckerFactory { + /** Precision of error checking. If error <= precision then it is equated to 0.0*/ + protected double precision; + + /** + * Creates an instance of ConvergenceCheckerFactory. + * + * @param precision Precision [0 <= precision < 1]. + */ + public ConvergenceCheckerFactory(double precision) { + this.precision = precision; + } + + /** + * Create an instance of ConvergenceChecker. + * + * @param sampleSize Sample size. + * @param externalLbToInternalMapping External label to internal mapping. + * @param loss Loss function. + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @return ConvergenceCheckerFactory instance. + */ + public abstract <K,V> ConvergenceChecker<K,V> create(long sampleSize, + IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss, + DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceChecker.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceChecker.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceChecker.java new file mode 100644 index 0000000..7340bfa --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceChecker.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting.convergence.mean; + +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker; +import org.apache.ignite.ml.composition.boosting.loss.Loss; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; + +/** + * Use mean value of errors for estimating error on dataset. + * + * @param <K> Type of a key in upstream data. + * @param <V> Type of a value in upstream data. + */ +public class MeanAbsValueConvergenceChecker<K,V> extends ConvergenceChecker<K,V> { + /** Serial version uid. */ + private static final long serialVersionUID = 8534776439755210864L; + + /** + * Creates an intance of MeanAbsValueConvergenceChecker. + * + * @param sampleSize Sample size. + * @param externalLbToInternalMapping External label to internal mapping. + * @param loss Loss. + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + */ + public MeanAbsValueConvergenceChecker(long sampleSize, IgniteFunction<Double, Double> externalLbToInternalMapping, + Loss loss, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor, + double precision) { + + super(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, featureExtractor, lbExtractor, precision); + } + + /** {@inheritDoc} */ + @Override public Double computeMeanErrorOnDataset(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset, + ModelsComposition mdl) { + + IgniteBiTuple<Double, Long> sumAndCnt = dataset.compute( + partition -> computeStatisticOnPartition(mdl, partition), + this::reduce + ); + + if(sumAndCnt == null || sumAndCnt.getValue() == 0) + return Double.NaN; + return sumAndCnt.getKey() / sumAndCnt.getValue(); + } + + /** + * Compute sum of absolute value of errors and count of rows in partition. + * + * @param mdl Model. + * @param part Partition. + * @return Tuple (sum of errors, count of rows) + */ + private IgniteBiTuple<Double, Long> computeStatisticOnPartition(ModelsComposition mdl, FeatureMatrixWithLabelsOnHeapData part) { + Double sum = 0.0; + + for(int i = 0; i < part.getFeatures().length; i++) { + double error = computeError(VectorUtils.of(part.getFeatures()[i]), part.getLabels()[i], mdl); + sum += Math.abs(error); + } + + return new IgniteBiTuple<>(sum, (long) part.getLabels().length); + } + + /** + * Merge left and right statistics from partitions. + * + * @param left Left. + * @param right Right. + * @return merged value. + */ + private IgniteBiTuple<Double, Long> reduce(IgniteBiTuple<Double, Long> left, IgniteBiTuple<Double, Long> right) { + if (left == null) { + if (right != null) + return right; + else + return new IgniteBiTuple<>(0.0, 0L); + } + + if (right == null) + return left; + + return new IgniteBiTuple<>( + left.getKey() + right.getKey(), + right.getValue() + left.getValue() + ); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerFactory.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerFactory.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerFactory.java new file mode 100644 index 0000000..f02a606 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerFactory.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting.convergence.mean; + +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker; +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory; +import org.apache.ignite.ml.composition.boosting.loss.Loss; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; + +/** + * Factory for {@link MeanAbsValueConvergenceChecker}. + */ +public class MeanAbsValueConvergenceCheckerFactory extends ConvergenceCheckerFactory { + /** + * @param precision Precision. + */ + public MeanAbsValueConvergenceCheckerFactory(double precision) { + super(precision); + } + + /** {@inheritDoc} */ + @Override public <K, V> ConvergenceChecker<K, V> create(long sampleSize, + IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + return new MeanAbsValueConvergenceChecker<>(sampleSize, externalLbToInternalMapping, loss, + datasetBuilder, featureExtractor, lbExtractor, precision); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/package-info.java new file mode 100644 index 0000000..1ab6e66 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains implementation of convergence checking computer by mean of absolute value of errors in dataset. + */ +package org.apache.ignite.ml.composition.boosting.convergence.mean; http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceChecker.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceChecker.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceChecker.java new file mode 100644 index 0000000..7e66a9c --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceChecker.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting.convergence.median; + +import java.util.Arrays; +import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker; +import org.apache.ignite.ml.composition.boosting.loss.Loss; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; + +/** + * Use median of median on partitions value of errors for estimating error on dataset. This algorithm may be less + * sensitive to + * + * @param <K> Type of a key in upstream data. + * @param <V> Type of a value in upstream data. + */ +public class MedianOfMedianConvergenceChecker<K, V> extends ConvergenceChecker<K, V> { + /** Serial version uid. */ + private static final long serialVersionUID = 4902502002933415287L; + + /** + * Creates an instance of MedianOfMedianConvergenceChecker. + * + * @param sampleSize Sample size. + * @param lblMapping External label to internal mapping. + * @param loss Loss function. + * @param datasetBuilder Dataset builder. + * @param fExtr Feature extractor. + * @param lbExtr Label extractor. + * @param precision Precision. + */ + public MedianOfMedianConvergenceChecker(long sampleSize, IgniteFunction<Double, Double> lblMapping, Loss loss, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> fExtr, + IgniteBiFunction<K, V, Double> lbExtr, double precision) { + + super(sampleSize, lblMapping, loss, datasetBuilder, fExtr, lbExtr, precision); + } + + /** {@inheritDoc} */ + @Override public Double computeMeanErrorOnDataset(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset, + ModelsComposition mdl) { + + double[] medians = dataset.compute( + data -> computeMedian(mdl, data), + this::reduce + ); + + if(medians == null) + return Double.POSITIVE_INFINITY; + return getMedian(medians); + } + + /** + * Compute median value on data partition. + * + * @param mdl Model. + * @param data Data. + * @return median value. + */ + private double[] computeMedian(ModelsComposition mdl, FeatureMatrixWithLabelsOnHeapData data) { + double[] errors = new double[data.getLabels().length]; + for (int i = 0; i < errors.length; i++) + errors[i] = Math.abs(computeError(VectorUtils.of(data.getFeatures()[i]), data.getLabels()[i], mdl)); + return new double[] {getMedian(errors)}; + } + + /** + * Compute median value on array of errors. + * + * @param errors Error values. + * @return median value of errors. + */ + private double getMedian(double[] errors) { + if(errors.length == 0) + return Double.POSITIVE_INFINITY; + + Arrays.sort(errors); + final int middleIdx = (errors.length - 1) / 2; + if (errors.length % 2 == 1) + return errors[middleIdx]; + else + return (errors[middleIdx + 1] + errors[middleIdx]) / 2; + } + + /** + * Merge median values among partitions. + * + * @param left Left partition. + * @param right Right partition. + * @return merged median values. + */ + private double[] reduce(double[] left, double[] right) { + if (left == null) + return right; + if(right == null) + return left; + + double[] res = new double[left.length + right.length]; + System.arraycopy(left, 0, res, 0, left.length); + System.arraycopy(right, 0, res, left.length, right.length); + return res; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerFactory.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerFactory.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerFactory.java new file mode 100644 index 0000000..a1affe0 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerFactory.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting.convergence.median; + +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker; +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory; +import org.apache.ignite.ml.composition.boosting.loss.Loss; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; + +/** + * Factory for {@link MedianOfMedianConvergenceChecker}. + */ +public class MedianOfMedianConvergenceCheckerFactory extends ConvergenceCheckerFactory { + /** + * @param precision Precision. + */ + public MedianOfMedianConvergenceCheckerFactory(double precision) { + super(precision); + } + + /** {@inheritDoc} */ + @Override public <K, V> ConvergenceChecker<K, V> create(long sampleSize, + IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + return new MedianOfMedianConvergenceChecker<>(sampleSize, externalLbToInternalMapping, loss, + datasetBuilder, featureExtractor, lbExtractor, precision); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/package-info.java new file mode 100644 index 0000000..3798ef9 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains implementation of convergence checking computer by median of medians of errors in dataset. + */ +package org.apache.ignite.ml.composition.boosting.convergence.median; http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/package-info.java new file mode 100644 index 0000000..6d42c62 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/package-info.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Package contains implementation of convergency checking algorithms for gradient boosting. + * This algorithms may stop training of gradient boosting if it achieve error on dataset less than precision + * specified by user. + */ +package org.apache.ignite.ml.composition.boosting.convergence; http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java new file mode 100644 index 0000000..716d04e --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting.convergence.simple; + +import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker; +import org.apache.ignite.ml.composition.boosting.loss.Loss; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; + +/** + * This strategy skip estimating error on dataset step. + * According to this strategy, training will stop after reaching the maximum number of iterations. + * + * @param <K> Type of a key in upstream data. + * @param <V> Type of a value in upstream data. + */ +public class ConvergenceCheckerStub<K,V> extends ConvergenceChecker<K,V> { + /** Serial version uid. */ + private static final long serialVersionUID = 8534776439755210864L; + + /** + * Creates an intance of ConvergenceCheckerStub. + * + * @param sampleSize Sample size. + * @param externalLbToInternalMapping External label to internal mapping. + * @param loss Loss function. + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + */ + public ConvergenceCheckerStub(long sampleSize, + IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss, + DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + + super(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, + featureExtractor, lbExtractor, 0.0); + } + + /** {@inheritDoc} */ + @Override public boolean isConverged(DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) { + return false; + } + + /** {@inheritDoc} */ + @Override public boolean isConverged(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset, + ModelsComposition currMdl) { + return false; + } + + /** {@inheritDoc} */ + @Override public Double computeMeanErrorOnDataset(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset, + ModelsComposition mdl) { + + throw new UnsupportedOperationException(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStubFactory.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStubFactory.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStubFactory.java new file mode 100644 index 0000000..a0f0d5c --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStubFactory.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting.convergence.simple; + +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker; +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory; +import org.apache.ignite.ml.composition.boosting.loss.Loss; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; + +/** + * Factory for {@link ConvergenceCheckerStub}. + */ +public class ConvergenceCheckerStubFactory extends ConvergenceCheckerFactory { + /** + * Create an instance of ConvergenceCheckerStubFactory. + */ + public ConvergenceCheckerStubFactory() { + super(0.0); + } + + /** {@inheritDoc} */ + @Override public <K, V> ConvergenceChecker<K, V> create(long sampleSize, + IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + + return new ConvergenceCheckerStub<>(sampleSize, externalLbToInternalMapping, loss, + datasetBuilder, featureExtractor, lbExtractor); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/package-info.java new file mode 100644 index 0000000..915903a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/package-info.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains implementation of Stub for convergence checking. + * By this implementation gradient boosting will train new submodels until count of models achieving max value [count + * of iterations parameter]. + */ +package org.apache.ignite.ml.composition.boosting.convergence.simple; http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/LogLoss.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/LogLoss.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/LogLoss.java new file mode 100644 index 0000000..19ef70b --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/LogLoss.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting.loss; + +/** + * Logistic regression loss function. + */ +public class LogLoss implements Loss { + /** Serial version uid. */ + private static final long serialVersionUID = 2251384437214194977L; + + /** {@inheritDoc} */ + @Override public double error(long sampleSize, double answer, double prediction) { + return -(answer * Math.log(prediction) + (1 - answer) * Math.log(1 - prediction)); + } + + /** {@inheritDoc} */ + @Override public double gradient(long sampleSize, double answer, double prediction) { + return (prediction - answer) / (prediction * (1.0 - prediction)); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/Loss.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/Loss.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/Loss.java new file mode 100644 index 0000000..72fff30 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/Loss.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting.loss; + +import java.io.Serializable; + +/** + * Loss interface of computing error or gradient of error on specific row in dataset. + */ +public interface Loss extends Serializable { + /** + * Error value for model answer. + * + * @param sampleSize Sample size. + * @param lb Label. + * @param mdlAnswer Model answer. + * @return error value. + */ + public double error(long sampleSize, double lb, double mdlAnswer); + + /** + * Error gradient value for model answer. + * + * @param sampleSize Sample size. + * @param lb Label. + * @param mdlAnswer Model answer. + * @return error value. + */ + public double gradient(long sampleSize, double lb, double mdlAnswer); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/SquaredError.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/SquaredError.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/SquaredError.java new file mode 100644 index 0000000..8f2f17e --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/SquaredError.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting.loss; + +/** + * Represent error function as E(label, modelAnswer) = 1/N * (label - prediction)^2 + */ +public class SquaredError implements Loss { + /** Serial version uid. */ + private static final long serialVersionUID = 564886150646352157L; + + /** {@inheritDoc} */ + @Override public double error(long sampleSize, double lb, double prediction) { + return Math.pow(lb - prediction, 2) / sampleSize; + } + + /** {@inheritDoc} */ + @Override public double gradient(long sampleSize, double lb, double prediction) { + return (2.0 / sampleSize) * (prediction - lb); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/package-info.java new file mode 100644 index 0000000..83a5e39 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains loss functions for Gradient Boosting algorithms. + */ +package org.apache.ignite.ml.composition.boosting.loss; http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregator.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregator.java index 8a369ad..5e0f7f1 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregator.java @@ -86,4 +86,14 @@ public class WeightedPredictionsAggregator implements PredictionsAggregator { return builder.append(bias > 0 ? " + " : " - ").append(String.format("%.4f", bias)) .append("]").toString(); } + + /** */ + public double[] getWeights() { + return weights; + } + + /** */ + public double getBias() { + return bias; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapData.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapData.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapData.java new file mode 100644 index 0000000..9dbc1a9 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapData.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.dataset.primitive; + +/** + * A partition {@code data} of the containing matrix of features and vector of labels stored in heap. + */ +public class FeatureMatrixWithLabelsOnHeapData implements AutoCloseable { + /** Matrix with features. */ + private final double[][] features; + + /** Vector with labels. */ + private final double[] labels; + + /** + * Constructs an instance of FeatureMatrixWithLabelsOnHeapData. + * + * @param features Features. + * @param labels Labels. + */ + public FeatureMatrixWithLabelsOnHeapData(double[][] features, double[] labels) { + assert features.length == labels.length : "Features and labels have to be the same length"; + + this.features = features; + this.labels = labels; + } + + /** */ + public double[][] getFeatures() { + return features; + } + + /** */ + public double[] getLabels() { + return labels; + } + + /** {@inheritDoc} */ + @Override public void close() { + // Do nothing, GC will clean up. + } +}
