IGNITE-9482: [ML] Refactor all trainers' settters to withFieldName format for meta-algorithms
this closes #4699 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/b10ba044 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/b10ba044 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/b10ba044 Branch: refs/heads/ignite-5960 Commit: b10ba044d2d4357738c70f924b94e2e6a50c5f20 Parents: 0a23658 Author: zaleslaw <zaleslaw....@gmail.com> Authored: Fri Sep 7 13:31:58 2018 +0300 Committer: Yury Babak <yba...@gridgain.com> Committed: Fri Sep 7 13:31:58 2018 +0300 ---------------------------------------------------------------------- .../clustering/KMeansClusterizationExample.java | 4 +- .../RandomForestClassificationExample.java | 2 +- .../RandomForestRegressionExample.java | 2 +- .../clustering/kmeans/ClusterizationModel.java | 4 +- .../ml/clustering/kmeans/KMeansModel.java | 4 +- .../ml/clustering/kmeans/KMeansTrainer.java | 8 +- .../ml/knn/ann/ANNClassificationTrainer.java | 4 +- .../org/apache/ignite/ml/nn/MLPTrainer.java | 165 ++++++++++++++++++- .../binarization/BinarizationPreprocessor.java | 4 +- .../binarization/BinarizationTrainer.java | 9 +- .../preprocessing/encoding/EncoderTrainer.java | 11 ++ .../linear/LinearRegressionSGDTrainer.java | 58 ++++++- .../binomial/LogisticRegressionSGDTrainer.java | 107 +++++++++++- .../LogRegressionMultiClassTrainer.java | 26 +-- .../ignite/ml/selection/cv/CrossValidation.java | 2 +- .../ml/selection/paramgrid/ParamGrid.java | 14 +- .../SVMLinearBinaryClassificationTrainer.java | 22 +-- ...VMLinearMultiClassClassificationTrainer.java | 22 +-- .../org/apache/ignite/ml/tree/DecisionTree.java | 4 +- .../tree/DecisionTreeClassificationTrainer.java | 4 +- .../ml/tree/DecisionTreeRegressionTrainer.java | 10 +- .../GDBBinaryClassifierOnTreesTrainer.java | 73 ++++++-- .../boosting/GDBRegressionOnTreesTrainer.java | 67 +++++++- .../tree/randomforest/RandomForestTrainer.java | 14 +- .../ignite/ml/clustering/KMeansModelTest.java | 4 +- .../ignite/ml/clustering/KMeansTrainerTest.java | 8 +- .../ignite/ml/common/LocalModelsTest.java | 2 +- .../ml/composition/boosting/GDBTrainerTest.java | 8 +- .../ml/environment/LearningEnvironmentTest.java | 2 +- .../binarization/BinarizationTrainerTest.java | 6 +- .../logistic/LogRegMultiClassTrainerTest.java | 8 +- .../tree/DecisionTreeRegressionTrainerTest.java | 2 +- .../ml/tree/randomforest/RandomForestTest.java | 2 +- 33 files changed, 540 insertions(+), 142 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java index b96cbce..152375a 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java @@ -70,8 +70,8 @@ public class KMeansClusterizationExample { ); System.out.println(">>> KMeans centroids"); - Tracer.showAscii(mdl.centers()[0]); - Tracer.showAscii(mdl.centers()[1]); + Tracer.showAscii(mdl.getCenters()[0]); + Tracer.showAscii(mdl.getCenters()[1]); System.out.println(">>>"); System.out.println(">>> -----------------------------------"); http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java index aa13943..6194153 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java @@ -74,7 +74,7 @@ public class RandomForestClassificationExample { .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD) .withMaxDepth(4) .withMinImpurityDelta(0.) - .withSubsampleSize(0.3) + .withSubSampleSize(0.3) .withSeed(0); System.out.println(">>> Configured trainer: " + classifier.getClass().getSimpleName()); http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java index e2bfe8b..5f010f2 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java @@ -78,7 +78,7 @@ public class RandomForestRegressionExample { .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD) .withMaxDepth(4) .withMinImpurityDelta(0.) - .withSubsampleSize(0.3) + .withSubSampleSize(0.3) .withSeed(0); trainer.setEnvironment(LearningEnvironment.builder() http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/ClusterizationModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/ClusterizationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/ClusterizationModel.java index 474a463..43e1899 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/ClusterizationModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/ClusterizationModel.java @@ -22,8 +22,8 @@ import org.apache.ignite.ml.Model; /** Base interface for all clusterization models. */ public interface ClusterizationModel<P, V> extends Model<P, V> { /** Gets the clusters count. */ - public int amountOfClusters(); + public int getAmountOfClusters(); /** Get cluster centers. */ - public P[] centers(); + public P[] getCenters(); } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansModel.java index bdfa1b6..e07f4f0 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansModel.java @@ -54,12 +54,12 @@ public class KMeansModel implements ClusterizationModel<Vector, Integer>, Export } /** Amount of centers in clusterization. */ - @Override public int amountOfClusters() { + @Override public int getAmountOfClusters() { return centers.length; } /** Get centers of clusters. */ - @Override public Vector[] centers() { + @Override public Vector[] getCenters() { return Arrays.copyOf(centers, centers.length); } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java index 2596dbc..a20d5da 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java @@ -106,7 +106,7 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> { return getLastTrainedModelOrThrowEmptyDatasetException(mdl); centers = Optional.ofNullable(mdl) - .map(KMeansModel::centers) + .map(KMeansModel::getCenters) .orElseGet(() -> initClusterCentersRandomly(dataset, k)); boolean converged = false; @@ -143,7 +143,7 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> { /** {@inheritDoc} */ @Override protected boolean checkState(KMeansModel mdl) { - return mdl.centers().length == k && mdl.distanceMeasure().equals(distance); + return mdl.getCenters().length == k && mdl.distanceMeasure().equals(distance); } /** @@ -313,7 +313,7 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> { * * @return The parameter value. */ - public int getK() { + public int getAmountOfClusters() { return k; } @@ -323,7 +323,7 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> { * @param k The parameter value. * @return Model with new amount of clusters parameter value. */ - public KMeansTrainer withK(int k) { + public KMeansTrainer withAmountOfClusters(int k) { this.k = k; return this; } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java index 3e32b67..e56a10a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java @@ -131,7 +131,7 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass private <K, V> List<Vector> getCentroids(IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, DatasetBuilder<K, V> datasetBuilder) { KMeansTrainer trainer = new KMeansTrainer() - .withK(k) + .withAmountOfClusters(k) .withMaxIterations(maxIterations) .withSeed(seed) .withDistance(distance) @@ -143,7 +143,7 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass lbExtractor ); - return Arrays.asList(mdl.centers()); + return Arrays.asList(mdl.getCenters()); } /** */ http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java index 8f1a4cb..1cac909 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java @@ -46,25 +46,25 @@ import org.apache.ignite.ml.util.Utils; */ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer<MultilayerPerceptron> { /** Multilayer perceptron architecture supplier that defines layers and activators. */ - private final IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier; + private IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier; /** Loss function to be minimized during the training. */ - private final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss; + private IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss; /** Update strategy that defines how to update model parameters during the training. */ - private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy; + private UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy; /** Maximal number of iterations before the training will be stopped. */ - private final int maxIterations; + private int maxIterations = 100; /** Batch size (per every partition). */ - private final int batchSize; + private int batchSize = 100; /** Maximal number of local iterations before synchronization. */ - private final int locIterations; + private int locIterations = 100; /** Multilayer perceptron model initializer. */ - private final long seed; + private long seed = 1234L; /** * Constructs a new instance of multilayer perceptron trainer. @@ -119,14 +119,18 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) { + assert archSupplier != null; + assert loss!= null; + assert updatesStgy!= null; + try (Dataset<EmptyContext, SimpleLabeledDatasetData> dataset = datasetBuilder.build( new EmptyContextBuilder<>(), new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor) )) { MultilayerPerceptron mdl; - if (lastLearnedModel != null) { + if (lastLearnedModel != null) mdl = lastLearnedModel; - } else { + else { MLPArchitecture arch = archSupplier.apply(dataset); mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed)); } @@ -205,6 +209,149 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer } } + /** + * Get the multilayer perceptron architecture supplier that defines layers and activators. + * + * @return The property value. + */ + public IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> getArchSupplier() { + return archSupplier; + } + + /** + * Set up the multilayer perceptron architecture supplier that defines layers and activators. + * + * @param archSupplier The parameter value. + * @return Model with the multilayer perceptron architecture supplier that defines layers and activators. + */ + public MLPTrainer<P> withArchSupplier( + IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier) { + this.archSupplier = archSupplier; + return this; + } + + /** + * Get the loss function to be minimized during the training. + * + * @return The property value. + */ + public IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> getLoss() { + return loss; + } + + /** + * Set up the loss function to be minimized during the training. + * + * @param loss The parameter value. + * @return Model with the loss function to be minimized during the training. + */ + public MLPTrainer<P> withLoss( + IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) { + this.loss = loss; + return this; + } + + /** + * Get the update strategy that defines how to update model parameters during the training. + * + * @return The property value. + */ + public UpdatesStrategy<? super MultilayerPerceptron, P> getUpdatesStgy() { + return updatesStgy; + } + + /** + * Set up the update strategy that defines how to update model parameters during the training. + * + * @param updatesStgy The parameter value. + * @return Model with the update strategy that defines how to update model parameters during the training. + */ + public MLPTrainer<P> withUpdatesStgy( + UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy) { + this.updatesStgy = updatesStgy; + return this; + } + + /** + * Get the maximal number of iterations before the training will be stopped. + * + * @return The property value. + */ + public int getMaxIterations() { + return maxIterations; + } + + /** + * Set up the maximal number of iterations before the training will be stopped. + * + * @param maxIterations The parameter value. + * @return Model with the maximal number of iterations before the training will be stopped. + */ + public MLPTrainer<P> withMaxIterations(int maxIterations) { + this.maxIterations = maxIterations; + return this; + } + + /** + * Get the batch size (per every partition). + * + * @return The property value. + */ + public int getBatchSize() { + return batchSize; + } + + /** + * Set up the batch size (per every partition). + * + * @param batchSize The parameter value. + * @return Model with the batch size (per every partition). + */ + public MLPTrainer<P> withBatchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + /** + * Get the maximal number of local iterations before synchronization. + * + * @return The property value. + */ + public int getLocIterations() { + return locIterations; + } + + /** + * Set up the maximal number of local iterations before synchronization. + * + * @param locIterations The parameter value. + * @return Model with the maximal number of local iterations before synchronization. + */ + public MLPTrainer<P> withLocIterations(int locIterations) { + this.locIterations = locIterations; + return this; + } + + /** + * Get the multilayer perceptron model initializer. + * + * @return The property value. + */ + public long getSeed() { + return seed; + } + + /** + * Set up the multilayer perceptron model initializer. + * + * @param seed The parameter value. + * @return Model with the multilayer perceptron model initializer. + */ + public MLPTrainer<P> withSeed(long seed) { + this.seed = seed; + return this; + } + /** {@inheritDoc} */ @Override protected boolean checkState(MultilayerPerceptron mdl) { return true; http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationPreprocessor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationPreprocessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationPreprocessor.java index 8300820..2e1bd5c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationPreprocessor.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationPreprocessor.java @@ -68,8 +68,8 @@ public class BinarizationPreprocessor<K, V> implements IgniteBiFunction<K, V, Ve return res; } - /** Gets the threshold parameter. */ - public double threshold() { + /** Get the threshold parameter. */ + public double getThreshold() { return threshold; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainer.java index 26541e0..ad8c90e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainer.java @@ -39,15 +39,16 @@ public class BinarizationTrainer<K, V> implements PreprocessingTrainer<K, V, Vec } /** - * Gets the threshold parameter value. - * @return The parameter value. + * Get the threshold parameter value. + * + * @return The property value. */ - public double threshold() { + public double getThreshold() { return threshold; } /** - * Sets the threshold parameter value. + * Set the threshold parameter value. * * @param threshold The given value. * @return The Binarization trainer. http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java index f716d96..a23d642 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java @@ -222,4 +222,15 @@ public class EncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[] this.encoderType = type; return this; } + + /** + * Sets the indices of features which should be encoded. + * + * @param handledIndices Indices of features which should be encoded. + * @return The changed trainer. + */ + public EncoderTrainer<K, V> withEncoderType(Set<Integer> handledIndices) { + this.handledIndices = handledIndices; + return this; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java index 125ed24..4132d35 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java @@ -44,17 +44,17 @@ public class LinearRegressionSGDTrainer<P extends Serializable> extends SingleLa /** Update strategy. */ private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy; - /** Max number of iteration. */ + /** Max amount of iterations. */ private int maxIterations = 1000; /** Batch size. */ private int batchSize = 10; - /** Number of local iterations. */ + /** Amount of local iterations. */ private int locIterations = 100; /** Seed for random generator. */ - private long seed = System.currentTimeMillis(); + private long seed = 1234L; /** * Constructs a new instance of linear regression SGD trainer. @@ -89,9 +89,12 @@ public class LinearRegressionSGDTrainer<P extends Serializable> extends SingleLa } /** {@inheritDoc} */ - @Override protected <K, V> LinearRegressionModel updateModel(LinearRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, + @Override protected <K, V> LinearRegressionModel updateModel(LinearRegressionModel mdl, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + assert updatesStgy != null; + IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> { int cols = dataset.compute(data -> { @@ -162,7 +165,7 @@ public class LinearRegressionSGDTrainer<P extends Serializable> extends SingleLa } /** - * Set up the max number of iterations before convergence. + * Set up the max amount of iterations before convergence. * * @param maxIterations The parameter value. * @return Model with new max number of iterations before convergence parameter value. @@ -204,4 +207,49 @@ public class LinearRegressionSGDTrainer<P extends Serializable> extends SingleLa this.seed = seed; return this; } + + /** + * Get the update strategy. + * + * @return The property value. + */ + public UpdatesStrategy<? super MultilayerPerceptron, P> getUpdatesStgy() { + return updatesStgy; + } + + /** + * Get the max amount of iterations. + * + * @return The property value. + */ + public int getMaxIterations() { + return maxIterations; + } + + /** + * Get the batch size. + * + * @return The property value. + */ + public int getBatchSize() { + return batchSize; + } + + /** + * Get the amount of local iterations. + * + * @return The property value. + */ + public int getLocIterations() { + return locIterations; + } + + /** + * Get the seed for random generator. + * + * @return The property value. + */ + public long getSeed() { + return seed; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java index 839dab5..fb5d5a0 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java @@ -41,19 +41,19 @@ import org.jetbrains.annotations.NotNull; */ public class LogisticRegressionSGDTrainer<P extends Serializable> extends SingleLabelDatasetTrainer<LogisticRegressionModel> { /** Update strategy. */ - private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy; + private UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy; /** Max number of iteration. */ - private final int maxIterations; + private int maxIterations; /** Batch size. */ - private final int batchSize; + private int batchSize; /** Number of local iterations. */ - private final int locIterations; + private int locIterations; /** Seed for random generator. */ - private final long seed; + private long seed; /** * Constructs a new instance of linear regression SGD trainer. @@ -116,10 +116,11 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single IgniteBiFunction<K, V, double[]> lbExtractorWrapper = (k, v) -> new double[] {lbExtractor.apply(k, v)}; MultilayerPerceptron mlp; - if(mdl != null) { + if (mdl != null) { mlp = restoreMLPState(mdl); mlp = trainer.update(mlp, datasetBuilder, featureExtractor, lbExtractorWrapper); - } else + } + else mlp = trainer.fit(datasetBuilder, featureExtractor, lbExtractorWrapper); double[] params = mlp.parameters().getStorage().data(); @@ -136,8 +137,10 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single @NotNull private MultilayerPerceptron restoreMLPState(LogisticRegressionModel mdl) { Vector weights = mdl.weights(); double intercept = mdl.intercept(); + MLPArchitecture architecture1 = new MLPArchitecture(weights.size()); architecture1 = architecture1.withAddedLayer(1, true, Activators.SIGMOID); + MLPArchitecture architecture = architecture1; MultilayerPerceptron perceptron = new MultilayerPerceptron(architecture); @@ -145,6 +148,7 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single weights.nonZeroes().forEach(ith -> mlpState.set(ith.index(), ith.get())); mlpState.set(mlpState.size() - 1, intercept); perceptron.setParameters(mlpState); + return perceptron; } @@ -152,4 +156,93 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single @Override protected boolean checkState(LogisticRegressionModel mdl) { return true; } + + /** + * Set up the max amount of iterations before convergence. + * + * @param maxIterations The parameter value. + * @return Model with new max number of iterations before convergence parameter value. + */ + public LogisticRegressionSGDTrainer<P> withMaxIterations(int maxIterations) { + this.maxIterations = maxIterations; + return this; + } + + /** + * Set up the batchSize parameter. + * + * @param batchSize The size of learning batch. + * @return Trainer with new batch size parameter value. + */ + public LogisticRegressionSGDTrainer<P> withBatchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + /** + * Set up the amount of local iterations of SGD algorithm. + * + * @param amountOfLocIterations The parameter value. + * @return Trainer with new locIterations parameter value. + */ + public LogisticRegressionSGDTrainer<P> withLocIterations(int amountOfLocIterations) { + this.locIterations = amountOfLocIterations; + return this; + } + + /** + * Set up the random seed parameter. + * + * @param seed Seed for random generator. + * @return Trainer with new seed parameter value. + */ + public LogisticRegressionSGDTrainer<P> withSeed(long seed) { + this.seed = seed; + return this; + } + + /** + * Get the update strategy. + * + * @return The property value. + */ + public UpdatesStrategy<? super MultilayerPerceptron, P> getUpdatesStgy() { + return updatesStgy; + } + + /** + * Get the max amount of iterations. + * + * @return The property value. + */ + public int getMaxIterations() { + return maxIterations; + } + + /** + * Get the batch size. + * + * @return The property value. + */ + public int getBatchSize() { + return batchSize; + } + + /** + * Get the amount of local iterations. + * + * @return The property value. + */ + public int getLocIterations() { + return locIterations; + } + + /** + * Get the seed for random generator. + * + * @return The property value. + */ + public long getSeed() { + return seed; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java index eb44301..b9cdcc7 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java @@ -77,14 +77,14 @@ public class LogRegressionMultiClassTrainer<P extends Serializable> } /** {@inheritDoc} */ - @Override public <K, V> LogRegressionMultiClassModel updateModel(LogRegressionMultiClassModel mdl, + @Override public <K, V> LogRegressionMultiClassModel updateModel(LogRegressionMultiClassModel newMdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor); if(classes.isEmpty()) - return getLastTrainedModelOrThrowEmptyDatasetException(mdl); + return getLastTrainedModelOrThrowEmptyDatasetException(newMdl); LogRegressionMultiClassModel multiClsMdl = new LogRegressionMultiClassModel(); @@ -101,12 +101,12 @@ public class LogRegressionMultiClassTrainer<P extends Serializable> return 0.0; }; - LogisticRegressionModel model = Optional.ofNullable(mdl) + LogisticRegressionModel mdl = Optional.ofNullable(newMdl) .flatMap(multiClassModel -> multiClassModel.getModel(clsLb)) .map(learnedModel -> trainer.update(learnedModel, datasetBuilder, featureExtractor, lbTransformer)) .orElseGet(() -> trainer.fit(datasetBuilder, featureExtractor, lbTransformer)); - multiClsMdl.add(clsLb, model); + multiClsMdl.add(clsLb, mdl); }); return multiClsMdl; @@ -169,20 +169,20 @@ public class LogRegressionMultiClassTrainer<P extends Serializable> } /** - * Gets the batch size. + * Get the batch size. * * @return The parameter value. */ - public double batchSize() { + public double getBatchSize() { return batchSize; } /** - * Gets the amount of outer iterations of SGD algorithm. + * Get the amount of outer iterations of SGD algorithm. * * @return The parameter value. */ - public int amountOfIterations() { + public int getAmountOfIterations() { return amountOfIterations; } @@ -198,11 +198,11 @@ public class LogRegressionMultiClassTrainer<P extends Serializable> } /** - * Gets the amount of local iterations. + * Get the amount of local iterations. * * @return The parameter value. */ - public int amountOfLocIterations() { + public int getAmountOfLocIterations() { return amountOfLocIterations; } @@ -229,7 +229,7 @@ public class LogRegressionMultiClassTrainer<P extends Serializable> } /** - * Gets the seed for random generator. + * Get the seed for random generator. * * @return The parameter value. */ @@ -249,11 +249,11 @@ public class LogRegressionMultiClassTrainer<P extends Serializable> } /** - * Gets the update strategy.. + * Get the update strategy.. * * @return The parameter value. */ - public UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy() { + public UpdatesStrategy<? super MultilayerPerceptron, P> getUpdatesStgy() { return updatesStgy; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java index 1ade876..ef4f30f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java @@ -120,7 +120,7 @@ public class CrossValidation<M extends Model<Vector, L>, L, K, V> { IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv, ParamGrid paramGrid) { - List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIndex()).generate(); + List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate(); CrossValidationResult cvRes = new CrossValidationResult(); http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParamGrid.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParamGrid.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParamGrid.java index 3279d93..f9c5bd2 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParamGrid.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/paramgrid/ParamGrid.java @@ -25,17 +25,17 @@ import java.util.Map; */ public class ParamGrid { /** Parameter values by parameter index. */ - private Map<Integer, Double[]> paramValuesByParamIndex = new HashMap<>(); + private Map<Integer, Double[]> paramValuesByParamIdx = new HashMap<>(); /** Parameter names by parameter index. */ - private Map<Integer, String> paramNamesByParamIndex = new HashMap<>(); + private Map<Integer, String> paramNamesByParamIdx = new HashMap<>(); /** Parameter counter. */ private int paramCntr; /** */ - public Map<Integer, Double[]> getParamValuesByParamIndex() { - return paramValuesByParamIndex; + public Map<Integer, Double[]> getParamValuesByParamIdx() { + return paramValuesByParamIdx; } /** @@ -45,14 +45,14 @@ public class ParamGrid { * @return The updated ParamGrid. */ public ParamGrid addHyperParam(String paramName, Double[] params) { - paramValuesByParamIndex.put(paramCntr, params); - paramNamesByParamIndex.put(paramCntr, paramName); + paramValuesByParamIdx.put(paramCntr, params); + paramNamesByParamIdx.put(paramCntr, paramName); paramCntr++; return this; } /** */ public String getParamNameByIndex(int idx) { - return paramNamesByParamIndex.get(idx); + return paramNamesByParamIdx.get(idx); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java index 573df1a..8fb98d2 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java @@ -214,7 +214,7 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai double qii = v.dot(v); double newAlpha = calcNewAlpha(alpha, gradient, qii); - Vector deltaWeights = v.times(lb * (newAlpha - alpha) / (this.lambda() * amountOfObservation)); + Vector deltaWeights = v.times(lb * (newAlpha - alpha) / (this.getLambda() * amountOfObservation)); return new Deltas(newAlpha - alpha, deltaWeights); } @@ -233,7 +233,7 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai /** */ private double calcGradient(double lb, Vector v, Vector weights, int amountOfObservation) { double dotProduct = v.dot(weights); - return (lb * dotProduct - 1.0) * (this.lambda() * amountOfObservation); + return (lb * dotProduct - 1.0) * (this.getLambda() * amountOfObservation); } /** */ @@ -261,18 +261,18 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai } /** - * Gets the regularization lambda. + * Get the regularization lambda. * - * @return The parameter value. + * @return The property value. */ - public double lambda() { + public double getLambda() { return lambda; } /** - * Gets the amount of outer iterations of SCDA algorithm. + * Get the amount of outer iterations of SCDA algorithm. * - * @return The parameter value. + * @return The property value. */ public int getAmountOfIterations() { return amountOfIterations; @@ -290,9 +290,9 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai } /** - * Gets the amount of local iterations of SCDA algorithm. + * Get the amount of local iterations of SCDA algorithm. * - * @return The parameter value. + * @return The property value. */ public int getAmountOfLocIterations() { return amountOfLocIterations; @@ -310,9 +310,9 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai } /** - * Gets the seed number. + * Get the seed number. * - * @return The parameter value. + * @return The property value. */ public long getSeed() { return seed; http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java index b77baa2..aeee178 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java @@ -82,9 +82,9 @@ public class SVMLinearMultiClassClassificationTrainer classes.forEach(clsLb -> { SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer() - .withAmountOfIterations(this.amountOfIterations()) - .withAmountOfLocIterations(this.amountOfLocIterations()) - .withLambda(this.lambda()) + .withAmountOfIterations(this.getAmountOfIterations()) + .withAmountOfLocIterations(this.getAmountOfLocIterations()) + .withLambda(this.getLambda()) .withSeed(this.seed); IgniteBiFunction<K, V, Double> lbTransformer = (k, v) -> { @@ -197,20 +197,20 @@ public class SVMLinearMultiClassClassificationTrainer } /** - * Gets the regularization lambda. + * Get the regularization lambda. * - * @return The parameter value. + * @return The property value. */ - public double lambda() { + public double getLambda() { return lambda; } /** * Gets the amount of outer iterations of SCDA algorithm. * - * @return The parameter value. + * @return The property value. */ - public int amountOfIterations() { + public int getAmountOfIterations() { return amountOfIterations; } @@ -228,9 +228,9 @@ public class SVMLinearMultiClassClassificationTrainer /** * Gets the amount of local iterations of SCDA algorithm. * - * @return The parameter value. + * @return The property value. */ - public int amountOfLocIterations() { + public int getAmountOfLocIterations() { return amountOfLocIterations; } @@ -248,7 +248,7 @@ public class SVMLinearMultiClassClassificationTrainer /** * Gets the seed number. * - * @return The parameter value. + * @return The property value. */ public long getSeed() { return seed; http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java index 355048a..45774cb 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java @@ -54,7 +54,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset private final DecisionTreeLeafBuilder decisionTreeLeafBuilder; /** Use index structure instead of using sorting while learning. */ - protected boolean useIndex = true; + protected boolean usingIdx = true; /** * Constructs a new distributed decision tree trainer. @@ -77,7 +77,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build( new EmptyContextBuilder<>(), - new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, useIndex) + new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, usingIdx) )) { return fit(dataset); } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java index f8fc769..91ec8e1 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java @@ -91,7 +91,7 @@ public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurity * @return Decision tree trainer. */ public DecisionTreeClassificationTrainer withUseIndex(boolean useIndex) { - this.useIndex = useIndex; + this.usingIdx = useIndex; return this; } @@ -127,6 +127,6 @@ public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurity for (Double lb : labels) encoder.put(lb, idx++); - return new GiniImpurityMeasureCalculator(encoder, useIndex); + return new GiniImpurityMeasureCalculator(encoder, usingIdx); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java index 4c9aac9..ea57bcc 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java @@ -53,13 +53,13 @@ public class DecisionTreeRegressionTrainer extends DecisionTree<MSEImpurityMeasu } /** - * Sets useIndex parameter and returns trainer instance. + * Sets usingIdx parameter and returns trainer instance. * - * @param useIndex Use index. + * @param usingIdx Use index. * @return Decision tree trainer. */ - public DecisionTreeRegressionTrainer withUseIndex(boolean useIndex) { - this.useIndex = useIndex; + public DecisionTreeRegressionTrainer withUsingIdx(boolean usingIdx) { + this.usingIdx = usingIdx; return this; } @@ -67,6 +67,6 @@ public class DecisionTreeRegressionTrainer extends DecisionTree<MSEImpurityMeasu @Override protected ImpurityMeasureCalculator<MSEImpurityMeasure> getImpurityMeasureCalculator( Dataset<EmptyContext, DecisionTreeData> dataset) { - return new MSEImpurityMeasureCalculator(useIndex); + return new MSEImpurityMeasureCalculator(usingIdx); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java index 4d87b47..b99dc2f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java @@ -27,13 +27,13 @@ import org.jetbrains.annotations.NotNull; */ public class GDBBinaryClassifierOnTreesTrainer extends GDBBinaryClassifierTrainer { /** Max depth. */ - private final int maxDepth; + private int maxDepth; /** Min impurity decrease. */ - private final double minImpurityDecrease; + private double minImpurityDecrease; - /** Use index structure instead of using sorting while learning. */ - private boolean useIndex = true; + /** Use index structure instead of using sorting during the learning process. */ + private boolean usingIdx = true; /** * Constructs instance of GDBBinaryClassifierOnTreesTrainer. @@ -53,22 +53,71 @@ public class GDBBinaryClassifierOnTreesTrainer extends GDBBinaryClassifierTraine /** {@inheritDoc} */ @NotNull @Override protected DecisionTreeRegressionTrainer buildBaseModelTrainer() { - return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease).withUseIndex(useIndex); + return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease).withUsingIdx(usingIdx); + } + + /** {@inheritDoc} */ + @Override protected GDBLearningStrategy getLearningStrategy() { + return new GDBOnTreesLearningStrategy(usingIdx); } /** - * Sets useIndex parameter and returns trainer instance. + * Set useIndex parameter and returns trainer instance. * - * @param useIndex Use index. + * @param usingIdx Use index. * @return Decision tree trainer. */ - public GDBBinaryClassifierOnTreesTrainer withUseIndex(boolean useIndex) { - this.useIndex = useIndex; + public GDBBinaryClassifierOnTreesTrainer withUsingIdx(boolean usingIdx) { + this.usingIdx = usingIdx; return this; } - /** {@inheritDoc} */ - @Override protected GDBLearningStrategy getLearningStrategy() { - return new GDBOnTreesLearningStrategy(useIndex); + /** + * Get the max depth. + * + * @return The property value. + */ + public int getMaxDepth() { + return maxDepth; + } + + /** + * Set up the max depth. + * + * @param maxDepth The parameter value. + * @return Decision tree trainer. + */ + public GDBBinaryClassifierOnTreesTrainer setMaxDepth(int maxDepth) { + this.maxDepth = maxDepth; + return this; + } + + /** + * Get the min impurity decrease. + * + * @return The property value. + */ + public double getMinImpurityDecrease() { + return minImpurityDecrease; + } + + /** + * Set up the min impurity decrease. + * + * @param minImpurityDecrease The parameter value. + * @return Decision tree trainer. + */ + public GDBBinaryClassifierOnTreesTrainer setMinImpurityDecrease(double minImpurityDecrease) { + this.minImpurityDecrease = minImpurityDecrease; + return this; + } + + /** + * Get the using index structure property instead of using sorting during the learning process. + * + * @return The property value. + */ + public boolean isUsingIdx() { + return usingIdx; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java index e2a183c..b6c0b48 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java @@ -27,13 +27,13 @@ import org.jetbrains.annotations.NotNull; */ public class GDBRegressionOnTreesTrainer extends GDBRegressionTrainer { /** Max depth. */ - private final int maxDepth; + private int maxDepth; /** Min impurity decrease. */ - private final double minImpurityDecrease; + private double minImpurityDecrease; /** Use index structure instead of using sorting while learning. */ - private boolean useIndex = true; + private boolean usingIdx = true; /** * Constructs instance of GDBRegressionOnTreesTrainer. @@ -53,22 +53,71 @@ public class GDBRegressionOnTreesTrainer extends GDBRegressionTrainer { /** {@inheritDoc} */ @NotNull @Override protected DecisionTreeRegressionTrainer buildBaseModelTrainer() { - return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease).withUseIndex(useIndex); + return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease).withUsingIdx(usingIdx); } /** - * Sets useIndex parameter and returns trainer instance. + * Set useIndex parameter and returns trainer instance. * - * @param useIndex Use index. + * @param usingIdx Use index. * @return Decision tree trainer. */ - public GDBRegressionOnTreesTrainer withUseIndex(boolean useIndex) { - this.useIndex = useIndex; + public GDBRegressionOnTreesTrainer withUsingIdx(boolean usingIdx) { + this.usingIdx = usingIdx; return this; } + /** + * Get the max depth. + * + * @return The property value. + */ + public int getMaxDepth() { + return maxDepth; + } + + /** + * Set up the max depth. + * + * @param maxDepth The parameter value. + * @return Decision tree trainer. + */ + public GDBRegressionOnTreesTrainer setMaxDepth(int maxDepth) { + this.maxDepth = maxDepth; + return this; + } + + /** + * Get the min impurity decrease. + * + * @return The property value. + */ + public double getMinImpurityDecrease() { + return minImpurityDecrease; + } + + /** + * Set up the min impurity decrease. + * + * @param minImpurityDecrease The parameter value. + * @return Decision tree trainer. + */ + public GDBRegressionOnTreesTrainer setMinImpurityDecrease(double minImpurityDecrease) { + this.minImpurityDecrease = minImpurityDecrease; + return this; + } + + /** + * Get the using index structure property instead of using sorting during the learning process. + * + * @return The property value. + */ + public boolean isUsingIdx() { + return usingIdx; + } + /** {@inheritDoc} */ @Override protected GDBLearningStrategy getLearningStrategy() { - return new GDBOnTreesLearningStrategy(useIndex); + return new GDBOnTreesLearningStrategy(usingIdx); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java index 91fcf0a..c617d8d 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java @@ -76,7 +76,7 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra private int cntOfTrees = 1; /** Subsample size. */ - private double subsampleSize = 1.0; + private double subSampleSize = 1.0; /** Max depth. */ private int maxDepth = 5; @@ -88,10 +88,10 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra private List<FeatureMeta> meta; /** Features per tree. */ - private int featuresPerTree; + private int featuresPerTree = 5; /** Seed. */ - private long seed = System.currentTimeMillis(); + private long seed = 1234L; /** Random generator. */ private Random random = new Random(seed); @@ -115,7 +115,7 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra List<TreeRoot> models = null; try (Dataset<EmptyContext, BootstrappedDatasetPartition> dataset = datasetBuilder.build( new EmptyContextBuilder<>(), - new BootstrappedDatasetBuilder<>(featureExtractor, lbExtractor, cntOfTrees, subsampleSize))) { + new BootstrappedDatasetBuilder<>(featureExtractor, lbExtractor, cntOfTrees, subSampleSize))) { if(!init(dataset)) return buildComposition(Collections.emptyList()); @@ -144,11 +144,11 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra } /** - * @param subsampleSize Subsample size. + * @param subSampleSize Subsample size. * @return an instance of current object with valid type in according to inheritance. */ - public T withSubsampleSize(double subsampleSize) { - this.subsampleSize = subsampleSize; + public T withSubSampleSize(double subSampleSize) { + this.subSampleSize = subSampleSize; return instance(); } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java index 03e0e6d..f71b7b3 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java @@ -54,7 +54,7 @@ public class KMeansModelTest { Assert.assertEquals(mdl.apply(new DenseVector(new double[]{-1.1, -1.1})), 3.0, PRECISION); Assert.assertEquals(mdl.distanceMeasure(), distanceMeasure); - Assert.assertEquals(mdl.amountOfClusters(), 4); - Assert.assertArrayEquals(mdl.centers(), centers); + Assert.assertEquals(mdl.getAmountOfClusters(), 4); + Assert.assertArrayEquals(mdl.getCenters(), centers); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java index 03f044a..74ff8f1 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java @@ -58,7 +58,7 @@ public class KMeansTrainerTest { @Test public void findOneClusters() { KMeansTrainer trainer = createAndCheckTrainer(); - KMeansModel knnMdl = trainer.withK(1).fit( + KMeansModel knnMdl = trainer.withAmountOfClusters(1).fit( new LocalDatasetBuilder<>(data, 2), (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[2] @@ -76,7 +76,7 @@ public class KMeansTrainerTest { @Test public void testUpdateMdl() { KMeansTrainer trainer = createAndCheckTrainer(); - KMeansModel originalMdl = trainer.withK(1).fit( + KMeansModel originalMdl = trainer.withAmountOfClusters(1).fit( new LocalDatasetBuilder<>(data, 2), (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[2] @@ -106,11 +106,11 @@ public class KMeansTrainerTest { @NotNull private KMeansTrainer createAndCheckTrainer() { KMeansTrainer trainer = new KMeansTrainer() .withDistance(new EuclideanDistance()) - .withK(10) + .withAmountOfClusters(10) .withMaxIterations(1) .withEpsilon(PRECISION) .withSeed(2); - assertEquals(10, trainer.getK()); + assertEquals(10, trainer.getAmountOfClusters()); assertEquals(2, trainer.getSeed()); assertTrue(trainer.getDistance() instanceof EuclideanDistance); return trainer; http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java index 9315850..ca3f0b5 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java @@ -196,7 +196,7 @@ public class LocalModelsTest { data.put(1, new double[] {1.0, 1960, 373200}); KMeansTrainer trainer = new KMeansTrainer() - .withK(1); + .withAmountOfClusters(1); return trainer.fit( new LocalDatasetBuilder<>(data, 2), http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java index 89b8c9c..4c3655b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java @@ -57,7 +57,7 @@ public class GDBTrainerTest { } GDBTrainer trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0) - .withUseIndex(true); + .withUsingIdx(true); Model<Vector, Double> mdl = trainer.fit( learningSample, 1, @@ -131,7 +131,7 @@ public class GDBTrainerTest { learningSample.put(i, new double[] {xs[i], ys[i]}); GDBTrainer trainer = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0) - .withUseIndex(true) + .withUsingIdx(true) .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.3)); Model<Vector, Double> mdl = fitter.apply(trainer, learningSample); @@ -177,10 +177,10 @@ public class GDBTrainerTest { IgniteBiFunction<Integer, double[], Double> lExtr = (k, v) -> v[1]; GDBTrainer classifTrainer = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0) - .withUseIndex(true) + .withUsingIdx(true) .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.3)); GDBTrainer regressTrainer = new GDBRegressionOnTreesTrainer(0.3, 500, 3, 0.0) - .withUseIndex(true) + .withUsingIdx(true) .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.3)); testUpdate(learningSample, fExtr, lExtr, classifTrainer); http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java index b06fd67..7e5a079 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java @@ -45,7 +45,7 @@ public class LearningEnvironmentTest { .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD) .withMaxDepth(4) .withMinImpurityDelta(0.) - .withSubsampleSize(0.3) + .withSubSampleSize(0.3) .withSeed(0); LearningEnvironment environment = LearningEnvironment.builder() http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java index 857d9bd..d465e82 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java @@ -69,14 +69,14 @@ public class BinarizationTrainerTest { BinarizationTrainer<Integer, double[]> binarizationTrainer = new BinarizationTrainer<Integer, double[]>() .withThreshold(10); - assertEquals(10., binarizationTrainer.threshold(), 0); + assertEquals(10., binarizationTrainer.getThreshold(), 0); BinarizationPreprocessor<Integer, double[]> preprocessor = binarizationTrainer.fit( datasetBuilder, (k, v) -> VectorUtils.of(v) ); - assertEquals(binarizationTrainer.threshold(), preprocessor.threshold(), 0); + assertEquals(binarizationTrainer.getThreshold(), preprocessor.getThreshold(), 0); assertArrayEquals(new double[] {0, 0, 1}, preprocessor.apply(5, new double[] {1, 10, 100}).asArray(), 1e-8); } @@ -93,7 +93,7 @@ public class BinarizationTrainerTest { BinarizationTrainer<Integer, double[]> binarizationTrainer = new BinarizationTrainer<Integer, double[]>() .withThreshold(10); - assertEquals(10., binarizationTrainer.threshold(), 0); + assertEquals(10., binarizationTrainer.getThreshold(), 0); IgniteBiFunction<Integer, double[], Vector> preprocessor = binarizationTrainer.fit( data, http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java index f08501c..73c8842 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java @@ -61,11 +61,11 @@ public class LogRegMultiClassTrainerTest extends TrainerTest { .withBatchSize(100) .withSeed(123L); - Assert.assertEquals(trainer.amountOfIterations(), 1000); - Assert.assertEquals(trainer.amountOfLocIterations(), 10); - Assert.assertEquals(trainer.batchSize(), 100, PRECISION); + Assert.assertEquals(trainer.getAmountOfIterations(), 1000); + Assert.assertEquals(trainer.getAmountOfLocIterations(), 10); + Assert.assertEquals(trainer.getBatchSize(), 100, PRECISION); Assert.assertEquals(trainer.seed(), 123L); - Assert.assertEquals(trainer.updatesStgy(), stgy); + Assert.assertEquals(trainer.getUpdatesStgy(), stgy); LogRegressionMultiClassModel mdl = trainer.fit( cacheMock, http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java index 4e64925..84975a8 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java @@ -73,7 +73,7 @@ public class DecisionTreeRegressionTrainerTest { } DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0) - .withUseIndex(useIndex == 1); + .withUsingIdx(useIndex == 1); DecisionTreeNode tree = trainer.fit( data, http://git-wip-us.apache.org/repos/asf/ignite/blob/b10ba044/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTest.java index ed474fe..9fa7f0e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTest.java @@ -60,7 +60,7 @@ public class RandomForestTest { .withFeaturesCountSelectionStrgy(x -> 4) .withMaxDepth(maxDepth) .withMinImpurityDelta(minImpDelta) - .withSubsampleSize(0.1); + .withSubSampleSize(0.1); /** */ @Test