IGNITE-8924: [ML] Parameter Grid for tuning hyper-parameters in
Cross-Validation process.

this closes #4425


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/5cddf920
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/5cddf920
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/5cddf920

Branch: refs/heads/ignite-gg-13195-cache-groups
Commit: 5cddf9208faa873f972120bd15270a7f155986e4
Parents: 6263dbe
Author: zaleslaw <[email protected]>
Authored: Fri Jul 27 14:59:30 2018 +0300
Committer: Yury Babak <[email protected]>
Committed: Fri Jul 27 14:59:30 2018 +0300

----------------------------------------------------------------------
 .../ml/tutorial/Step_8_CV_with_Param_Grid.java  | 173 +++++++++++++++
 .../ignite/ml/selection/cv/CrossValidation.java | 212 +++++++++++++------
 .../ml/selection/cv/CrossValidationResult.java  | 134 ++++++++++++
 .../ml/selection/paramgrid/ParamGrid.java       |  58 +++++
 .../paramgrid/ParameterSetGenerator.java        |  91 ++++++++
 .../ml/selection/paramgrid/package-info.java    |  22 ++
 .../org/apache/ignite/ml/tree/DecisionTree.java |   6 +-
 .../tree/DecisionTreeClassificationTrainer.java |  38 ++++
 .../ignite/ml/selection/SelectionTestSuite.java |   2 +
 .../paramgrid/ParameterSetGeneratorTest.java    |  56 +++++
 10 files changed, 723 insertions(+), 69 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
----------------------------------------------------------------------
diff --git 
a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
 
b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
new file mode 100644
index 0000000..6104299
--- /dev/null
+++ 
b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.tutorial;
+
+import java.io.FileNotFoundException;
+import java.util.Arrays;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import 
org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer;
+import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
+import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
+import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
+import org.apache.ignite.ml.selection.cv.CrossValidation;
+import org.apache.ignite.ml.selection.cv.CrossValidationResult;
+import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
+import org.apache.ignite.ml.selection.split.TrainTestSplit;
+import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
+import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.thread.IgniteThread;
+
+/**
+ * To choose the best hyperparameters the cross-validation will be used in 
this example.
+ *
+ * The purpose of cross-validation is model checking, not model building.
+ *
+ * We train k different models.
+ *
+ * They differ in that 1/(k-1)th of the training data is exchanged against 
other cases.
+ *
+ * These models are sometimes called surrogate models because the (average) 
performance measured for these models
+ * is taken as a surrogate of the performance of the model trained on all 
cases.
+ *
+ * All scenarios are described there: 
https://sebastianraschka.com/faq/docs/evaluate-a-model.html
+ */
+public class Step_8_CV_with_Param_Grid {
+    /** Run example. */
+    public static void main(String[] args) throws InterruptedException {
+        try (Ignite ignite = 
Ignition.start("examples/config/example-ignite.xml")) {
+            IgniteThread igniteThread = new 
IgniteThread(ignite.configuration().getIgniteInstanceName(),
+                Step_8_CV_with_Param_Grid.class.getSimpleName(), () -> {
+                try {
+                    IgniteCache<Integer, Object[]> dataCache = 
TitanicUtils.readPassengers(ignite);
+
+                    // Defines first preprocessor that extracts features from 
an upstream data.
+                    // Extracts "pclass", "sibsp", "parch", "sex", "embarked", 
"age", "fare"
+                    IgniteBiFunction<Integer, Object[], Object[]> 
featureExtractor
+                        = (k, v) -> new Object[]{v[0], v[3], v[4], v[5], v[6], 
v[8], v[10]};
+
+                    IgniteBiFunction<Integer, Object[], Double> lbExtractor = 
(k, v) -> (double) v[1];
+
+                    TrainTestSplit<Integer, Object[]> split = new 
TrainTestDatasetSplitter<Integer, Object[]>()
+                        .split(0.75);
+
+                    IgniteBiFunction<Integer, Object[], Vector> 
strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>()
+                        .encodeFeature(1)
+                        .encodeFeature(6) // <--- Changed index here
+                        .fit(ignite,
+                            dataCache,
+                            featureExtractor
+                        );
+
+                    IgniteBiFunction<Integer, Object[], Vector> 
imputingPreprocessor = new ImputerTrainer<Integer, Object[]>()
+                        .fit(ignite,
+                            dataCache,
+                            strEncoderPreprocessor
+                        );
+
+                    IgniteBiFunction<Integer, Object[], Vector> 
minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>()
+                        .fit(
+                            ignite,
+                            dataCache,
+                            imputingPreprocessor
+                        );
+
+                    IgniteBiFunction<Integer, Object[], Vector> 
normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>()
+                        .withP(2)
+                        .fit(
+                            ignite,
+                            dataCache,
+                            minMaxScalerPreprocessor
+                        );
+
+                    // Tune hyperparams with K-fold Cross-Validation on the 
splitted training set.
+
+                    DecisionTreeClassificationTrainer trainerCV = new 
DecisionTreeClassificationTrainer();
+
+                    CrossValidation<DecisionTreeNode, Double, Integer, 
Object[]> scoreCalculator
+                        = new CrossValidation<>();
+
+                    ParamGrid paramGrid = new ParamGrid()
+                        .addHyperParam("maxDeep", new Double[]{1.0, 2.0, 3.0, 
4.0, 5.0, 10.0, 10.0})
+                        .addHyperParam("minImpurityDecrease", new 
Double[]{0.0, 0.25, 0.5});
+
+                    CrossValidationResult crossValidationRes = 
scoreCalculator.score(
+                        trainerCV,
+                        new Accuracy<>(),
+                        ignite,
+                        dataCache,
+                        split.getTrainFilter(),
+                        normalizationPreprocessor,
+                        lbExtractor,
+                        3,
+                        paramGrid
+                    );
+
+                    System.out.println("Train with maxDeep: " + 
crossValidationRes.getBest("maxDeep") + " and minImpurityDecrease: " + 
crossValidationRes.getBest("minImpurityDecrease"));
+
+                    DecisionTreeClassificationTrainer trainer = new 
DecisionTreeClassificationTrainer()
+                        .withMaxDeep(crossValidationRes.getBest("maxDeep"))
+                        
.withMinImpurityDecrease(crossValidationRes.getBest("minImpurityDecrease"));
+
+                    System.out.println(crossValidationRes);
+
+                    System.out.println("Best score: " + 
Arrays.toString(crossValidationRes.getBestScore()));
+                    System.out.println("Best hyper params: " + 
crossValidationRes.getBestHyperParams());
+                    System.out.println("Best average score: " + 
crossValidationRes.getBestAvgScore());
+
+                    crossValidationRes.getScoringBoard().forEach((hyperParams, 
score) -> {
+                        System.out.println("Score " + Arrays.toString(score) + 
" for hyper params " + hyperParams);
+                    });
+
+                    // Train decision tree model.
+                    DecisionTreeNode bestMdl = trainer.fit(
+                        ignite,
+                        dataCache,
+                        split.getTrainFilter(),
+                        normalizationPreprocessor,
+                        lbExtractor
+                    );
+
+                    double accuracy = Evaluator.evaluate(
+                        dataCache,
+                        split.getTestFilter(),
+                        bestMdl,
+                        normalizationPreprocessor,
+                        lbExtractor,
+                        new Accuracy<>()
+                    );
+
+                    System.out.println("\n>>> Accuracy " + accuracy);
+                    System.out.println("\n>>> Test Error " + (1 - accuracy));
+                } catch (FileNotFoundException e) {
+                    e.printStackTrace();
+                }
+            });
+
+            igniteThread.start();
+
+            igniteThread.join();
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java
index f417fab..1ade876 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java
@@ -17,6 +17,11 @@
 
 package org.apache.ignite.ml.selection.cv;
 
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.function.BiFunction;
 import java.util.function.Function;
@@ -29,6 +34,8 @@ import 
org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
+import org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator;
 import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursor;
 import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
 import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursor;
@@ -42,9 +49,9 @@ import org.apache.ignite.ml.trainers.DatasetTrainer;
  * following way: the training set is split into k smaller sets. The following 
procedure is followed for each of the k
  * “folds”:
  * <ul>
- *     <li>A model is trained using k-1 of the folds as training data;</li>
- *     <li>the resulting model is validated on the remaining part of the data 
(i.e., it is used as a test set to compute
- *     a performance measure such as accuracy).</li>
+ *    <li>A model is trained using k-1 of the folds as training data;</li>
+ *    <li>the resulting model is validated on the remaining part of the data 
(i.e., it is used as a test set to compute
+ * a performance measure such as accuracy).</li>
  * </ul>
  *
  * @param <M> Type of model.
@@ -56,18 +63,18 @@ public class CrossValidation<M extends Model<Vector, L>, L, 
K, V> {
     /**
      * Computes cross-validated metrics.
      *
-     * @param trainer Trainer of the model.
-     * @param scoreCalculator Score calculator.
-     * @param ignite Ignite instance.
-     * @param upstreamCache Ignite cache with {@code upstream} data.
+     * @param trainer          Trainer of the model.
+     * @param scoreCalculator  Score calculator.
+     * @param ignite           Ignite instance.
+     * @param upstreamCache    Ignite cache with {@code upstream} data.
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor Label extractor.
-     * @param cv Number of folds.
+     * @param lbExtractor      Label extractor.
+     * @param cv               Number of folds.
      * @return Array of scores of the estimator for each run of the cross 
validation.
      */
     public double[] score(DatasetTrainer<M, L> trainer, Metric<L> 
scoreCalculator, Ignite ignite,
-                          IgniteCache<K, V> upstreamCache, IgniteBiFunction<K, 
V, Vector> featureExtractor,
-                          IgniteBiFunction<K, V, L> lbExtractor, int cv) {
+        IgniteCache<K, V> upstreamCache, IgniteBiFunction<K, V, Vector> 
featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor, int cv) {
         return score(trainer, scoreCalculator, ignite, upstreamCache, (k, v) 
-> true, featureExtractor, lbExtractor,
             new SHA256UniformMapper<>(), cv);
     }
@@ -75,35 +82,109 @@ public class CrossValidation<M extends Model<Vector, L>, 
L, K, V> {
     /**
      * Computes cross-validated metrics.
      *
-     * @param trainer Trainer of the model.
-     * @param scoreCalculator Base score calculator.
-     * @param ignite Ignite instance.
-     * @param upstreamCache Ignite cache with {@code upstream} data.
-     * @param filter Base {@code upstream} data filter.
+     * @param trainer          Trainer of the model.
+     * @param scoreCalculator  Base score calculator.
+     * @param ignite           Ignite instance.
+     * @param upstreamCache    Ignite cache with {@code upstream} data.
+     * @param filter           Base {@code upstream} data filter.
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor Label extractor.
-     * @param cv Number of folds.
+     * @param lbExtractor      Label extractor.
+     * @param cv               Number of folds.
      * @return Array of scores of the estimator for each run of the cross 
validation.
      */
     public double[] score(DatasetTrainer<M, L> trainer, Metric<L> 
scoreCalculator, Ignite ignite,
-                          IgniteCache<K, V> upstreamCache, 
IgniteBiPredicate<K, V> filter,
-                          IgniteBiFunction<K, V, Vector> featureExtractor, 
IgniteBiFunction<K, V, L> lbExtractor, int cv) {
+        IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, L> lbExtractor, int cv) {
         return score(trainer, scoreCalculator, ignite, upstreamCache, filter, 
featureExtractor, lbExtractor,
             new SHA256UniformMapper<>(), cv);
     }
 
     /**
+     * Computes cross-validated metrics with a passed parameter grid.
+     *
+     * The real cross-validation training will be called each time for each 
parameter set.
+     *
+     * @param trainer          Trainer of the model.
+     * @param scoreCalculator  Base score calculator.
+     * @param ignite           Ignite instance.
+     * @param upstreamCache    Ignite cache with {@code upstream} data.
+     * @param filter           Base {@code upstream} data filter.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor      Label extractor.
+     * @param cv               Number of folds.
+     * @param paramGrid        Parameter grid.
+     * @return Array of scores of the estimator for each run of the cross 
validation.
+     */
+    public CrossValidationResult score(DatasetTrainer<M, L> trainer, Metric<L> 
scoreCalculator, Ignite ignite,
+        IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, L> lbExtractor, int cv,
+        ParamGrid paramGrid) {
+
+        List<Double[]> paramSets = new 
ParameterSetGenerator(paramGrid.getParamValuesByParamIndex()).generate();
+
+        CrossValidationResult cvRes = new CrossValidationResult();
+
+        paramSets.forEach(paramSet -> {
+            Map<String, Double> paramMap = new HashMap<>();
+
+            for (int paramIdx = 0; paramIdx < paramSet.length; paramIdx++) {
+                String paramName = paramGrid.getParamNameByIndex(paramIdx);
+                Double paramVal = paramSet[paramIdx];
+
+                paramMap.put(paramName, paramVal);
+
+                try {
+                    final String mtdName = "with" +
+                        paramName.substring(0, 1).toUpperCase() +
+                        paramName.substring(1);
+
+                    Method trainerSetter = null;
+
+                    // We should iterate along all methods due to we have no 
info about signature and passed types.
+                    for (Method method : 
trainer.getClass().getDeclaredMethods()) {
+                        if (method.getName().equals(mtdName))
+                            trainerSetter = method;
+                    }
+
+                    if (trainerSetter != null)
+                        trainerSetter.invoke(trainer, paramVal);
+                    else
+                        throw new NoSuchMethodException(mtdName);
+
+                } catch (NoSuchMethodException | IllegalAccessException | 
InvocationTargetException e) {
+                    e.printStackTrace();
+                }
+            }
+
+            double[] locScores = score(trainer, scoreCalculator, ignite, 
upstreamCache, filter, featureExtractor, lbExtractor,
+                new SHA256UniformMapper<>(), cv);
+
+            cvRes.addScores(locScores, paramMap);
+
+            final double locAvgScore = 
Arrays.stream(locScores).average().orElse(Double.MIN_VALUE);
+
+            if (locAvgScore > cvRes.getBestAvgScore()) {
+                cvRes.setBestScore(locScores);
+                cvRes.setBestHyperParams(paramMap);
+                System.out.println(paramMap.toString());
+            }
+        });
+
+        return cvRes;
+    }
+
+    /**
      * Computes cross-validated metrics.
      *
-     * @param trainer Trainer of the model.
-     * @param scoreCalculator Base score calculator.
-     * @param ignite Ignite instance.
-     * @param upstreamCache Ignite cache with {@code upstream} data.
-     * @param filter Base {@code upstream} data filter.
+     * @param trainer          Trainer of the model.
+     * @param scoreCalculator  Base score calculator.
+     * @param ignite           Ignite instance.
+     * @param upstreamCache    Ignite cache with {@code upstream} data.
+     * @param filter           Base {@code upstream} data filter.
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor Label extractor.
-     * @param mapper Mapper used to map a key-value pair to a point on the 
segment (0, 1).
-     * @param cv Number of folds.
+     * @param lbExtractor      Label extractor.
+     * @param mapper           Mapper used to map a key-value pair to a point 
on the segment (0, 1).
+     * @param cv               Number of folds.
      * @return Array of scores of the estimator for each run of the cross 
validation.
      */
     public double[] score(DatasetTrainer<M, L> trainer, Metric<L> 
scoreCalculator,
@@ -136,17 +217,17 @@ public class CrossValidation<M extends Model<Vector, L>, 
L, K, V> {
     /**
      * Computes cross-validated metrics.
      *
-     * @param trainer Trainer of the model.
-     * @param scoreCalculator Base score calculator.
-     * @param upstreamMap Map with {@code upstream} data.
-     * @param parts Number of partitions.
+     * @param trainer          Trainer of the model.
+     * @param scoreCalculator  Base score calculator.
+     * @param upstreamMap      Map with {@code upstream} data.
+     * @param parts            Number of partitions.
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor Label extractor.
-     * @param cv Number of folds.
+     * @param lbExtractor      Label extractor.
+     * @param cv               Number of folds.
      * @return Array of scores of the estimator for each run of the cross 
validation.
      */
     public double[] score(DatasetTrainer<M, L> trainer, Metric<L> 
scoreCalculator, Map<K, V> upstreamMap,
-                          int parts, IgniteBiFunction<K, V, Vector> 
featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) {
+        int parts, IgniteBiFunction<K, V, Vector> featureExtractor, 
IgniteBiFunction<K, V, L> lbExtractor, int cv) {
         return score(trainer, scoreCalculator, upstreamMap, (k, v) -> true, 
parts, featureExtractor, lbExtractor,
             new SHA256UniformMapper<>(), cv);
     }
@@ -154,19 +235,19 @@ public class CrossValidation<M extends Model<Vector, L>, 
L, K, V> {
     /**
      * Computes cross-validated metrics.
      *
-     * @param trainer Trainer of the model.
-     * @param scoreCalculator Base score calculator.
-     * @param upstreamMap Map with {@code upstream} data.
-     * @param filter Base {@code upstream} data filter.
-     * @param parts Number of partitions.
+     * @param trainer          Trainer of the model.
+     * @param scoreCalculator  Base score calculator.
+     * @param upstreamMap      Map with {@code upstream} data.
+     * @param filter           Base {@code upstream} data filter.
+     * @param parts            Number of partitions.
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor Label extractor.
-     * @param cv Number of folds.
+     * @param lbExtractor      Label extractor.
+     * @param cv               Number of folds.
      * @return Array of scores of the estimator for each run of the cross 
validation.
      */
     public double[] score(DatasetTrainer<M, L> trainer, Metric<L> 
scoreCalculator, Map<K, V> upstreamMap,
-                          IgniteBiPredicate<K, V> filter, int parts, 
IgniteBiFunction<K, V, Vector> featureExtractor,
-                          IgniteBiFunction<K, V, L> lbExtractor, int cv) {
+        IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, 
Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor, int cv) {
         return score(trainer, scoreCalculator, upstreamMap, filter, parts, 
featureExtractor, lbExtractor,
             new SHA256UniformMapper<>(), cv);
     }
@@ -174,20 +255,20 @@ public class CrossValidation<M extends Model<Vector, L>, 
L, K, V> {
     /**
      * Computes cross-validated metrics.
      *
-     * @param trainer Trainer of the model.
-     * @param scoreCalculator Base score calculator.
-     * @param upstreamMap Map with {@code upstream} data.
-     * @param filter Base {@code upstream} data filter.
-     * @param parts Number of partitions.
+     * @param trainer          Trainer of the model.
+     * @param scoreCalculator  Base score calculator.
+     * @param upstreamMap      Map with {@code upstream} data.
+     * @param filter           Base {@code upstream} data filter.
+     * @param parts            Number of partitions.
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor Label extractor.
-     * @param mapper Mapper used to map a key-value pair to a point on the 
segment (0, 1).
-     * @param cv Number of folds.
+     * @param lbExtractor      Label extractor.
+     * @param mapper           Mapper used to map a key-value pair to a point 
on the segment (0, 1).
+     * @param cv               Number of folds.
      * @return Array of scores of the estimator for each run of the cross 
validation.
      */
     public double[] score(DatasetTrainer<M, L> trainer, Metric<L> 
scoreCalculator, Map<K, V> upstreamMap,
-                          IgniteBiPredicate<K, V> filter, int parts, 
IgniteBiFunction<K, V, Vector> featureExtractor,
-                          IgniteBiFunction<K, V, L> lbExtractor, 
UniformMapper<K, V> mapper, int cv) {
+        IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, 
Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor, UniformMapper<K, V> mapper, int 
cv) {
         return score(
             trainer,
             predicate -> new LocalDatasetBuilder<>(
@@ -213,21 +294,21 @@ public class CrossValidation<M extends Model<Vector, L>, 
L, K, V> {
     /**
      * Computes cross-validated metrics.
      *
-     * @param trainer Trainer of the model.
+     * @param trainer                Trainer of the model.
      * @param datasetBuilderSupplier Dataset builder supplier.
-     * @param testDataIterSupplier Test data iterator supplier.
-     * @param featureExtractor Feature extractor.
-     * @param lbExtractor Label extractor.
-     * @param scoreCalculator Base score calculator.
-     * @param mapper Mapper used to map a key-value pair to a point on the 
segment (0, 1).
-     * @param cv Number of folds.
+     * @param testDataIterSupplier   Test data iterator supplier.
+     * @param featureExtractor       Feature extractor.
+     * @param lbExtractor            Label extractor.
+     * @param scoreCalculator        Base score calculator.
+     * @param mapper                 Mapper used to map a key-value pair to a 
point on the segment (0, 1).
+     * @param cv                     Number of folds.
      * @return Array of scores of the estimator for each run of the cross 
validation.
      */
     private double[] score(DatasetTrainer<M, L> trainer, 
Function<IgniteBiPredicate<K, V>,
         DatasetBuilder<K, V>> datasetBuilderSupplier,
-                           BiFunction<IgniteBiPredicate<K, V>, M, 
LabelPairCursor<L>> testDataIterSupplier,
-                           IgniteBiFunction<K, V, Vector> featureExtractor, 
IgniteBiFunction<K, V, L> lbExtractor,
-                           Metric<L> scoreCalculator, UniformMapper<K, V> 
mapper, int cv) {
+        BiFunction<IgniteBiPredicate<K, V>, M, LabelPairCursor<L>> 
testDataIterSupplier,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, L> lbExtractor,
+        Metric<L> scoreCalculator, UniformMapper<K, V> mapper, int cv) {
 
         double[] scores = new double[cv];
 
@@ -246,8 +327,7 @@ public class CrossValidation<M extends Model<Vector, L>, L, 
K, V> {
 
             try (LabelPairCursor<L> cursor = 
testDataIterSupplier.apply(trainSetFilter, mdl)) {
                 scores[i] = scoreCalculator.score(cursor.iterator());
-            }
-            catch (Exception e) {
+            } catch (Exception e) {
                 throw new RuntimeException(e);
             }
         }

http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationResult.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationResult.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationResult.java
new file mode 100644
index 0000000..55c20be
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationResult.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.cv;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Represents the cross validation procedure result,
+ * wraps score and values of hyper parameters associated with these values.
+ */
+public class CrossValidationResult {
+    /** Best hyper params. */
+    private Map<String, Double> bestHyperParams;
+
+    /** Best score. */
+    private double[] bestScore;
+
+    /**
+     * Scoring board.
+     * The key is map of hyper parameters and its values,
+     * the value is score result associated with set of hyper paramters 
presented in the key.
+     */
+    private Map<Map<String, Double>, double[]> scoringBoard = new HashMap<>();
+
+    /**
+     * Default constructor.
+     */
+    CrossValidationResult() {
+    }
+
+    /**
+     * Gets the best value for the specific hyper parameter.
+     *
+     * @param hyperParamName Hyper parameter name.
+     * @return The value.
+     */
+    public double getBest(String hyperParamName) {
+        return bestHyperParams.get(hyperParamName);
+    }
+
+    /**
+     * Gets the best score for the specific hyper parameter.
+     *
+     * @return The value.
+     */
+    public double[] getBestScore() {
+        return bestScore;
+    }
+
+    /**
+     * Adds local scores and associated parameter set to the scoring board.
+     *
+     * @param locScores The scores.
+     * @param paramMap  The parameter set associated with the given scores.
+     */
+    void addScores(double[] locScores, Map<String, Double> paramMap) {
+        scoringBoard.put(paramMap, locScores);
+    }
+
+    /**
+     * Gets the the average value of best score array.
+     *
+     * @return The value.
+     */
+    public double getBestAvgScore() {
+        if (bestScore == null)
+            return Double.MIN_VALUE;
+        return Arrays.stream(bestScore).average().orElse(Double.MIN_VALUE);
+    }
+
+    /**
+     * Helper method in cross-validation process.
+     *
+     * @param bestScore The best score.
+     */
+    void setBestScore(double[] bestScore) {
+        this.bestScore = bestScore;
+    }
+
+    /**
+     * Helper method in cross-validation process.
+     *
+     * @param bestHyperParams The best hyper parameters.
+     */
+    void setBestHyperParams(Map<String, Double> bestHyperParams) {
+        this.bestHyperParams = bestHyperParams;
+    }
+
+    /**
+     * Gets the Scoring Board.
+     *
+     * The key is map of hyper parameters and its values,
+     * the value is score result associated with set of hyper paramters 
presented in the key.
+     *
+     * @return The Scoring Board.
+     */
+    public Map<Map<String, Double>, double[]> getScoringBoard() {
+        return scoringBoard;
+    }
+
+    /**
+     * Gets the best hyper parameters set.
+     *
+     * @return The value.
+     */
+    public Map<String, Double> getBestHyperParams() {
+        return bestHyperParams;
+    }
+
+    /** {@inheritDoc} */
+    @Override public String toString() {
+        return "CrossValidationResult{" +
+            "bestHyperParams=" + bestHyperParams +
+            ", bestScore=" + Arrays.toString(bestScore) +
+            '}';
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParamGrid.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParamGrid.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParamGrid.java
new file mode 100644
index 0000000..3279d93
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParamGrid.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.paramgrid;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Keeps the grid of parameters.
+ */
+public class ParamGrid {
+    /** Parameter values by parameter index. */
+    private Map<Integer, Double[]> paramValuesByParamIndex = new HashMap<>();
+
+    /** Parameter names by parameter index. */
+    private Map<Integer, String> paramNamesByParamIndex = new HashMap<>();
+
+    /** Parameter counter. */
+    private int paramCntr;
+
+    /** */
+    public Map<Integer, Double[]> getParamValuesByParamIndex() {
+        return paramValuesByParamIndex;
+    }
+
+    /**
+     * Adds a grid for the specific hyper parameter.
+     * @param paramName The parameter name.
+     * @param params The array of the given hyper parameter values.
+     * @return The updated ParamGrid.
+     */
+    public ParamGrid addHyperParam(String paramName, Double[] params) {
+        paramValuesByParamIndex.put(paramCntr, params);
+        paramNamesByParamIndex.put(paramCntr, paramName);
+        paramCntr++;
+        return this;
+    }
+
+    /** */
+    public String getParamNameByIndex(int idx) {
+        return paramNamesByParamIndex.get(idx);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGenerator.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGenerator.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGenerator.java
new file mode 100644
index 0000000..b6c86a1
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGenerator.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.paramgrid;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Generates tuples of hyper parameter values by given map.
+ *
+ * In the given map keys are names of hyper parameters
+ * and values are arrays of values for hyper parameter presented in the key.
+ */
+public class ParameterSetGenerator {
+    /** Size of parameter vector. Default value is 100. */
+    private int sizeOfParamVector = 100;
+
+    /** Params. */
+    private List<Double[]> params = new ArrayList<>();
+
+    /** The given map of hyper parameters and its values. */
+    private Map<Integer, Double[]> map;
+
+    /**
+     * Creates an instance of the generator.
+     *
+     * @param map In the given map keys are names of hyper parameters
+     * and values are arrays of values for hyper parameter presented in the 
key.
+     */
+    public ParameterSetGenerator(Map<Integer, Double[]> map) {
+        assert map != null;
+        assert !map.isEmpty();
+
+        this.map = map;
+        this.sizeOfParamVector = map.size();
+    }
+
+    /**
+     * Returns the list of tuples presented as arrays.
+     */
+    public List<Double[]> generate() {
+
+        Double[] nextPnt = new Double[sizeOfParamVector];
+
+        traverseTree(map, nextPnt, -1);
+
+        return params;
+    }
+
+    /**
+     * Traverse tree on the current level and starts procedure of child 
traversing.
+     *
+     * @param map The current state of the data.
+     * @param nextPnt Next point.
+     * @param dimensionNum Dimension number.
+     */
+    private void traverseTree(Map<Integer, Double[]> map, Double[] nextPnt, 
int dimensionNum) {
+        dimensionNum++;
+
+        if (dimensionNum == sizeOfParamVector){
+            Double[] paramSet = Arrays.copyOf(nextPnt, sizeOfParamVector);
+            System.out.println(Arrays.toString(paramSet));
+            params.add(paramSet);
+            return;
+        }
+
+        Double[] valuesOfCurrDimension = map.get(dimensionNum);
+
+        for (Double specificValue : valuesOfCurrDimension) {
+            nextPnt[dimensionNum] = specificValue;
+            traverseTree(map, nextPnt, dimensionNum);
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/package-info.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/package-info.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/package-info.java
new file mode 100644
index 0000000..277083f
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Root package for parameter grid.
+ */
+package org.apache.ignite.ml.selection.paramgrid;

http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
index c94d2dd..c1e3abf 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
@@ -42,13 +42,13 @@ import 
org.apache.ignite.ml.tree.leaf.DecisionTreeLeafBuilder;
  */
 public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends 
DatasetTrainer<DecisionTreeNode, Double> {
     /** Max tree deep. */
-    private final int maxDeep;
+    int maxDeep;
 
     /** Min impurity decrease. */
-    private final double minImpurityDecrease;
+    double minImpurityDecrease;
 
     /** Step function compressor. */
-    private final StepFunctionCompressor<T> compressor;
+    StepFunctionCompressor<T> compressor;
 
     /** Decision tree leaf builder. */
     private final DecisionTreeLeafBuilder decisionTreeLeafBuilder;

http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
index ce75190..71e387f 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
@@ -46,6 +46,14 @@ public class DecisionTreeClassificationTrainer extends 
DecisionTree<GiniImpurity
     }
 
     /**
+     * Constructs a new decision tree classifier with default impurity 
function compressor
+     * and default maxDeep = 5 and minImpurityDecrease = 0.
+     */
+    public DecisionTreeClassificationTrainer() {
+        this(5, 0, null);
+    }
+
+    /**
      * Constructs a new instance of decision tree classifier.
      *
      * @param maxDeep Max tree deep.
@@ -56,6 +64,36 @@ public class DecisionTreeClassificationTrainer extends 
DecisionTree<GiniImpurity
         super(maxDeep, minImpurityDecrease, compressor, new 
MostCommonDecisionTreeLeafBuilder());
     }
 
+    /**
+     * Set up the max deep of decision tree.
+     * @param maxDeep The parameter value.
+     * @return Trainer with new maxDeep parameter value.
+     */
+    public DecisionTreeClassificationTrainer withMaxDeep(Double maxDeep){
+        this.maxDeep = maxDeep.intValue();
+        return this;
+    }
+
+    /**
+     * Set up the min impurity decrease of decision tree.
+     * @param minImpurityDecrease The parameter value.
+     * @return Trainer with new minImpurityDecrease parameter value.
+     */
+    public DecisionTreeClassificationTrainer withMinImpurityDecrease(Double 
minImpurityDecrease){
+        this.minImpurityDecrease = minImpurityDecrease;
+        return this;
+    }
+
+    /**
+     * Set up the step function compressor of decision tree.
+     * @param compressor The parameter value.
+     * @return Trainer with new compressor parameter value.
+     */
+    public DecisionTreeClassificationTrainer 
withCompressor(StepFunctionCompressor compressor){
+        this.compressor = compressor;
+        return this;
+    }
+
     /** {@inheritDoc} */
     @Override ImpurityMeasureCalculator<GiniImpurityMeasure> 
getImpurityMeasureCalculator(
         Dataset<EmptyContext, DecisionTreeData> dataset) {

http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java
index cc69074..3adae79 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java
@@ -18,6 +18,7 @@
 package org.apache.ignite.ml.selection;
 
 import org.apache.ignite.ml.selection.cv.CrossValidationTest;
+import org.apache.ignite.ml.selection.paramgrid.ParameterSetGeneratorTest;
 import 
org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursorTest;
 import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursorTest;
 import org.apache.ignite.ml.selection.scoring.metric.AccuracyTest;
@@ -35,6 +36,7 @@ import org.junit.runners.Suite;
 @RunWith(Suite.class)
 @Suite.SuiteClasses({
     CrossValidationTest.class,
+    ParameterSetGeneratorTest.class,
     CacheBasedLabelPairCursorTest.class,
     LocalLabelPairCursorTest.class,
     AccuracyTest.class,

http://git-wip-us.apache.org/repos/asf/ignite/blob/5cddf920/modules/ml/src/test/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGeneratorTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGeneratorTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGeneratorTest.java
new file mode 100644
index 0000000..8e36024
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/selection/paramgrid/ParameterSetGeneratorTest.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.paramgrid;
+
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link ParameterSetGenerator}.
+ */
+public class ParameterSetGeneratorTest {
+    /** */
+    @Test
+    public void testParamSetGenerator() {
+        Map<Integer, Double[]> map = new TreeMap<>();
+        map.put(0, new Double[]{1.1, 2.1});
+        map.put(1, new Double[]{1.2, 2.2, 3.2, 4.2});
+        map.put(2, new Double[]{1.3, 2.3});
+        map.put(3, new Double[]{1.4});
+
+        List<Double[]> res = new ParameterSetGenerator(map).generate();
+        assertEquals(res.size(), 16);
+    }
+    /** */
+    @Test(expected = java.lang.AssertionError.class)
+    public void testParamSetGeneratorWithEmptyMap() {
+        Map<Integer, Double[]> map = new TreeMap<>();
+        new ParameterSetGenerator(map).generate();
+
+    }
+
+    /** */
+    @Test(expected = java.lang.AssertionError.class)
+    public void testNullHandling() {
+       new ParameterSetGenerator(null).generate();
+    }
+}

Reply via email to