Repository: ignite Updated Branches: refs/heads/master 02afb437f -> f4c18f114
http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java index 2f5d5d6..fb34c93 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java @@ -26,8 +26,10 @@ import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.environment.LearningEnvironment; +import org.apache.ignite.ml.environment.logging.MLLogger; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.jetbrains.annotations.NotNull; /** * Interface for trainers. Trainer is just a function which produces model from the data. @@ -53,6 +55,71 @@ public abstract class DatasetTrainer<M extends Model, L> { IgniteBiFunction<K, V, L> lbExtractor); /** + * Gets state of model in arguments, compare it with training parameters of trainer and if they are fit then + * trainer updates model in according to new data and return new model. In other case trains new model. + * + * @param mdl Learned model. + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @return Updated model. + */ + public <K,V> M update(M mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + + if(mdl != null) { + if(checkState(mdl)) { + return updateModel(mdl, datasetBuilder, featureExtractor, lbExtractor); + } else { + environment.logger(getClass()).log( + MLLogger.VerboseLevel.HIGH, + "Model cannot be updated because of initial state of " + + "it doesn't corresponds to trainer parameters" + ); + } + } + + return fit(datasetBuilder, featureExtractor, lbExtractor); + } + + /** + * @param mdl Model. + * @return true if current critical for training parameters correspond to parameters from last training. + */ + protected abstract boolean checkState(M mdl); + + /** + * Used on update phase when given dataset is empty. + * If last trained model exist then method returns it. In other case throws IllegalArgumentException. + * + * @param lastTrainedMdl Model. + */ + @NotNull protected M getLastTrainedModelOrThrowEmptyDatasetException(M lastTrainedMdl) { + String msg = "Cannot train model on empty dataset"; + if (lastTrainedMdl != null) { + environment.logger(getClass()).log(MLLogger.VerboseLevel.HIGH, msg); + return lastTrainedMdl; + } else + throw new EmptyDatasetException(); + } + + /** + * Gets state of model in arguments, update in according to new data and return new model. + * + * @param mdl Learned model. + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @return Updated model. + */ + protected abstract <K, V> M updateModel(M mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor); + + /** * Trains model based on the specified data. * * @param ignite Ignite instance. @@ -73,6 +140,27 @@ public abstract class DatasetTrainer<M extends Model, L> { } /** + * Gets state of model in arguments, update in according to new data and return new model. + * + * @param mdl Learned model. + * @param ignite Ignite instance. + * @param cache Ignite cache. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @return Updated model. + */ + public <K, V> M update(M mdl, Ignite ignite, IgniteCache<K, V> cache, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + return update( + mdl, new CacheBasedDatasetBuilder<>(ignite, cache), + featureExtractor, + lbExtractor + ); + } + + /** * Trains model based on the specified data. * * @param ignite Ignite instance. @@ -94,6 +182,28 @@ public abstract class DatasetTrainer<M extends Model, L> { } /** + * Gets state of model in arguments, update in according to new data and return new model. + * + * @param mdl Learned model. + * @param ignite Ignite instance. + * @param cache Ignite cache. + * @param filter Filter for {@code upstream} data. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @return Updated model. + */ + public <K, V> M update(M mdl, Ignite ignite, IgniteCache<K, V> cache, IgniteBiPredicate<K, V> filter, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + return update( + mdl, new CacheBasedDatasetBuilder<>(ignite, cache, filter), + featureExtractor, + lbExtractor + ); + } + + /** * Trains model based on the specified data. * * @param data Data. @@ -114,6 +224,27 @@ public abstract class DatasetTrainer<M extends Model, L> { } /** + * Gets state of model in arguments, update in according to new data and return new model. + * + * @param mdl Learned model. + * @param data Data. + * @param parts Number of partitions. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @return Updated model. + */ + public <K, V> M update(M mdl, Map<K, V> data, int parts, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor) { + return update( + mdl, new LocalDatasetBuilder<>(data, parts), + featureExtractor, + lbExtractor + ); + } + + /** * Trains model based on the specified data. * * @param data Data. @@ -136,10 +267,45 @@ public abstract class DatasetTrainer<M extends Model, L> { } /** + * Gets state of model in arguments, update in according to new data and return new model. + * + * @param data Data. + * @param filter Filter for {@code upstream} data. + * @param parts Number of partitions. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @return Updated model. + */ + public <K, V> M update(M mdl, Map<K, V> data, IgniteBiPredicate<K, V> filter, int parts, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor) { + return update( + mdl, new LocalDatasetBuilder<>(data, filter, parts), + featureExtractor, + lbExtractor + ); + } + + /** * Sets learning Environment * @param environment Environment. */ public void setEnvironment(LearningEnvironment environment) { this.environment = environment; } + + /** */ + public static class EmptyDatasetException extends IllegalArgumentException { + /** Serial version uid. */ + private static final long serialVersionUID = 6914650522523293521L; + + /** + * Constructs an instance of EmptyDatasetException. + */ + public EmptyDatasetException() { + super("Cannot train model on empty dataset"); + } + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java index de8994a..355048a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java @@ -86,6 +86,29 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset } } + /** + * Trains new model based on dataset because there is no valid approach to update decision trees. + * + * @param mdl Learned model. + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @return New model based on new dataset. + */ + @Override public <K, V> DecisionTreeNode updateModel(DecisionTreeNode mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + return fit(datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override protected boolean checkState(DecisionTreeNode mdl) { + return true; + } + + /** */ public <K,V> DecisionTreeNode fit(Dataset<EmptyContext, DecisionTreeData> dataset) { return split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset)); } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java index 559dfff..7832584 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java @@ -64,8 +64,9 @@ public class RandomForestClassifierTrainer * This id can be used as index in arrays or lists. * * @param dataset Dataset. + * @return true if initialization was done. */ - @Override protected void init(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) { + @Override protected boolean init(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) { Set<Double> uniqLabels = dataset.compute( x -> { Set<Double> labels = new HashSet<>(); @@ -85,11 +86,14 @@ public class RandomForestClassifierTrainer } ); + if(uniqLabels == null) + return false; + int i = 0; for (Double label : uniqLabels) lblMapping.put(label, i++); - super.init(dataset); + return super.init(dataset); } /** {@inheritDoc} */ http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java index cb25aa3..91fcf0a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java @@ -30,6 +30,7 @@ import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.ignite.ml.Model; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; @@ -116,7 +117,8 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra new EmptyContextBuilder<>(), new BootstrappedDatasetBuilder<>(featureExtractor, lbExtractor, cntOfTrees, subsampleSize))) { - init(dataset); + if(!init(dataset)) + return buildComposition(Collections.emptyList()); models = fit(dataset); } catch (Exception e) { @@ -202,7 +204,8 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra * * @param dataset Dataset. */ - protected void init(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) { + protected boolean init(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) { + return true; } /** @@ -215,6 +218,8 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra Queue<TreeNode> treesQueue = createRootsQueue(); ArrayList<TreeRoot> roots = initTrees(treesQueue); Map<Integer, BucketMeta> histMeta = computeHistogramMeta(meta, dataset); + if(histMeta.isEmpty()) + return Collections.emptyList(); ImpurityHistogramsComputer<S> histogramsComputer = createImpurityHistogramsComputer(); while (!treesQueue.isEmpty()) { @@ -232,6 +237,23 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra return roots; } + /** {@inheritDoc} */ + @Override protected boolean checkState(ModelsComposition mdl) { + ModelsComposition fakeComposition = buildComposition(Collections.emptyList()); + return mdl.getPredictionsAggregator().getClass() == fakeComposition.getPredictionsAggregator().getClass(); + } + + /** {@inheritDoc} */ + @Override protected <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + ArrayList<Model<Vector, Double>> oldModels = new ArrayList<>(mdl.getModels()); + ModelsComposition newModels = fit(datasetBuilder, featureExtractor, lbExtractor); + oldModels.addAll(newModels.getModels()); + + return new ModelsComposition(oldModels, mdl.getPredictionsAggregator()); + } + /** * Split node with NodeId if need. * @@ -302,6 +324,8 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra List<NormalDistributionStatistics> stats = new NormalDistributionStatisticsComputer() .computeStatistics(meta, dataset); + if(stats == null) + return Collections.emptyMap(); Map<Integer, BucketMeta> bucketsMeta = new HashMap<>(); for (int i = 0; i < stats.size(); i++) { http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java index aae5af1..03f044a 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java @@ -27,6 +27,7 @@ import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; +import org.jetbrains.annotations.NotNull; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -39,19 +40,70 @@ public class KMeansTrainerTest { /** Precision in test checks. */ private static final double PRECISION = 1e-2; + /** Data. */ + private static final Map<Integer, double[]> data = new HashMap<>(); + + static { + data.put(0, new double[] {1.0, 1.0, 1.0}); + data.put(1, new double[] {1.0, 2.0, 1.0}); + data.put(2, new double[] {2.0, 1.0, 1.0}); + data.put(3, new double[] {-1.0, -1.0, 2.0}); + data.put(4, new double[] {-1.0, -2.0, 2.0}); + data.put(5, new double[] {-2.0, -1.0, 2.0}); + } + /** * A few points, one cluster, one iteration */ @Test public void findOneClusters() { - Map<Integer, double[]> data = new HashMap<>(); - data.put(0, new double[]{1.0, 1.0, 1.0}); - data.put(1, new double[]{1.0, 2.0, 1.0}); - data.put(2, new double[]{2.0, 1.0, 1.0}); - data.put(3, new double[]{-1.0, -1.0, 2.0}); - data.put(4, new double[]{-1.0, -2.0, 2.0}); - data.put(5, new double[]{-2.0, -1.0, 2.0}); + KMeansTrainer trainer = createAndCheckTrainer(); + KMeansModel knnMdl = trainer.withK(1).fit( + new LocalDatasetBuilder<>(data, 2), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[2] + ); + + Vector firstVector = new DenseVector(new double[] {2.0, 2.0}); + assertEquals(knnMdl.apply(firstVector), 0.0, PRECISION); + Vector secondVector = new DenseVector(new double[] {-2.0, -2.0}); + assertEquals(knnMdl.apply(secondVector), 0.0, PRECISION); + assertEquals(trainer.getMaxIterations(), 1); + assertEquals(trainer.getEpsilon(), PRECISION, PRECISION); + } + /** */ + @Test + public void testUpdateMdl() { + KMeansTrainer trainer = createAndCheckTrainer(); + KMeansModel originalMdl = trainer.withK(1).fit( + new LocalDatasetBuilder<>(data, 2), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[2] + ); + KMeansModel updatedMdlOnSameDataset = trainer.update( + originalMdl, + new LocalDatasetBuilder<>(data, 2), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[2] + ); + KMeansModel updatedMdlOnEmptyDataset = trainer.update( + originalMdl, + new LocalDatasetBuilder<>(new HashMap<Integer, double[]>(), 2), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[2] + ); + + Vector firstVector = new DenseVector(new double[] {2.0, 2.0}); + Vector secondVector = new DenseVector(new double[] {-2.0, -2.0}); + assertEquals(originalMdl.apply(firstVector), updatedMdlOnSameDataset.apply(firstVector), PRECISION); + assertEquals(originalMdl.apply(secondVector), updatedMdlOnSameDataset.apply(secondVector), PRECISION); + assertEquals(originalMdl.apply(firstVector), updatedMdlOnEmptyDataset.apply(firstVector), PRECISION); + assertEquals(originalMdl.apply(secondVector), updatedMdlOnEmptyDataset.apply(secondVector), PRECISION); + } + + /** */ + @NotNull private KMeansTrainer createAndCheckTrainer() { KMeansTrainer trainer = new KMeansTrainer() .withDistance(new EuclideanDistance()) .withK(10) @@ -61,20 +113,6 @@ public class KMeansTrainerTest { assertEquals(10, trainer.getK()); assertEquals(2, trainer.getSeed()); assertTrue(trainer.getDistance() instanceof EuclideanDistance); - - KMeansModel knnMdl = trainer - .withK(1) - .fit( - new LocalDatasetBuilder<>(data, 2), - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), - (k, v) -> v[2] - ); - - Vector firstVector = new DenseVector(new double[]{2.0, 2.0}); - assertEquals(knnMdl.apply(firstVector), 0.0, PRECISION); - Vector secondVector = new DenseVector(new double[]{-2.0, -2.0}); - assertEquals(knnMdl.apply(secondVector), 0.0, PRECISION); - assertEquals(trainer.getMaxIterations(), 1); - assertEquals(trainer.getEpsilon(), PRECISION, PRECISION); + return trainer; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java index acf28e9..745eac9 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java @@ -22,6 +22,7 @@ import java.util.Set; import org.apache.ignite.ml.clustering.kmeans.KMeansModel; import org.apache.ignite.ml.clustering.kmeans.KMeansModelFormat; import org.apache.ignite.ml.knn.ann.ANNClassificationModel; +import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer; import org.apache.ignite.ml.knn.ann.ANNModelFormat; import org.apache.ignite.ml.knn.classification.KNNClassificationModel; import org.apache.ignite.ml.knn.classification.KNNModelFormat; @@ -103,11 +104,11 @@ public class CollectionsTest { test(new SVMLinearBinaryClassificationModel(null, 1.0), new SVMLinearBinaryClassificationModel(null, 0.5)); - test(new ANNClassificationModel(new LabeledVectorSet<>()), - new ANNClassificationModel(new LabeledVectorSet<>(1, 1, true))); + test(new ANNClassificationModel(new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()), + new ANNClassificationModel(new LabeledVectorSet<>(1, 1, true), new ANNClassificationTrainer.CentroidStat())); - test(new ANNModelFormat(1, new ManhattanDistance(), NNStrategy.SIMPLE, new LabeledVectorSet<>()), - new ANNModelFormat(2, new ManhattanDistance(), NNStrategy.SIMPLE, new LabeledVectorSet<>())); + test(new ANNModelFormat(1, new ManhattanDistance(), NNStrategy.SIMPLE, new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()), + new ANNModelFormat(2, new ManhattanDistance(), NNStrategy.SIMPLE, new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat())); } /** Test classes that have all instances equal (eg, metrics). */ http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java index 17d9c1a..9315850 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java @@ -32,6 +32,7 @@ import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.knn.NNClassificationModel; import org.apache.ignite.ml.knn.ann.ANNClassificationModel; +import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer; import org.apache.ignite.ml.knn.ann.ANNModelFormat; import org.apache.ignite.ml.knn.ann.ProbableLabel; import org.apache.ignite.ml.knn.classification.KNNClassificationModel; @@ -237,7 +238,7 @@ public class LocalModelsTest { executeModelTest(mdlFilePath -> { final LabeledVectorSet<ProbableLabel, LabeledVector> centers = new LabeledVectorSet<>(); - NNClassificationModel mdl = new ANNClassificationModel(centers) + NNClassificationModel mdl = new ANNClassificationModel(centers, new ANNClassificationTrainer.CentroidStat()) .withK(4) .withDistanceMeasure(new ManhattanDistance()) .withStrategy(NNStrategy.WEIGHTED); @@ -250,7 +251,7 @@ public class LocalModelsTest { Assert.assertNotNull(load); - NNClassificationModel importedMdl = new ANNClassificationModel(load.getCandidates()) + NNClassificationModel importedMdl = new ANNClassificationModel(load.getCandidates(), new ANNClassificationTrainer.CentroidStat()) .withK(load.getK()) .withDistanceMeasure(load.getDistanceMeasure()) .withStrategy(load.getStgy()); http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java index 4452668..3e340f6 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java @@ -54,7 +54,7 @@ public class GDBTrainerTest { learningSample.put(i, new double[] {xs[i], ys[i]}); } - DatasetTrainer<Model<Vector, Double>, Double> trainer + DatasetTrainer<ModelsComposition, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0).withUseIndex(true); Model<Vector, Double> mdl = trainer.fit( http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java index 7289b1d..d8fb620 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java @@ -26,6 +26,7 @@ import org.apache.ignite.ml.knn.ann.ANNClassificationModel; import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer; import org.apache.ignite.ml.knn.classification.NNStrategy; import org.apache.ignite.ml.math.distances.EuclideanDistance; +import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Assert; import org.junit.Test; @@ -68,4 +69,47 @@ public class ANNClassificationTest extends TrainerTest { Assert.assertTrue(mdl.toString(true).contains(NNStrategy.SIMPLE.name())); Assert.assertTrue(mdl.toString(false).contains(NNStrategy.SIMPLE.name())); } + + /** */ + @Test + public void testUpdate() { + Map<Integer, double[]> cacheMock = new HashMap<>(); + + for (int i = 0; i < twoClusters.length; i++) + cacheMock.put(i, twoClusters[i]); + + ANNClassificationTrainer trainer = new ANNClassificationTrainer() + .withK(10) + .withMaxIterations(10) + .withEpsilon(1e-4) + .withDistance(new EuclideanDistance()); + + ANNClassificationModel originalMdl = (ANNClassificationModel) trainer.withSeed(1234L).fit( + cacheMock, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ).withK(3) + .withDistanceMeasure(new EuclideanDistance()) + .withStrategy(NNStrategy.SIMPLE); + + ANNClassificationModel updatedOnSameDataset = trainer.withSeed(1234L).update(originalMdl, + cacheMock, parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[2] + ); + + ANNClassificationModel updatedOnEmptyDataset = trainer.withSeed(1234L).update(originalMdl, + new HashMap<Integer, double[]>(), parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[2] + ); + + Vector v1 = VectorUtils.of(550, 550); + Vector v2 = VectorUtils.of(-550, -550); + TestUtils.assertEquals(originalMdl.apply(v1), updatedOnSameDataset.apply(v1), PRECISION); + TestUtils.assertEquals(originalMdl.apply(v2), updatedOnSameDataset.apply(v2), PRECISION); + TestUtils.assertEquals(originalMdl.apply(v1), updatedOnEmptyDataset.apply(v1), PRECISION); + TestUtils.assertEquals(originalMdl.apply(v2), updatedOnEmptyDataset.apply(v2), PRECISION); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java index c5a5c1c..748123a 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java @@ -174,4 +174,43 @@ public class KNNClassificationTest { Vector vector = new DenseVector(new double[] {-1.01, -1.01}); assertEquals(knnMdl.apply(vector), 1.0); } + + /** */ + @Test + public void testUpdate() { + Map<Integer, double[]> data = new HashMap<>(); + data.put(0, new double[] {10.0, 10.0, 1.0}); + data.put(1, new double[] {10.0, 20.0, 1.0}); + data.put(2, new double[] {-1, -1, 1.0}); + data.put(3, new double[] {-2, -2, 2.0}); + data.put(4, new double[] {-1.0, -2.0, 2.0}); + data.put(5, new double[] {-2.0, -1.0, 2.0}); + + KNNClassificationTrainer trainer = new KNNClassificationTrainer(); + + KNNClassificationModel originalMdl = (KNNClassificationModel)trainer.fit( + data, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[2] + ).withK(3) + .withDistanceMeasure(new EuclideanDistance()) + .withStrategy(NNStrategy.WEIGHTED); + + KNNClassificationModel updatedOnSameDataset = trainer.update(originalMdl, + data, parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[2] + ); + + KNNClassificationModel updatedOnEmptyDataset = trainer.update(originalMdl, + new HashMap<Integer, double[]>(), parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[2] + ); + + Vector vector = new DenseVector(new double[] {-1.01, -1.01}); + assertEquals(originalMdl.apply(vector), updatedOnSameDataset.apply(vector)); + assertEquals(originalMdl.apply(vector), updatedOnEmptyDataset.apply(vector)); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java index 5504e1a..52ff1ec 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java @@ -35,6 +35,8 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import static junit.framework.TestCase.assertEquals; + /** * Tests for {@link KNNRegressionTrainer}. */ @@ -135,4 +137,42 @@ public class KNNRegressionTest { Assert.assertTrue(knnMdl.toString(true).contains(stgy.name())); Assert.assertTrue(knnMdl.toString(false).contains(stgy.name())); } + + /** */ + @Test + public void testUpdate() { + Map<Integer, double[]> data = new HashMap<>(); + data.put(0, new double[] {11.0, 0, 0, 0, 0, 0}); + data.put(1, new double[] {12.0, 2.0, 0, 0, 0, 0}); + data.put(2, new double[] {13.0, 0, 3.0, 0, 0, 0}); + data.put(3, new double[] {14.0, 0, 0, 4.0, 0, 0}); + data.put(4, new double[] {15.0, 0, 0, 0, 5.0, 0}); + data.put(5, new double[] {16.0, 0, 0, 0, 0, 6.0}); + + KNNRegressionTrainer trainer = new KNNRegressionTrainer(); + + KNNRegressionModel originalMdl = (KNNRegressionModel) trainer.fit( + new LocalDatasetBuilder<>(data, parts), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ).withK(1) + .withDistanceMeasure(new EuclideanDistance()) + .withStrategy(NNStrategy.SIMPLE); + + KNNRegressionModel updatedOnSameDataset = trainer.update(originalMdl, + data, parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[2] + ); + + KNNRegressionModel updatedOnEmptyDataset = trainer.update(originalMdl, + new HashMap<Integer, double[]>(), parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[2] + ); + + Vector vector = new DenseVector(new double[] {0, 0, 0, 5.0, 0.0}); + assertEquals(originalMdl.apply(vector), updatedOnSameDataset.apply(vector)); + assertEquals(originalMdl.apply(vector), updatedOnEmptyDataset.apply(vector)); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java index a1d601c..6a6555e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java @@ -29,6 +29,7 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; +import org.apache.ignite.ml.optimization.SmoothParametrized; import org.apache.ignite.ml.optimization.updatecalculators.NesterovParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator; import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; @@ -154,6 +155,69 @@ public class MLPTrainerTest { TestUtils.checkIsInEpsilonNeighbourhood(new DenseVector(new double[]{0.0}), predict.getRow(0), 1E-1); } + + /** */ + @Test + public void testUpdate() { + UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ); + + Map<Integer, double[][]> xorData = new HashMap<>(); + xorData.put(0, new double[][]{{0.0, 0.0}, {0.0}}); + xorData.put(1, new double[][]{{0.0, 1.0}, {1.0}}); + xorData.put(2, new double[][]{{1.0, 0.0}, {1.0}}); + xorData.put(3, new double[][]{{1.0, 1.0}, {0.0}}); + + MLPArchitecture arch = new MLPArchitecture(2). + withAddedLayer(10, true, Activators.RELU). + withAddedLayer(1, false, Activators.SIGMOID); + + MLPTrainer<SimpleGDParameterUpdate> trainer = new MLPTrainer<>( + arch, + LossFunctions.MSE, + updatesStgy, + 3000, + batchSize, + 50, + 123L + ); + + MultilayerPerceptron originalMdl = trainer.fit( + xorData, + parts, + (k, v) -> VectorUtils.of(v[0]), + (k, v) -> v[1] + ); + + MultilayerPerceptron updatedOnSameDS = trainer.update( + originalMdl, + xorData, + parts, + (k, v) -> VectorUtils.of(v[0]), + (k, v) -> v[1] + ); + + MultilayerPerceptron updatedOnEmptyDS = trainer.update( + originalMdl, + new HashMap<Integer, double[][]>(), + parts, + (k, v) -> VectorUtils.of(v[0]), + (k, v) -> v[1] + ); + + DenseMatrix matrix = new DenseMatrix(new double[][] { + {0.0, 0.0}, + {0.0, 1.0}, + {1.0, 0.0}, + {1.0, 1.0} + }); + + TestUtils.checkIsInEpsilonNeighbourhood(originalMdl.apply(matrix).getRow(0), updatedOnSameDS.apply(matrix).getRow(0), 1E-1); + TestUtils.checkIsInEpsilonNeighbourhood(originalMdl.apply(matrix).getRow(0), updatedOnEmptyDS.apply(matrix).getRow(0), 1E-1); + } } /** http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java index d16ae72..9c35ac7 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java @@ -101,4 +101,55 @@ public class LinearRegressionLSQRTrainerTest extends TrainerTest { assertEquals(intercept, mdl.getIntercept(), 1e-6); } + + /** */ + @Test + public void testUpdate() { + Random rnd = new Random(0); + Map<Integer, double[]> data = new HashMap<>(); + double[] coef = new double[100]; + double intercept = rnd.nextDouble() * 10; + + for (int i = 0; i < 100000; i++) { + double[] x = new double[coef.length + 1]; + + for (int j = 0; j < coef.length; j++) + x[j] = rnd.nextDouble() * 10; + + x[coef.length] = intercept; + + data.put(i, x); + } + + LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); + + LinearRegressionModel originalModel = trainer.fit( + data, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[coef.length] + ); + + LinearRegressionModel updatedOnSameDS = trainer.update( + originalModel, + data, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[coef.length] + ); + + LinearRegressionModel updatedOnEmpyDS = trainer.update( + originalModel, + new HashMap<Integer, double[]>(), + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[coef.length] + ); + + assertArrayEquals(originalModel.getWeights().getStorage().data(), updatedOnSameDS.getWeights().getStorage().data(), 1e-6); + assertEquals(originalModel.getIntercept(), updatedOnSameDS.getIntercept(), 1e-6); + + assertArrayEquals(originalModel.getWeights().getStorage().data(), updatedOnEmpyDS.getWeights().getStorage().data(), 1e-6); + assertEquals(originalModel.getIntercept(), updatedOnEmpyDS.getIntercept(), 1e-6); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java index 349e712..86b0f27 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java @@ -72,4 +72,66 @@ public class LinearRegressionSGDTrainerTest extends TrainerTest { assertEquals(2.8421709430404007e-14, mdl.getIntercept(), 1e-1); } + + /** */ + @Test + public void testUpdate() { + Map<Integer, double[]> data = new HashMap<>(); + data.put(0, new double[]{-1.0915526, 1.81983527, -0.91409478, 0.70890712, -24.55724107}); + data.put(1, new double[]{-0.61072904, 0.37545517, 0.21705352, 0.09516495, -26.57226867}); + data.put(2, new double[]{0.05485406, 0.88219898, -0.80584547, 0.94668307, 61.80919728}); + data.put(3, new double[]{-0.24835094, -0.34000053, -1.69984651, -1.45902635, -161.65525991}); + data.put(4, new double[]{0.63675392, 0.31675535, 0.38837437, -1.1221971, -14.46432611}); + data.put(5, new double[]{0.14194017, 2.18158997, -0.28397346, -0.62090588, -3.2122197}); + data.put(6, new double[]{-0.53487507, 1.4454797, 0.21570443, -0.54161422, -46.5469012}); + data.put(7, new double[]{-1.58812173, -0.73216803, -2.15670676, -1.03195988, -247.23559889}); + data.put(8, new double[]{0.20702671, 0.92864654, 0.32721202, -0.09047503, 31.61484949}); + data.put(9, new double[]{-0.37890345, -0.04846179, -0.84122753, -1.14667474, -124.92598583}); + + LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>( + new RPropUpdateCalculator(), + RPropParameterUpdate::sumLocal, + RPropParameterUpdate::avg + ), 100000, 10, 100, 0L); + + LinearRegressionModel originalModel = trainer.withSeed(0).fit( + data, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[4] + ); + + + LinearRegressionModel updatedOnSameDS = trainer.withSeed(0).update( + originalModel, + data, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[4] + ); + + LinearRegressionModel updatedOnEmptyDS = trainer.withSeed(0).update( + originalModel, + new HashMap<Integer, double[]>(), + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[4] + ); + + assertArrayEquals( + originalModel.getWeights().getStorage().data(), + updatedOnSameDS.getWeights().getStorage().data(), + 1.0 + ); + + assertEquals(originalModel.getIntercept(), updatedOnSameDS.getIntercept(), 1.0); + + assertArrayEquals( + originalModel.getWeights().getStorage().data(), + updatedOnEmptyDS.getWeights().getStorage().data(), + 1e-1 + ); + + assertEquals(originalModel.getIntercept(), updatedOnEmptyDS.getIntercept(), 1e-1); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java index 1f8c5d1..f08501c 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java @@ -19,9 +19,11 @@ package org.apache.ignite.ml.regressions.logistic; import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; +import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.SmoothParametrized; @@ -81,4 +83,60 @@ public class LogRegMultiClassTrainerTest extends TrainerTest { TestUtils.assertEquals(2, mdl.apply(VectorUtils.of(-10, -10)), PRECISION); TestUtils.assertEquals(3, mdl.apply(VectorUtils.of(10, -10)), PRECISION); } + + /** */ + @Test + public void testUpdate() { + Map<Integer, double[]> cacheMock = new HashMap<>(); + + for (int i = 0; i < fourSetsInSquareVertices.length; i++) + cacheMock.put(i, fourSetsInSquareVertices[i]); + + LogRegressionMultiClassTrainer<?> trainer = new LogRegressionMultiClassTrainer<>() + .withUpdatesStgy(new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + )) + .withAmountOfIterations(1000) + .withAmountOfLocIterations(10) + .withBatchSize(100) + .withSeed(123L); + + LogRegressionMultiClassModel originalModel = trainer.fit( + cacheMock, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + LogRegressionMultiClassModel updatedOnSameDS = trainer.update( + originalModel, + cacheMock, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + LogRegressionMultiClassModel updatedOnEmptyDS = trainer.update( + originalModel, + new HashMap<Integer, double[]>(), + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + List<Vector> vectors = Arrays.asList( + VectorUtils.of(10, 10), + VectorUtils.of(-10, 10), + VectorUtils.of(-10, -10), + VectorUtils.of(10, -10) + ); + + + for (Vector vec : vectors) { + TestUtils.assertEquals(originalModel.apply(vec), updatedOnSameDS.apply(vec), PRECISION); + TestUtils.assertEquals(originalModel.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION); + } + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java index 5bd2dbd..1da0d1a 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.Map; import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; +import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; @@ -60,4 +61,49 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest { TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION); TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION); } + + /** */ + @Test + public void testUpdate() { + Map<Integer, double[]> cacheMock = new HashMap<>(); + + for (int i = 0; i < twoLinearlySeparableClasses.length; i++) + cacheMock.put(i, twoLinearlySeparableClasses[i]); + + LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( + new SimpleGDUpdateCalculator().withLearningRate(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ), 100000, 10, 100, 123L); + + LogisticRegressionModel originalModel = trainer.fit( + cacheMock, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + LogisticRegressionModel updatedOnSameDS = trainer.update( + originalModel, + cacheMock, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + LogisticRegressionModel updatedOnEmptyDS = trainer.update( + originalModel, + new HashMap<Integer, double[]>(), + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + Vector v1 = VectorUtils.of(100, 10); + Vector v2 = VectorUtils.of(10, 100); + TestUtils.assertEquals(originalModel.apply(v1), updatedOnSameDS.apply(v1), PRECISION); + TestUtils.assertEquals(originalModel.apply(v2), updatedOnSameDS.apply(v2), PRECISION); + TestUtils.assertEquals(originalModel.apply(v2), updatedOnEmptyDS.apply(v2), PRECISION); + TestUtils.assertEquals(originalModel.apply(v1), updatedOnEmptyDS.apply(v1), PRECISION); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java index 5630bee..263bb6d 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.Map; import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; +import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; @@ -52,4 +53,44 @@ public class SVMBinaryTrainerTest extends TrainerTest { TestUtils.assertEquals(-1, mdl.apply(VectorUtils.of(100, 10)), PRECISION); TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION); } + + /** */ + @Test + public void testUpdate() { + Map<Integer, double[]> cacheMock = new HashMap<>(); + + for (int i = 0; i < twoLinearlySeparableClasses.length; i++) + cacheMock.put(i, twoLinearlySeparableClasses[i]); + + SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer() + .withAmountOfIterations(1000) + .withSeed(1234L); + + SVMLinearBinaryClassificationModel originalModel = trainer.fit( + cacheMock, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + SVMLinearBinaryClassificationModel updatedOnSameDS = trainer.update( + originalModel, + cacheMock, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + SVMLinearBinaryClassificationModel updatedOnEmptyDS = trainer.update( + originalModel, + new HashMap<Integer, double[]>(), + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + Vector v = VectorUtils.of(100, 10); + TestUtils.assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), PRECISION); + TestUtils.assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), PRECISION); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java index 7ea28c2..e0c62af 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.Map; import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; +import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; @@ -54,4 +55,46 @@ public class SVMMultiClassTrainerTest extends TrainerTest { TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION); TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION); } + + /** */ + @Test + public void testUpdate() { + Map<Integer, double[]> cacheMock = new HashMap<>(); + + for (int i = 0; i < twoLinearlySeparableClasses.length; i++) + cacheMock.put(i, twoLinearlySeparableClasses[i]); + + SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer() + .withLambda(0.3) + .withAmountOfLocIterations(10) + .withAmountOfIterations(100) + .withSeed(1234L); + + SVMLinearMultiClassClassificationModel originalModel = trainer.fit( + cacheMock, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + SVMLinearMultiClassClassificationModel updatedOnSameDS = trainer.update( + originalModel, + cacheMock, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + SVMLinearMultiClassClassificationModel updatedOnEmptyDS = trainer.update( + originalModel, + new HashMap<Integer, double[]>(), + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + Vector v = VectorUtils.of(100, 10); + TestUtils.assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), PRECISION); + TestUtils.assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), PRECISION); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java index 4abf508..087f4e8 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java @@ -24,6 +24,7 @@ import java.util.Map; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator; import org.apache.ignite.ml.dataset.feature.FeatureMeta; +import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; import org.junit.runner.RunWith; @@ -75,7 +76,7 @@ public class RandomForestClassifierTrainerTest { } ArrayList<FeatureMeta> meta = new ArrayList<>(); - for(int i = 0; i < 4; i++) + for (int i = 0; i < 4; i++) meta.add(new FeatureMeta("", i, false)); RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(meta) .withCountOfTrees(5) @@ -86,4 +87,34 @@ public class RandomForestClassifierTrainerTest { assertTrue(mdl.getPredictionsAggregator() instanceof OnMajorityPredictionsAggregator); assertEquals(5, mdl.getModels().size()); } + + /** */ + @Test + public void testUpdate() { + int sampleSize = 1000; + Map<double[], Double> sample = new HashMap<>(); + for (int i = 0; i < sampleSize; i++) { + double x1 = i; + double x2 = x1 / 10.0; + double x3 = x2 / 10.0; + double x4 = x3 / 10.0; + + sample.put(new double[] {x1, x2, x3, x4}, (double)(i % 2)); + } + + ArrayList<FeatureMeta> meta = new ArrayList<>(); + for (int i = 0; i < 4; i++) + meta.add(new FeatureMeta("", i, false)); + RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(meta) + .withCountOfTrees(100) + .withFeaturesCountSelectionStrgy(x -> 2); + + ModelsComposition originalModel = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); + ModelsComposition updatedOnSameDS = trainer.update(originalModel, sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); + ModelsComposition updatedOnEmptyDS = trainer.update(originalModel, new HashMap<double[], Double>(), parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); + + Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005); + assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), 0.01); + assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), 0.01); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java index c4a4a75..fcc20bd 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java @@ -24,6 +24,7 @@ import java.util.Map; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator; import org.apache.ignite.ml.dataset.feature.FeatureMeta; +import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; import org.junit.runner.RunWith; @@ -82,4 +83,34 @@ public class RandomForestRegressionTrainerTest { assertTrue(mdl.getPredictionsAggregator() instanceof MeanValuePredictionsAggregator); assertEquals(5, mdl.getModels().size()); } + + /** */ + @Test + public void testUpdate() { + int sampleSize = 1000; + Map<double[], Double> sample = new HashMap<>(); + for (int i = 0; i < sampleSize; i++) { + double x1 = i; + double x2 = x1 / 10.0; + double x3 = x2 / 10.0; + double x4 = x3 / 10.0; + + sample.put(new double[] {x1, x2, x3, x4}, (double)(i % 2)); + } + + ArrayList<FeatureMeta> meta = new ArrayList<>(); + for (int i = 0; i < 4; i++) + meta.add(new FeatureMeta("", i, false)); + RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(meta) + .withCountOfTrees(100) + .withFeaturesCountSelectionStrgy(x -> 2); + + ModelsComposition originalModel = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); + ModelsComposition updatedOnSameDS = trainer.update(originalModel, sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); + ModelsComposition updatedOnEmptyDS = trainer.update(originalModel, new HashMap<double[], Double>(), parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); + + Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005); + assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), 0.1); + assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), 0.1); + } }
