IGNITE-8924: [ML] Parameter Grid for tuning hyper-parameters in Cross-Validation process.
this closes #4425 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/5cddf920 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/5cddf920 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/5cddf920 Branch: refs/heads/ignite-gg-13195-cache-groups Commit: 5cddf9208faa873f972120bd15270a7f155986e4 Parents: 6263dbe Author: zaleslaw <[email protected]> Authored: Fri Jul 27 14:59:30 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Fri Jul 27 14:59:30 2018 +0300 ---------------------------------------------------------------------- .../ml/tutorial/Step_8_CV_with_Param_Grid.java | 173 +++++++++++++++ .../ignite/ml/selection/cv/CrossValidation.java | 212 +++++++++++++------ .../ml/selection/cv/CrossValidationResult.java | 134 ++++++++++++ .../ml/selection/paramgrid/ParamGrid.java | 58 +++++ .../paramgrid/ParameterSetGenerator.java | 91 ++++++++ .../ml/selection/paramgrid/package-info.java | 22 ++ .../org/apache/ignite/ml/tree/DecisionTree.java | 6 +- .../tree/DecisionTreeClassificationTrainer.java | 38 ++++ .../ignite/ml/selection/SelectionTestSuite.java | 2 + .../paramgrid/ParameterSetGeneratorTest.java | 56 +++++ 10 files changed, 723 insertions(+), 69 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java new file mode 100644 index 0000000..6104299 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tutorial; + +import java.io.FileNotFoundException; +import java.util.Arrays; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; +import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; +import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; +import org.apache.ignite.ml.selection.cv.CrossValidation; +import org.apache.ignite.ml.selection.cv.CrossValidationResult; +import org.apache.ignite.ml.selection.paramgrid.ParamGrid; +import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; +import org.apache.ignite.ml.selection.scoring.metric.Accuracy; +import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter; +import org.apache.ignite.ml.selection.split.TrainTestSplit; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.thread.IgniteThread; + +/** + * To choose the best hyperparameters the cross-validation will be used in this example. + * + * The purpose of cross-validation is model checking, not model building. + * + * We train k different models. + * + * They differ in that 1/(k-1)th of the training data is exchanged against other cases. + * + * These models are sometimes called surrogate models because the (average) performance measured for these models + * is taken as a surrogate of the performance of the model trained on all cases. + * + * All scenarios are described there: https://sebastianraschka.com/faq/docs/evaluate-a-model.html + */ +public class Step_8_CV_with_Param_Grid { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + Step_8_CV_with_Param_Grid.class.getSimpleName(), () -> { + try { + IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite); + + // Defines first preprocessor that extracts features from an upstream data. + // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare" + IgniteBiFunction<Integer, Object[], Object[]> featureExtractor + = (k, v) -> new Object[]{v[0], v[3], v[4], v[5], v[6], v[8], v[10]}; + + IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double) v[1]; + + TrainTestSplit<Integer, Object[]> split = new TrainTestDatasetSplitter<Integer, Object[]>() + .split(0.75); + + IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + .encodeFeature(1) + .encodeFeature(6) // <--- Changed index here + .fit(ignite, + dataCache, + featureExtractor + ); + + IgniteBiFunction<Integer, Object[], Vector> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>() + .fit(ignite, + dataCache, + strEncoderPreprocessor + ); + + IgniteBiFunction<Integer, Object[], Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>() + .fit( + ignite, + dataCache, + imputingPreprocessor + ); + + IgniteBiFunction<Integer, Object[], Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>() + .withP(2) + .fit( + ignite, + dataCache, + minMaxScalerPreprocessor + ); + + // Tune hyperparams with K-fold Cross-Validation on the splitted training set. + + DecisionTreeClassificationTrainer trainerCV = new DecisionTreeClassificationTrainer(); + + CrossValidation<DecisionTreeNode, Double, Integer, Object[]> scoreCalculator + = new CrossValidation<>(); + + ParamGrid paramGrid = new ParamGrid() + .addHyperParam("maxDeep", new Double[]{1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 10.0}) + .addHyperParam("minImpurityDecrease", new Double[]{0.0, 0.25, 0.5}); + + CrossValidationResult crossValidationRes = scoreCalculator.score( + trainerCV, + new Accuracy<>(), + ignite, + dataCache, + split.getTrainFilter(), + normalizationPreprocessor, + lbExtractor, + 3, + paramGrid + ); + + System.out.println("Train with maxDeep: " + crossValidationRes.getBest("maxDeep") + " and minImpurityDecrease: " + crossValidationRes.getBest("minImpurityDecrease")); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer() + .withMaxDeep(crossValidationRes.getBest("maxDeep")) + .withMinImpurityDecrease(crossValidationRes.getBest("minImpurityDecrease")); + + System.out.println(crossValidationRes); + + System.out.println("Best score: " + Arrays.toString(crossValidationRes.getBestScore())); + System.out.println("Best hyper params: " + crossValidationRes.getBestHyperParams()); + System.out.println("Best average score: " + crossValidationRes.getBestAvgScore()); + + crossValidationRes.getScoringBoard().forEach((hyperParams, score) -> { + System.out.println("Score " + Arrays.toString(score) + " for hyper params " + hyperParams); + }); + + // Train decision tree model. + DecisionTreeNode bestMdl = trainer.fit( + ignite, + dataCache, + split.getTrainFilter(), + normalizationPreprocessor, + lbExtractor + ); + + double accuracy = Evaluator.evaluate( + dataCache, + split.getTestFilter(), + bestMdl, + normalizationPreprocessor, + lbExtractor, + new Accuracy<>() + ); + + System.out.println("\n>>> Accuracy " + accuracy); + System.out.println("\n>>> Test Error " + (1 - accuracy)); + } catch (FileNotFoundException e) { + e.printStackTrace(); + } + }); + + igniteThread.start(); + + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java index f417fab..1ade876 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java @@ -17,6 +17,11 @@ package org.apache.ignite.ml.selection.cv; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.function.BiFunction; import java.util.function.Function; @@ -29,6 +34,8 @@ import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.selection.paramgrid.ParamGrid; +import org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator; import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursor; import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor; import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursor; @@ -42,9 +49,9 @@ import org.apache.ignite.ml.trainers.DatasetTrainer; * following way: the training set is split into k smaller sets. The following procedure is followed for each of the k * âfoldsâ: * <ul> - * <li>A model is trained using k-1 of the folds as training data;</li> - * <li>the resulting model is validated on the remaining part of the data (i.e., it is used as a test set to compute - * a performance measure such as accuracy).</li> + * <li>A model is trained using k-1 of the folds as training data;</li> + * <li>the resulting model is validated on the remaining part of the data (i.e., it is used as a test set to compute + * a performance measure such as accuracy).</li> * </ul> * * @param <M> Type of model. @@ -56,18 +63,18 @@ public class CrossValidation<M extends Model<Vector, L>, L, K, V> { /** * Computes cross-validated metrics. * - * @param trainer Trainer of the model. - * @param scoreCalculator Score calculator. - * @param ignite Ignite instance. - * @param upstreamCache Ignite cache with {@code upstream} data. + * @param trainer Trainer of the model. + * @param scoreCalculator Score calculator. + * @param ignite Ignite instance. + * @param upstreamCache Ignite cache with {@code upstream} data. * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - * @param cv Number of folds. + * @param lbExtractor Label extractor. + * @param cv Number of folds. * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Ignite ignite, - IgniteCache<K, V> upstreamCache, IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, L> lbExtractor, int cv) { + IgniteCache<K, V> upstreamCache, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor, int cv) { return score(trainer, scoreCalculator, ignite, upstreamCache, (k, v) -> true, featureExtractor, lbExtractor, new SHA256UniformMapper<>(), cv); } @@ -75,35 +82,109 @@ public class CrossValidation<M extends Model<Vector, L>, L, K, V> { /** * Computes cross-validated metrics. * - * @param trainer Trainer of the model. - * @param scoreCalculator Base score calculator. - * @param ignite Ignite instance. - * @param upstreamCache Ignite cache with {@code upstream} data. - * @param filter Base {@code upstream} data filter. + * @param trainer Trainer of the model. + * @param scoreCalculator Base score calculator. + * @param ignite Ignite instance. + * @param upstreamCache Ignite cache with {@code upstream} data. + * @param filter Base {@code upstream} data filter. * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - * @param cv Number of folds. + * @param lbExtractor Label extractor. + * @param cv Number of folds. * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Ignite ignite, - IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter, - IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) { + IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) { return score(trainer, scoreCalculator, ignite, upstreamCache, filter, featureExtractor, lbExtractor, new SHA256UniformMapper<>(), cv); } /** + * Computes cross-validated metrics with a passed parameter grid. + * + * The real cross-validation training will be called each time for each parameter set. + * + * @param trainer Trainer of the model. + * @param scoreCalculator Base score calculator. + * @param ignite Ignite instance. + * @param upstreamCache Ignite cache with {@code upstream} data. + * @param filter Base {@code upstream} data filter. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param cv Number of folds. + * @param paramGrid Parameter grid. + * @return Array of scores of the estimator for each run of the cross validation. + */ + public CrossValidationResult score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Ignite ignite, + IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv, + ParamGrid paramGrid) { + + List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIndex()).generate(); + + CrossValidationResult cvRes = new CrossValidationResult(); + + paramSets.forEach(paramSet -> { + Map<String, Double> paramMap = new HashMap<>(); + + for (int paramIdx = 0; paramIdx < paramSet.length; paramIdx++) { + String paramName = paramGrid.getParamNameByIndex(paramIdx); + Double paramVal = paramSet[paramIdx]; + + paramMap.put(paramName, paramVal); + + try { + final String mtdName = "with" + + paramName.substring(0, 1).toUpperCase() + + paramName.substring(1); + + Method trainerSetter = null; + + // We should iterate along all methods due to we have no info about signature and passed types. + for (Method method : trainer.getClass().getDeclaredMethods()) { + if (method.getName().equals(mtdName)) + trainerSetter = method; + } + + if (trainerSetter != null) + trainerSetter.invoke(trainer, paramVal); + else + throw new NoSuchMethodException(mtdName); + + } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + e.printStackTrace(); + } + } + + double[] locScores = score(trainer, scoreCalculator, ignite, upstreamCache, filter, featureExtractor, lbExtractor, + new SHA256UniformMapper<>(), cv); + + cvRes.addScores(locScores, paramMap); + + final double locAvgScore = Arrays.stream(locScores).average().orElse(Double.MIN_VALUE); + + if (locAvgScore > cvRes.getBestAvgScore()) { + cvRes.setBestScore(locScores); + cvRes.setBestHyperParams(paramMap); + System.out.println(paramMap.toString()); + } + }); + + return cvRes; + } + + /** * Computes cross-validated metrics. * - * @param trainer Trainer of the model. - * @param scoreCalculator Base score calculator. - * @param ignite Ignite instance. - * @param upstreamCache Ignite cache with {@code upstream} data. - * @param filter Base {@code upstream} data filter. + * @param trainer Trainer of the model. + * @param scoreCalculator Base score calculator. + * @param ignite Ignite instance. + * @param upstreamCache Ignite cache with {@code upstream} data. + * @param filter Base {@code upstream} data filter. * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1). - * @param cv Number of folds. + * @param lbExtractor Label extractor. + * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1). + * @param cv Number of folds. * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, @@ -136,17 +217,17 @@ public class CrossValidation<M extends Model<Vector, L>, L, K, V> { /** * Computes cross-validated metrics. * - * @param trainer Trainer of the model. - * @param scoreCalculator Base score calculator. - * @param upstreamMap Map with {@code upstream} data. - * @param parts Number of partitions. + * @param trainer Trainer of the model. + * @param scoreCalculator Base score calculator. + * @param upstreamMap Map with {@code upstream} data. + * @param parts Number of partitions. * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - * @param cv Number of folds. + * @param lbExtractor Label extractor. + * @param cv Number of folds. * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Map<K, V> upstreamMap, - int parts, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) { + int parts, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) { return score(trainer, scoreCalculator, upstreamMap, (k, v) -> true, parts, featureExtractor, lbExtractor, new SHA256UniformMapper<>(), cv); } @@ -154,19 +235,19 @@ public class CrossValidation<M extends Model<Vector, L>, L, K, V> { /** * Computes cross-validated metrics. * - * @param trainer Trainer of the model. - * @param scoreCalculator Base score calculator. - * @param upstreamMap Map with {@code upstream} data. - * @param filter Base {@code upstream} data filter. - * @param parts Number of partitions. + * @param trainer Trainer of the model. + * @param scoreCalculator Base score calculator. + * @param upstreamMap Map with {@code upstream} data. + * @param filter Base {@code upstream} data filter. + * @param parts Number of partitions. * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - * @param cv Number of folds. + * @param lbExtractor Label extractor. + * @param cv Number of folds. * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Map<K, V> upstreamMap, - IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, L> lbExtractor, int cv) { + IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor, int cv) { return score(trainer, scoreCalculator, upstreamMap, filter, parts, featureExtractor, lbExtractor, new SHA256UniformMapper<>(), cv); } @@ -174,20 +255,20 @@ public class CrossValidation<M extends Model<Vector, L>, L, K, V> { /** * Computes cross-validated metrics. * - * @param trainer Trainer of the model. - * @param scoreCalculator Base score calculator. - * @param upstreamMap Map with {@code upstream} data. - * @param filter Base {@code upstream} data filter. - * @param parts Number of partitions. + * @param trainer Trainer of the model. + * @param scoreCalculator Base score calculator. + * @param upstreamMap Map with {@code upstream} data. + * @param filter Base {@code upstream} data filter. + * @param parts Number of partitions. * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1). - * @param cv Number of folds. + * @param lbExtractor Label extractor. + * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1). + * @param cv Number of folds. * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Map<K, V> upstreamMap, - IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, L> lbExtractor, UniformMapper<K, V> mapper, int cv) { + IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor, UniformMapper<K, V> mapper, int cv) { return score( trainer, predicate -> new LocalDatasetBuilder<>( @@ -213,21 +294,21 @@ public class CrossValidation<M extends Model<Vector, L>, L, K, V> { /** * Computes cross-validated metrics. * - * @param trainer Trainer of the model. + * @param trainer Trainer of the model. * @param datasetBuilderSupplier Dataset builder supplier. - * @param testDataIterSupplier Test data iterator supplier. - * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - * @param scoreCalculator Base score calculator. - * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1). - * @param cv Number of folds. + * @param testDataIterSupplier Test data iterator supplier. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param scoreCalculator Base score calculator. + * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1). + * @param cv Number of folds. * @return Array of scores of the estimator for each run of the cross validation. */ private double[] score(DatasetTrainer<M, L> trainer, Function<IgniteBiPredicate<K, V>, DatasetBuilder<K, V>> datasetBuilderSupplier, - BiFunction<IgniteBiPredicate<K, V>, M, LabelPairCursor<L>> testDataIterSupplier, - IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, - Metric<L> scoreCalculator, UniformMapper<K, V> mapper, int cv) { + BiFunction<IgniteBiPredicate<K, V>, M, LabelPairCursor<L>> testDataIterSupplier, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, + Metric<L> scoreCalculator, UniformMapper<K, V> mapper, int cv) { double[] scores = new double[cv]; @@ -246,8 +327,7 @@ public class CrossValidation<M extends Model<Vector, L>, L, K, V> { try (LabelPairCursor<L> cursor = testDataIterSupplier.apply(trainSetFilter, mdl)) { scores[i] = scoreCalculator.score(cursor.iterator()); - } - catch (Exception e) { + } catch (Exception e) { throw new RuntimeException(e); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationResult.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationResult.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationResult.java new file mode 100644 index 0000000..55c20be --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationResult.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.cv; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * Represents the cross validation procedure result, + * wraps score and values of hyper parameters associated with these values. + */ +public class CrossValidationResult { + /** Best hyper params. */ + private Map<String, Double> bestHyperParams; + + /** Best score. */ + private double[] bestScore; + + /** + * Scoring board. + * The key is map of hyper parameters and its values, + * the value is score result associated with set of hyper paramters presented in the key. + */ + private Map<Map<String, Double>, double[]> scoringBoard = new HashMap<>(); + + /** + * Default constructor. + */ + CrossValidationResult() { + } + + /** + * Gets the best value for the specific hyper parameter. + * + * @param hyperParamName Hyper parameter name. + * @return The value. + */ + public double getBest(String hyperParamName) { + return bestHyperParams.get(hyperParamName); + } + + /** + * Gets the best score for the specific hyper parameter. + * + * @return The value. + */ + public double[] getBestScore() { + return bestScore; + } + + /** + * Adds local scores and associated parameter set to the scoring board. + * + * @param locScores The scores. + * @param paramMap The parameter set associated with the given scores. + */ + void addScores(double[] locScores, Map<String, Double> paramMap) { + scoringBoard.put(paramMap, locScores); + } + + /** + * Gets the the average value of best score array. + * + * @return The value. + */ + public double getBestAvgScore() { + if (bestScore == null) + return Double.MIN_VALUE; + return Arrays.stream(bestScore).average().orElse(Double.MIN_VALUE); + } + + /** + * Helper method in cross-validation process. + * + * @param bestScore The best score. + */ + void setBestScore(double[] bestScore) { + this.bestScore = bestScore; + } + + /** + * Helper method in cross-validation process. + * + * @param bestHyperParams The best hyper parameters. + */ + void setBestHyperParams(Map<String, Double> bestHyperParams) { + this.bestHyperParams = bestHyperParams; + } + + /** + * Gets the Scoring Board. + * + * The key is map of hyper parameters and its values, + * the value is score result associated with set of hyper paramters presented in the key. + * + * @return The Scoring Board. + */ + public Map<Map<String, Double>, double[]> getScoringBoard() { + return scoringBoard; + } + + /** + * Gets the best hyper parameters set. + * + * @return The value. + */ + public Map<String, Double> getBestHyperParams() { + return bestHyperParams; + } + + /** {@inheritDoc} */ + @Override public String toString() { + return "CrossValidationResult{" + + "bestHyperParams=" + bestHyperParams + + ", bestScore=" + Arrays.toString(bestScore) + + '}'; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParamGrid.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParamGrid.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParamGrid.java new file mode 100644 index 0000000..3279d93 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParamGrid.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.paramgrid; + +import java.util.HashMap; +import java.util.Map; + +/** + * Keeps the grid of parameters. + */ +public class ParamGrid { + /** Parameter values by parameter index. */ + private Map<Integer, Double[]> paramValuesByParamIndex = new HashMap<>(); + + /** Parameter names by parameter index. */ + private Map<Integer, String> paramNamesByParamIndex = new HashMap<>(); + + /** Parameter counter. */ + private int paramCntr; + + /** */ + public Map<Integer, Double[]> getParamValuesByParamIndex() { + return paramValuesByParamIndex; + } + + /** + * Adds a grid for the specific hyper parameter. + * @param paramName The parameter name. + * @param params The array of the given hyper parameter values. + * @return The updated ParamGrid. + */ + public ParamGrid addHyperParam(String paramName, Double[] params) { + paramValuesByParamIndex.put(paramCntr, params); + paramNamesByParamIndex.put(paramCntr, paramName); + paramCntr++; + return this; + } + + /** */ + public String getParamNameByIndex(int idx) { + return paramNamesByParamIndex.get(idx); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGenerator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGenerator.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGenerator.java new file mode 100644 index 0000000..b6c86a1 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGenerator.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.paramgrid; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +/** + * Generates tuples of hyper parameter values by given map. + * + * In the given map keys are names of hyper parameters + * and values are arrays of values for hyper parameter presented in the key. + */ +public class ParameterSetGenerator { + /** Size of parameter vector. Default value is 100. */ + private int sizeOfParamVector = 100; + + /** Params. */ + private List<Double[]> params = new ArrayList<>(); + + /** The given map of hyper parameters and its values. */ + private Map<Integer, Double[]> map; + + /** + * Creates an instance of the generator. + * + * @param map In the given map keys are names of hyper parameters + * and values are arrays of values for hyper parameter presented in the key. + */ + public ParameterSetGenerator(Map<Integer, Double[]> map) { + assert map != null; + assert !map.isEmpty(); + + this.map = map; + this.sizeOfParamVector = map.size(); + } + + /** + * Returns the list of tuples presented as arrays. + */ + public List<Double[]> generate() { + + Double[] nextPnt = new Double[sizeOfParamVector]; + + traverseTree(map, nextPnt, -1); + + return params; + } + + /** + * Traverse tree on the current level and starts procedure of child traversing. + * + * @param map The current state of the data. + * @param nextPnt Next point. + * @param dimensionNum Dimension number. + */ + private void traverseTree(Map<Integer, Double[]> map, Double[] nextPnt, int dimensionNum) { + dimensionNum++; + + if (dimensionNum == sizeOfParamVector){ + Double[] paramSet = Arrays.copyOf(nextPnt, sizeOfParamVector); + System.out.println(Arrays.toString(paramSet)); + params.add(paramSet); + return; + } + + Double[] valuesOfCurrDimension = map.get(dimensionNum); + + for (Double specificValue : valuesOfCurrDimension) { + nextPnt[dimensionNum] = specificValue; + traverseTree(map, nextPnt, dimensionNum); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/package-info.java new file mode 100644 index 0000000..277083f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Root package for parameter grid. + */ +package org.apache.ignite.ml.selection.paramgrid; http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java index c94d2dd..c1e3abf 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java @@ -42,13 +42,13 @@ import org.apache.ignite.ml.tree.leaf.DecisionTreeLeafBuilder; */ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends DatasetTrainer<DecisionTreeNode, Double> { /** Max tree deep. */ - private final int maxDeep; + int maxDeep; /** Min impurity decrease. */ - private final double minImpurityDecrease; + double minImpurityDecrease; /** Step function compressor. */ - private final StepFunctionCompressor<T> compressor; + StepFunctionCompressor<T> compressor; /** Decision tree leaf builder. */ private final DecisionTreeLeafBuilder decisionTreeLeafBuilder; http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java index ce75190..71e387f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java @@ -46,6 +46,14 @@ public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurity } /** + * Constructs a new decision tree classifier with default impurity function compressor + * and default maxDeep = 5 and minImpurityDecrease = 0. + */ + public DecisionTreeClassificationTrainer() { + this(5, 0, null); + } + + /** * Constructs a new instance of decision tree classifier. * * @param maxDeep Max tree deep. @@ -56,6 +64,36 @@ public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurity super(maxDeep, minImpurityDecrease, compressor, new MostCommonDecisionTreeLeafBuilder()); } + /** + * Set up the max deep of decision tree. + * @param maxDeep The parameter value. + * @return Trainer with new maxDeep parameter value. + */ + public DecisionTreeClassificationTrainer withMaxDeep(Double maxDeep){ + this.maxDeep = maxDeep.intValue(); + return this; + } + + /** + * Set up the min impurity decrease of decision tree. + * @param minImpurityDecrease The parameter value. + * @return Trainer with new minImpurityDecrease parameter value. + */ + public DecisionTreeClassificationTrainer withMinImpurityDecrease(Double minImpurityDecrease){ + this.minImpurityDecrease = minImpurityDecrease; + return this; + } + + /** + * Set up the step function compressor of decision tree. + * @param compressor The parameter value. + * @return Trainer with new compressor parameter value. + */ + public DecisionTreeClassificationTrainer withCompressor(StepFunctionCompressor compressor){ + this.compressor = compressor; + return this; + } + /** {@inheritDoc} */ @Override ImpurityMeasureCalculator<GiniImpurityMeasure> getImpurityMeasureCalculator( Dataset<EmptyContext, DecisionTreeData> dataset) { http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java index cc69074..3adae79 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java @@ -18,6 +18,7 @@ package org.apache.ignite.ml.selection; import org.apache.ignite.ml.selection.cv.CrossValidationTest; +import org.apache.ignite.ml.selection.paramgrid.ParameterSetGeneratorTest; import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursorTest; import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursorTest; import org.apache.ignite.ml.selection.scoring.metric.AccuracyTest; @@ -35,6 +36,7 @@ import org.junit.runners.Suite; @RunWith(Suite.class) @Suite.SuiteClasses({ CrossValidationTest.class, + ParameterSetGeneratorTest.class, CacheBasedLabelPairCursorTest.class, LocalLabelPairCursorTest.class, AccuracyTest.class, http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/test/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGeneratorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGeneratorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGeneratorTest.java new file mode 100644 index 0000000..8e36024 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGeneratorTest.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.paramgrid; + +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link ParameterSetGenerator}. + */ +public class ParameterSetGeneratorTest { + /** */ + @Test + public void testParamSetGenerator() { + Map<Integer, Double[]> map = new TreeMap<>(); + map.put(0, new Double[]{1.1, 2.1}); + map.put(1, new Double[]{1.2, 2.2, 3.2, 4.2}); + map.put(2, new Double[]{1.3, 2.3}); + map.put(3, new Double[]{1.4}); + + List<Double[]> res = new ParameterSetGenerator(map).generate(); + assertEquals(res.size(), 16); + } + /** */ + @Test(expected = java.lang.AssertionError.class) + public void testParamSetGeneratorWithEmptyMap() { + Map<Integer, Double[]> map = new TreeMap<>(); + new ParameterSetGenerator(map).generate(); + + } + + /** */ + @Test(expected = java.lang.AssertionError.class) + public void testNullHandling() { + new ParameterSetGenerator(null).generate(); + } +}
