Repository: ignite Updated Branches: refs/heads/master c9368da76 -> 414f45e0a
IGNITE-9065: Gradient boosting optimization this closes #4486 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/414f45e0 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/414f45e0 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/414f45e0 Branch: refs/heads/master Commit: 414f45e0af39e1f7acf8304eedb113ca305e9a21 Parents: c9368da Author: Alexey Platonov <[email protected]> Authored: Wed Aug 8 13:22:26 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Wed Aug 8 13:22:26 2018 +0300 ---------------------------------------------------------------------- .../GDBOnTreesRegressionTrainerExample.java | 116 ++++++++++++ .../GRBOnTreesRegressionTrainerExample.java | 116 ------------ .../boosting/GDBLearningStrategy.java | 178 +++++++++++++++++++ .../ml/composition/boosting/GDBTrainer.java | 48 ++--- .../org/apache/ignite/ml/tree/DecisionTree.java | 8 +- .../tree/DecisionTreeClassificationTrainer.java | 2 +- .../ml/tree/DecisionTreeRegressionTrainer.java | 2 +- .../GDBBinaryClassifierOnTreesTrainer.java | 11 +- .../boosting/GDBOnTreesLearningStrategy.java | 97 ++++++++++ .../boosting/GDBRegressionOnTreesTrainer.java | 11 +- .../ignite/ml/tree/data/DecisionTreeData.java | 11 ++ .../impurity/ImpurityMeasureCalculator.java | 6 + 12 files changed, 457 insertions(+), 149 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java new file mode 100644 index 0000000..fa7a0d4 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tree.boosting; + +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.trainers.DatasetTrainer; +import org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer; +import org.apache.ignite.thread.IgniteThread; +import org.jetbrains.annotations.NotNull; + +/** + * Example represents a solution for the task of regression learning based on + * Gradient Boosting on trees implementation. It shows an initialization of {@link org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer}, + * initialization of Ignite Cache, learning step and comparing of predicted and real values. + * + * In this example dataset is creating automatically by parabolic function f(x) = x^2. + */ +public class GDBOnTreesRegressionTrainerExample { + /** + * Run example. + * + * @param args Command line arguments, none required. + */ + public static void main(String... args) throws InterruptedException { + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + GDBOnTreesRegressionTrainerExample.class.getSimpleName(), () -> { + + // Create cache with training data. + CacheConfiguration<Integer, double[]> trainingSetCfg = createCacheConfiguration(); + IgniteCache<Integer, double[]> trainingSet = fillTrainingData(ignite, trainingSetCfg); + + // Create regression trainer. + DatasetTrainer<Model<Vector, Double>, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.); + + // Train decision tree model. + Model<Vector, Double> mdl = trainer.fit( + ignite, + trainingSet, + (k, v) -> VectorUtils.of(v[0]), + (k, v) -> v[1] + ); + + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Valid answer \t|"); + System.out.println(">>> ---------------------------------"); + + // Calculate score. + for (int x = -5; x < 5; x++) { + double predicted = mdl.apply(VectorUtils.of(x)); + + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.pow(x, 2)); + } + + System.out.println(">>> ---------------------------------"); + + System.out.println(">>> GDB Regression trainer example completed."); + }); + + igniteThread.start(); + igniteThread.join(); + } + } + + /** + * Create cache configuration. + */ + @NotNull private static CacheConfiguration<Integer, double[]> createCacheConfiguration() { + CacheConfiguration<Integer, double[]> trainingSetCfg = new CacheConfiguration<>(); + trainingSetCfg.setName("TRAINING_SET"); + trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10)); + return trainingSetCfg; + } + + /** + * Fill parabola training data. + * + * @param ignite Ignite. + * @param trainingSetCfg Training set config. + */ + @NotNull private static IgniteCache<Integer, double[]> fillTrainingData(Ignite ignite, + CacheConfiguration<Integer, double[]> trainingSetCfg) { + IgniteCache<Integer, double[]> trainingSet = ignite.createCache(trainingSetCfg); + for(int i = -50; i <= 50; i++) { + double x = ((double)i) / 10.0; + double y = Math.pow(x, 2); + trainingSet.put(i, new double[] {x, y}); + } + return trainingSet; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GRBOnTreesRegressionTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GRBOnTreesRegressionTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GRBOnTreesRegressionTrainerExample.java deleted file mode 100644 index 71d405a..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GRBOnTreesRegressionTrainerExample.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.tree.boosting; - -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; -import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.math.primitives.vector.VectorUtils; -import org.apache.ignite.ml.trainers.DatasetTrainer; -import org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer; -import org.apache.ignite.thread.IgniteThread; -import org.jetbrains.annotations.NotNull; - -/** - * Example represents a solution for the task of regression learning based on - * Gradient Boosting on trees implementation. It shows an initialization of {@link org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer}, - * initialization of Ignite Cache, learning step and comparing of predicted and real values. - * - * In this example dataset is creating automatically by parabolic function f(x) = x^2. - */ -public class GRBOnTreesRegressionTrainerExample { - /** - * Run example. - * - * @param args Command line arguments, none required. - */ - public static void main(String... args) throws InterruptedException { - // Start ignite grid. - try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { - System.out.println(">>> Ignite grid started."); - - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - GRBOnTreesRegressionTrainerExample.class.getSimpleName(), () -> { - - // Create cache with training data. - CacheConfiguration<Integer, double[]> trainingSetCfg = createCacheConfiguration(); - IgniteCache<Integer, double[]> trainingSet = fillTrainingData(ignite, trainingSetCfg); - - // Create regression trainer. - DatasetTrainer<Model<Vector, Double>, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.); - - // Train decision tree model. - Model<Vector, Double> mdl = trainer.fit( - ignite, - trainingSet, - (k, v) -> VectorUtils.of(v[0]), - (k, v) -> v[1] - ); - - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Valid answer \t|"); - System.out.println(">>> ---------------------------------"); - - // Calculate score. - for (int x = -5; x < 5; x++) { - double predicted = mdl.apply(VectorUtils.of(x)); - - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.pow(x, 2)); - } - - System.out.println(">>> ---------------------------------"); - - System.out.println(">>> GDB Regression trainer example completed."); - }); - - igniteThread.start(); - igniteThread.join(); - } - } - - /** - * Create cache configuration. - */ - @NotNull private static CacheConfiguration<Integer, double[]> createCacheConfiguration() { - CacheConfiguration<Integer, double[]> trainingSetCfg = new CacheConfiguration<>(); - trainingSetCfg.setName("TRAINING_SET"); - trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10)); - return trainingSetCfg; - } - - /** - * Fill parabola training data. - * - * @param ignite Ignite. - * @param trainingSetCfg Training set config. - */ - @NotNull private static IgniteCache<Integer, double[]> fillTrainingData(Ignite ignite, - CacheConfiguration<Integer, double[]> trainingSetCfg) { - IgniteCache<Integer, double[]> trainingSet = ignite.createCache(trainingSetCfg); - for(int i = -50; i <= 50; i++) { - double x = ((double)i) / 10.0; - double y = Math.pow(x, 2); - trainingSet.put(i, new double[] {x, y}); - } - return trainingSet; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java new file mode 100644 index 0000000..375748a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.environment.LearningEnvironment; +import org.apache.ignite.ml.environment.logging.MLLogger; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.functions.IgniteSupplier; +import org.apache.ignite.ml.math.functions.IgniteTriFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.trainers.DatasetTrainer; + +/** + * Learning strategy for gradient boosting. + */ +public class GDBLearningStrategy { + /** Learning environment. */ + protected LearningEnvironment environment; + + /** Count of iterations. */ + protected int cntOfIterations; + + /** Loss of gradient. */ + protected IgniteTriFunction<Long, Double, Double, Double> lossGradient; + + /** External label to internal mapping. */ + protected IgniteFunction<Double, Double> externalLbToInternalMapping; + + /** Base model trainer builder. */ + protected IgniteSupplier<DatasetTrainer<? extends Model<Vector, Double>, Double>> baseMdlTrainerBuilder; + + /** Mean label value. */ + protected double meanLabelValue; + + /** Sample size. */ + protected long sampleSize; + + /** Composition weights. */ + protected double[] compositionWeights; + + /** + * Implementation of gradient boosting iterations. At each step of iterations this algorithm + * build a regression model based on gradient of loss-function for current models composition. + * + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @return list of learned models. + */ + public <K, V> List<Model<Vector, Double>> learnModels(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + List<Model<Vector, Double>> models = new ArrayList<>(); + DatasetTrainer<? extends Model<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get(); + for (int i = 0; i < cntOfIterations; i++) { + double[] weights = Arrays.copyOf(compositionWeights, i); + + WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLabelValue); + Model<Vector, Double> currComposition = new ModelsComposition(models, aggregator); + + IgniteBiFunction<K, V, Double> lbExtractorWrap = (k, v) -> { + Double realAnswer = externalLbToInternalMapping.apply(lbExtractor.apply(k, v)); + Double mdlAnswer = currComposition.apply(featureExtractor.apply(k, v)); + return -lossGradient.apply(sampleSize, realAnswer, mdlAnswer); + }; + + long startTs = System.currentTimeMillis(); + models.add(trainer.fit(datasetBuilder, featureExtractor, lbExtractorWrap)); + double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0; + environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime); + } + + return models; + } + + /** + * Sets learning environment. + * + * @param environment Learning Environment. + */ + public GDBLearningStrategy withEnvironment(LearningEnvironment environment) { + this.environment = environment; + return this; + } + + /** + * Sets count of iterations. + * + * @param cntOfIterations Count of iterations. + */ + public GDBLearningStrategy withCntOfIterations(int cntOfIterations) { + this.cntOfIterations = cntOfIterations; + return this; + } + + /** + * Sets gradient of loss function. + * + * @param lossGradient Loss gradient. + */ + public GDBLearningStrategy withLossGradient(IgniteTriFunction<Long, Double, Double, Double> lossGradient) { + this.lossGradient = lossGradient; + return this; + } + + /** + * Sets external to internal label representation mapping. + * + * @param externalLbToInternal External label to internal. + */ + public GDBLearningStrategy withExternalLabelToInternal(IgniteFunction<Double, Double> externalLbToInternal) { + this.externalLbToInternalMapping = externalLbToInternal; + return this; + } + + /** + * Sets base model builder. + * + * @param buildBaseMdlTrainer Build base model trainer. + */ + public GDBLearningStrategy withBaseModelTrainerBuilder(IgniteSupplier<DatasetTrainer<? extends Model<Vector, Double>, Double>> buildBaseMdlTrainer) { + this.baseMdlTrainerBuilder = buildBaseMdlTrainer; + return this; + } + + /** + * Sets mean label value. + * + * @param meanLabelValue Mean label value. + */ + public GDBLearningStrategy withMeanLabelValue(double meanLabelValue) { + this.meanLabelValue = meanLabelValue; + return this; + } + + /** + * Sets sample size. + * + * @param sampleSize Sample size. + */ + public GDBLearningStrategy withSampleSize(long sampleSize) { + this.sampleSize = sampleSize; + return this; + } + + /** + * Sets composition weights vector. + * + * @param compositionWeights Composition weights. + */ + public GDBLearningStrategy withCompositionWeights(double[] compositionWeights) { + this.compositionWeights = compositionWeights; + return this; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java index 8663d3d..5a0f52a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java @@ -17,7 +17,6 @@ package org.apache.ignite.ml.composition.boosting; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.ignite.lang.IgniteBiTuple; @@ -53,16 +52,18 @@ import org.jetbrains.annotations.NotNull; * * But in practice Decision Trees is most used regressors (see: {@link DecisionTreeRegressionTrainer}). */ -abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, Double> { +public abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, Double> { /** Gradient step. */ private final double gradientStep; + /** Count of iterations. */ private final int cntOfIterations; + /** * Gradient of loss function. First argument is sample size, second argument is valid answer, third argument is * current model prediction. */ - private final IgniteTriFunction<Long, Double, Double, Double> lossGradient; + protected final IgniteTriFunction<Long, Double, Double, Double> lossGradient; /** * Constructs GDBTrainer instance. @@ -91,28 +92,23 @@ abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, Double> Double mean = initAndSampleSize.get1(); Long sampleSize = initAndSampleSize.get2(); - List<Model<Vector, Double>> models = new ArrayList<>(); double[] compositionWeights = new double[cntOfIterations]; Arrays.fill(compositionWeights, gradientStep); WeightedPredictionsAggregator resAggregator = new WeightedPredictionsAggregator(compositionWeights, mean); long learningStartTs = System.currentTimeMillis(); - for (int i = 0; i < cntOfIterations; i++) { - double[] weights = Arrays.copyOf(compositionWeights, i); - WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, mean); - Model<Vector, Double> currComposition = new ModelsComposition(models, aggregator); - - IgniteBiFunction<K, V, Double> lbExtractorWrap = (k, v) -> { - Double realAnswer = externalLabelToInternal(lbExtractor.apply(k, v)); - Double mdlAnswer = currComposition.apply(featureExtractor.apply(k, v)); - return -lossGradient.apply(sampleSize, realAnswer, mdlAnswer); - }; - - long startTs = System.currentTimeMillis(); - models.add(buildBaseModelTrainer().fit(datasetBuilder, featureExtractor, lbExtractorWrap)); - double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0; - environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime); - } + + List<Model<Vector, Double>> models = getLearningStrategy() + .withBaseModelTrainerBuilder(this::buildBaseModelTrainer) + .withExternalLabelToInternal(this::externalLabelToInternal) + .withCntOfIterations(cntOfIterations) + .withCompositionWeights(compositionWeights) + .withEnvironment(environment) + .withLossGradient(lossGradient) + .withSampleSize(sampleSize) + .withMeanLabelValue(mean) + .learnModels(datasetBuilder, featureExtractor, lbExtractor); + double learningTime = (double)(System.currentTimeMillis() - learningStartTs) / 1000.0; environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", learningTime); @@ -136,7 +132,8 @@ abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, Double> /** * Returns regressor model trainer for one step of GDB. */ - @NotNull protected abstract DatasetTrainer<? extends Model<Vector, Double>, Double> buildBaseModelTrainer(); + @NotNull + protected abstract DatasetTrainer<? extends Model<Vector, Double>, Double> buildBaseModelTrainer(); /** * Maps external representation of label to internal. @@ -191,4 +188,13 @@ abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, Double> throw new RuntimeException(e); } } + + /** + * Returns learning strategy. + * + * @return learning strategy. + */ + protected GDBLearningStrategy getLearningStrategy() { + return new GDBLearningStrategy(); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java index 270f14a..de8994a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java @@ -79,20 +79,24 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset new EmptyContextBuilder<>(), new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, useIndex) )) { - return split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset)); + return fit(dataset); } catch (Exception e) { throw new RuntimeException(e); } } + public <K,V> DecisionTreeNode fit(Dataset<EmptyContext, DecisionTreeData> dataset) { + return split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset)); + } + /** * Returns impurity measure calculator. * * @param dataset Dataset. * @return Impurity measure calculator. */ - abstract ImpurityMeasureCalculator<T> getImpurityMeasureCalculator(Dataset<EmptyContext, DecisionTreeData> dataset); + protected abstract ImpurityMeasureCalculator<T> getImpurityMeasureCalculator(Dataset<EmptyContext, DecisionTreeData> dataset); /** * Splits the node specified by the given dataset and predicate and returns decision tree node. http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java index f371334..f8fc769 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java @@ -96,7 +96,7 @@ public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurity } /** {@inheritDoc} */ - @Override ImpurityMeasureCalculator<GiniImpurityMeasure> getImpurityMeasureCalculator( + @Override protected ImpurityMeasureCalculator<GiniImpurityMeasure> getImpurityMeasureCalculator( Dataset<EmptyContext, DecisionTreeData> dataset) { Set<Double> labels = dataset.compute(part -> { http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java index 7446237..4c9aac9 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java @@ -64,7 +64,7 @@ public class DecisionTreeRegressionTrainer extends DecisionTree<MSEImpurityMeasu } /** {@inheritDoc} */ - @Override ImpurityMeasureCalculator<MSEImpurityMeasure> getImpurityMeasureCalculator( + @Override protected ImpurityMeasureCalculator<MSEImpurityMeasure> getImpurityMeasureCalculator( Dataset<EmptyContext, DecisionTreeData> dataset) { return new MSEImpurityMeasureCalculator(useIndex); http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java index 631e848..4d87b47 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java @@ -17,10 +17,8 @@ package org.apache.ignite.ml.tree.boosting; -import org.apache.ignite.ml.Model; import org.apache.ignite.ml.composition.boosting.GDBBinaryClassifierTrainer; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.trainers.DatasetTrainer; +import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy; import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer; import org.jetbrains.annotations.NotNull; @@ -54,7 +52,7 @@ public class GDBBinaryClassifierOnTreesTrainer extends GDBBinaryClassifierTraine } /** {@inheritDoc} */ - @NotNull @Override protected DatasetTrainer<? extends Model<Vector, Double>, Double> buildBaseModelTrainer() { + @NotNull @Override protected DecisionTreeRegressionTrainer buildBaseModelTrainer() { return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease).withUseIndex(useIndex); } @@ -68,4 +66,9 @@ public class GDBBinaryClassifierOnTreesTrainer extends GDBBinaryClassifierTraine this.useIndex = useIndex; return this; } + + /** {@inheritDoc} */ + @Override protected GDBLearningStrategy getLearningStrategy() { + return new GDBOnTreesLearningStrategy(useIndex); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java new file mode 100644 index 0000000..8589a79 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.boosting; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy; +import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.logging.MLLogger; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.trainers.DatasetTrainer; +import org.apache.ignite.ml.tree.DecisionTree; +import org.apache.ignite.ml.tree.data.DecisionTreeData; +import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder; + +/** + * Gradient boosting on trees specific learning strategy reusing learning dataset with index between + * several learning iterations. + */ +public class GDBOnTreesLearningStrategy extends GDBLearningStrategy { + private boolean useIndex; + + /** + * Create an instance of learning strategy. + * + * @param useIndex Use index. + */ + public GDBOnTreesLearningStrategy(boolean useIndex) { + this.useIndex = useIndex; + } + + /** {@inheritDoc} */ + @Override public <K, V> List<Model<Vector, Double>> learnModels(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + DatasetTrainer<? extends Model<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get(); + assert trainer instanceof DecisionTree; + DecisionTree decisionTreeTrainer = (DecisionTree) trainer; + + List<Model<Vector, Double>> models = new ArrayList<>(); + try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build( + new EmptyContextBuilder<>(), + new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, useIndex) + )) { + for (int i = 0; i < cntOfIterations; i++) { + double[] weights = Arrays.copyOf(compositionWeights, i); + WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLabelValue); + Model<Vector, Double> currComposition = new ModelsComposition(models, aggregator); + + dataset.compute(part -> { + if(part.getCopyOfOriginalLabels() == null) + part.setCopyOfOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length)); + + for(int j = 0; j < part.getLabels().length; j++) { + double mdlAnswer = currComposition.apply(VectorUtils.of(part.getFeatures()[j])); + double originalLbVal = externalLbToInternalMapping.apply(part.getCopyOfOriginalLabels()[j]); + part.getLabels()[j] = -lossGradient.apply(sampleSize, originalLbVal, mdlAnswer); + } + }); + + long startTs = System.currentTimeMillis(); + models.add(decisionTreeTrainer.fit(dataset)); + double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0; + environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime); + } + } + catch (Exception e) { + throw new RuntimeException(e); + } + + return models; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java index 450dae3..e2a183c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java @@ -17,10 +17,8 @@ package org.apache.ignite.ml.tree.boosting; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy; import org.apache.ignite.ml.composition.boosting.GDBRegressionTrainer; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.trainers.DatasetTrainer; import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer; import org.jetbrains.annotations.NotNull; @@ -54,7 +52,7 @@ public class GDBRegressionOnTreesTrainer extends GDBRegressionTrainer { } /** {@inheritDoc} */ - @NotNull @Override protected DatasetTrainer<? extends Model<Vector, Double>, Double> buildBaseModelTrainer() { + @NotNull @Override protected DecisionTreeRegressionTrainer buildBaseModelTrainer() { return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease).withUseIndex(useIndex); } @@ -68,4 +66,9 @@ public class GDBRegressionOnTreesTrainer extends GDBRegressionTrainer { this.useIndex = useIndex; return this; } + + /** {@inheritDoc} */ + @Override protected GDBLearningStrategy getLearningStrategy() { + return new GDBOnTreesLearningStrategy(useIndex); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java index c017e5c..d5750ea 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java @@ -31,6 +31,9 @@ public class DecisionTreeData implements AutoCloseable { /** Vector with labels. */ private final double[] labels; + /** Copy of vector with original labels. Auxiliary for Gradient Boosting on Trees.*/ + private double[] copyOfOriginalLabels; + /** Indexes cache. */ private final List<TreeDataIndex> indexesCache; @@ -137,6 +140,14 @@ public class DecisionTreeData implements AutoCloseable { return labels; } + public double[] getCopyOfOriginalLabels() { + return copyOfOriginalLabels; + } + + public void setCopyOfOriginalLabels(double[] copyOfOriginalLabels) { + this.copyOfOriginalLabels = copyOfOriginalLabels; + } + /** {@inheritDoc} */ @Override public void close() { // Do nothing, GC will clean up. http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java index 709f68e..0c67535 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java @@ -18,6 +18,8 @@ package org.apache.ignite.ml.tree.impurity; import java.io.Serializable; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.tree.TreeFilter; import org.apache.ignite.ml.tree.data.DecisionTreeData; import org.apache.ignite.ml.tree.data.TreeDataIndex; @@ -98,4 +100,8 @@ public abstract class ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> im protected double getFeatureValue(DecisionTreeData data, TreeDataIndex idx, int featureId, int k) { return useIndex ? idx.featureInSortedOrder(k, featureId) : data.getFeatures()[k][featureId]; } + + protected Vector getFeatureValues(DecisionTreeData data, TreeDataIndex idx, int featureId, int k) { + return VectorUtils.of(useIndex ? idx.featuresInSortedOrder(k, featureId) : data.getFeatures()[k]); + } }
