Repository: ignite Updated Branches: refs/heads/master ed6bf5ac2 -> 6225c56ea
http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java new file mode 100644 index 0000000..be1724c --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.dataset.primitive; + +import java.io.Serializable; +import java.util.Iterator; +import org.apache.ignite.ml.dataset.PartitionDataBuilder; +import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.tree.data.DecisionTreeData; + +/** + * A partition {@code data} builder that makes {@link DecisionTreeData}. + * + * @param <K> Type of a key in <tt>upstream</tt> data. + * @param <V> Type of a value in <tt>upstream</tt> data. + * @param <C> Type of a partition <tt>context</tt>. + */ +public class FeatureMatrixWithLabelsOnHeapDataBuilder<K, V, C extends Serializable> + implements PartitionDataBuilder<K, V, C, FeatureMatrixWithLabelsOnHeapData> { + /** Serial version uid. */ + private static final long serialVersionUID = 6273736987424171813L; + + /** Function that extracts features from an {@code upstream} data. */ + private final IgniteBiFunction<K, V, Vector> featureExtractor; + + /** Function that extracts labels from an {@code upstream} data. */ + private final IgniteBiFunction<K, V, Double> lbExtractor; + + /** + * Constructs a new instance of decision tree data builder. + * + * @param featureExtractor Function that extracts features from an {@code upstream} data. + * @param lbExtractor Function that extracts labels from an {@code upstream} data. + */ + public FeatureMatrixWithLabelsOnHeapDataBuilder(IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + this.featureExtractor = featureExtractor; + this.lbExtractor = lbExtractor; + } + + /** {@inheritDoc} */ + @Override public FeatureMatrixWithLabelsOnHeapData build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) { + double[][] features = new double[Math.toIntExact(upstreamDataSize)][]; + double[] labels = new double[Math.toIntExact(upstreamDataSize)]; + + int ptr = 0; + while (upstreamData.hasNext()) { + UpstreamEntry<K, V> entry = upstreamData.next(); + + features[ptr] = featureExtractor.apply(entry.getKey(), entry.getValue()).asArray(); + + labels[ptr] = lbExtractor.apply(entry.getKey(), entry.getValue()); + + ptr++; + } + + return new FeatureMatrixWithLabelsOnHeapData(features, labels); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java index 8589a79..6ebbda1 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java @@ -17,12 +17,13 @@ package org.apache.ignite.ml.tree.boosting; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.ignite.ml.Model; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy; +import org.apache.ignite.ml.composition.boosting.GDBTrainer; +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker; import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; @@ -54,22 +55,30 @@ public class GDBOnTreesLearningStrategy extends GDBLearningStrategy { } /** {@inheritDoc} */ - @Override public <K, V> List<Model<Vector, Double>> learnModels(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + @Override public <K, V> List<Model<Vector, Double>> update(GDBTrainer.GDBModel mdlToUpdate, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { DatasetTrainer<? extends Model<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get(); assert trainer instanceof DecisionTree; DecisionTree decisionTreeTrainer = (DecisionTree) trainer; - List<Model<Vector, Double>> models = new ArrayList<>(); + List<Model<Vector, Double>> models = initLearningState(mdlToUpdate); + + ConvergenceChecker<K,V> convCheck = checkConvergenceStgyFactory.create(sampleSize, + externalLbToInternalMapping, loss, datasetBuilder, featureExtractor, lbExtractor); + try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build( new EmptyContextBuilder<>(), new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, useIndex) )) { for (int i = 0; i < cntOfIterations; i++) { - double[] weights = Arrays.copyOf(compositionWeights, i); + double[] weights = Arrays.copyOf(compositionWeights, models.size()); WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLabelValue); - Model<Vector, Double> currComposition = new ModelsComposition(models, aggregator); + ModelsComposition currComposition = new ModelsComposition(models, aggregator); + + if(convCheck.isConverged(dataset, currComposition)) + break; dataset.compute(part -> { if(part.getCopyOfOriginalLabels() == null) @@ -78,7 +87,7 @@ public class GDBOnTreesLearningStrategy extends GDBLearningStrategy { for(int j = 0; j < part.getLabels().length; j++) { double mdlAnswer = currComposition.apply(VectorUtils.of(part.getFeatures()[j])); double originalLbVal = externalLbToInternalMapping.apply(part.getCopyOfOriginalLabels()[j]); - part.getLabels()[j] = -lossGradient.apply(sampleSize, originalLbVal, mdlAnswer); + part.getLabels()[j] = -loss.gradient(sampleSize, originalLbVal, mdlAnswer); } }); @@ -92,6 +101,7 @@ public class GDBOnTreesLearningStrategy extends GDBLearningStrategy { throw new RuntimeException(e); } + compositionWeights = Arrays.copyOf(compositionWeights, models.size()); return models; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java index d5750ea..b8a16dc 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java @@ -19,18 +19,14 @@ package org.apache.ignite.ml.tree.data; import java.util.ArrayList; import java.util.List; +import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData; import org.apache.ignite.ml.tree.TreeFilter; /** - * A partition {@code data} of the containing matrix of features and vector of labels stored in heap. + * A partition {@code data} of the containing matrix of features and vector of labels stored in heap + * with index on features. */ -public class DecisionTreeData implements AutoCloseable { - /** Matrix with features. */ - private final double[][] features; - - /** Vector with labels. */ - private final double[] labels; - +public class DecisionTreeData extends FeatureMatrixWithLabelsOnHeapData implements AutoCloseable { /** Copy of vector with original labels. Auxiliary for Gradient Boosting on Trees.*/ private double[] copyOfOriginalLabels; @@ -48,10 +44,7 @@ public class DecisionTreeData implements AutoCloseable { * @param buildIdx Build index. */ public DecisionTreeData(double[][] features, double[] labels, boolean buildIdx) { - assert features.length == labels.length : "Features and labels have to be the same length"; - - this.features = features; - this.labels = labels; + super(features, labels); this.buildIndex = buildIdx; indexesCache = new ArrayList<>(); @@ -68,6 +61,8 @@ public class DecisionTreeData implements AutoCloseable { public DecisionTreeData filter(TreeFilter filter) { int size = 0; + double[][] features = getFeatures(); + double[] labels = getLabels(); for (int i = 0; i < features.length; i++) if (filter.test(features[i])) size++; @@ -95,12 +90,15 @@ public class DecisionTreeData implements AutoCloseable { * @param col Column. */ public void sort(int col) { - sort(col, 0, features.length - 1); + sort(col, 0, getFeatures().length - 1); } /** */ private void sort(int col, int from, int to) { if (from < to) { + double[][] features = getFeatures(); + double[] labels = getLabels(); + double pivot = features[(from + to) / 2][col]; int i = from, j = to; @@ -131,19 +129,11 @@ public class DecisionTreeData implements AutoCloseable { } /** */ - public double[][] getFeatures() { - return features; - } - - /** */ - public double[] getLabels() { - return labels; - } - public double[] getCopyOfOriginalLabels() { return copyOfOriginalLabels; } + /** */ public void setCopyOfOriginalLabels(double[] copyOfOriginalLabels) { this.copyOfOriginalLabels = copyOfOriginalLabels; } @@ -170,7 +160,7 @@ public class DecisionTreeData implements AutoCloseable { if (depth == indexesCache.size()) { if (depth == 0) - indexesCache.add(new TreeDataIndex(features, labels)); + indexesCache.add(new TreeDataIndex(getFeatures(), getLabels())); else { TreeDataIndex lastIndex = indexesCache.get(depth - 1); indexesCache.add(lastIndex.filter(filter)); http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java index 3e340f6..89b8c9c 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java @@ -22,11 +22,13 @@ import java.util.Map; import java.util.function.BiFunction; import org.apache.ignite.ml.Model; import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory; +import org.apache.ignite.ml.composition.boosting.convergence.simple.ConvergenceCheckerStubFactory; import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; -import org.apache.ignite.ml.trainers.DatasetTrainer; import org.apache.ignite.ml.tree.DecisionTreeConditionalNode; import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer; import org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer; @@ -54,8 +56,8 @@ public class GDBTrainerTest { learningSample.put(i, new double[] {xs[i], ys[i]}); } - DatasetTrainer<ModelsComposition, Double> trainer - = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0).withUseIndex(true); + GDBTrainer trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0) + .withUseIndex(true); Model<Vector, Double> mdl = trainer.fit( learningSample, 1, @@ -74,7 +76,6 @@ public class GDBTrainerTest { assertEquals(0.0, mse, 0.0001); - assertTrue(mdl instanceof ModelsComposition); ModelsComposition composition = (ModelsComposition)mdl; assertTrue(composition.toString().length() > 0); assertTrue(composition.toString(true).length() > 0); @@ -84,6 +85,13 @@ public class GDBTrainerTest { assertEquals(2000, composition.getModels().size()); assertTrue(composition.getPredictionsAggregator() instanceof WeightedPredictionsAggregator); + + trainer = trainer.withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.1)); + assertTrue(trainer.fit( + learningSample, 1, + (k, v) -> VectorUtils.of(v[0]), + (k, v) -> v[1] + ).getModels().size() < 2000); } /** */ @@ -107,7 +115,7 @@ public class GDBTrainerTest { } /** */ - private void testClassifier(BiFunction<GDBBinaryClassifierOnTreesTrainer, Map<Integer, double[]>, + private void testClassifier(BiFunction<GDBTrainer, Map<Integer, double[]>, Model<Vector, Double>> fitter) { int sampleSize = 100; double[] xs = new double[sampleSize]; @@ -122,8 +130,9 @@ public class GDBTrainerTest { for (int i = 0; i < sampleSize; i++) learningSample.put(i, new double[] {xs[i], ys[i]}); - GDBBinaryClassifierOnTreesTrainer trainer - = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0).withUseIndex(true); + GDBTrainer trainer = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0) + .withUseIndex(true) + .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.3)); Model<Vector, Double> mdl = fitter.apply(trainer, learningSample); @@ -132,7 +141,7 @@ public class GDBTrainerTest { double x = xs[j]; double y = ys[j]; double p = mdl.apply(VectorUtils.of(x)); - if(p != y) + if (p != y) errorsCnt++; } @@ -142,7 +151,61 @@ public class GDBTrainerTest { ModelsComposition composition = (ModelsComposition)mdl; composition.getModels().forEach(m -> assertTrue(m instanceof DecisionTreeConditionalNode)); - assertEquals(500, composition.getModels().size()); + assertTrue(composition.getModels().size() < 500); assertTrue(composition.getPredictionsAggregator() instanceof WeightedPredictionsAggregator); + + trainer = trainer.withCheckConvergenceStgyFactory(new ConvergenceCheckerStubFactory()); + assertEquals(500, ((ModelsComposition)fitter.apply(trainer, learningSample)).getModels().size()); + } + + /** */ + @Test + public void testUpdate() { + int sampleSize = 100; + double[] xs = new double[sampleSize]; + double[] ys = new double[sampleSize]; + + for (int i = 0; i < sampleSize; i++) { + xs[i] = i; + ys[i] = ((int)(xs[i] / 10.0) % 2) == 0 ? -1.0 : 1.0; + } + + Map<Integer, double[]> learningSample = new HashMap<>(); + for (int i = 0; i < sampleSize; i++) + learningSample.put(i, new double[] {xs[i], ys[i]}); + IgniteBiFunction<Integer, double[], Vector> fExtr = (k, v) -> VectorUtils.of(v[0]); + IgniteBiFunction<Integer, double[], Double> lExtr = (k, v) -> v[1]; + + GDBTrainer classifTrainer = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0) + .withUseIndex(true) + .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.3)); + GDBTrainer regressTrainer = new GDBRegressionOnTreesTrainer(0.3, 500, 3, 0.0) + .withUseIndex(true) + .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.3)); + + testUpdate(learningSample, fExtr, lExtr, classifTrainer); + testUpdate(learningSample, fExtr, lExtr, regressTrainer); + } + + /** */ + private void testUpdate(Map<Integer, double[]> dataset, IgniteBiFunction<Integer, double[], Vector> fExtr, + IgniteBiFunction<Integer, double[], Double> lExtr, GDBTrainer trainer) { + + ModelsComposition originalMdl = trainer.fit(dataset, 1, fExtr, lExtr); + ModelsComposition updatedOnSameDataset = trainer.update(originalMdl, dataset, 1, fExtr, lExtr); + + LocalDatasetBuilder<Integer, double[]> epmtyDataset = new LocalDatasetBuilder<>(new HashMap<>(), 1); + ModelsComposition updatedOnEmptyDataset = trainer.updateModel(originalMdl, epmtyDataset, fExtr, lExtr); + + dataset.forEach((k,v) -> { + Vector features = fExtr.apply(k, v); + + Double originalAnswer = originalMdl.apply(features); + Double updatedMdlAnswer1 = updatedOnSameDataset.apply(features); + Double updatedMdlAnswer2 = updatedOnEmptyDataset.apply(features); + + assertEquals(originalAnswer, updatedMdlAnswer1, 0.01); + assertEquals(originalAnswer, updatedMdlAnswer2, 0.01); + }); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java new file mode 100644 index 0000000..50fdf8b --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting.convergence; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.boosting.loss.Loss; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.junit.Before; + +/** */ +public abstract class ConvergenceCheckerTest { + /** Not converged model. */ + protected ModelsComposition notConvergedMdl = new ModelsComposition(Collections.emptyList(), null) { + @Override public Double apply(Vector features) { + return 2.1 * features.get(0); + } + }; + + /** Converged model. */ + protected ModelsComposition convergedMdl = new ModelsComposition(Collections.emptyList(), null) { + @Override public Double apply(Vector features) { + return 2 * (features.get(0) + 1); + } + }; + + /** Features extractor. */ + protected IgniteBiFunction<double[], Double, Vector> fExtr = (x, y) -> VectorUtils.of(x); + + /** Label extractor. */ + protected IgniteBiFunction<double[], Double, Double> lbExtr = (x, y) -> y; + + /** Data. */ + protected Map<double[], Double> data; + + /** */ + @Before + public void setUp() throws Exception { + data = new HashMap<>(); + for(int i = 0; i < 10; i ++) + data.put(new double[]{i, i + 1}, (double)(2 * (i + 1))); + } + + /** */ + public ConvergenceChecker<double[], Double> createChecker(ConvergenceCheckerFactory factory, + LocalDatasetBuilder<double[], Double> datasetBuilder) { + + return factory.create(data.size(), + x -> x, + new Loss() { + @Override public double error(long sampleSize, double lb, double mdlAnswer) { + return mdlAnswer - lb; + } + + @Override public double gradient(long sampleSize, double lb, double mdlAnswer) { + return mdlAnswer - lb; + } + }, + datasetBuilder, fExtr, lbExtr + ); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java new file mode 100644 index 0000000..0b42db8 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting.convergence.mean; + +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker; +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerTest; +import org.apache.ignite.ml.dataset.impl.local.LocalDataset; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData; +import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder; +import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.junit.Assert; +import org.junit.Test; + +/** */ +public class MeanAbsValueConvergenceCheckerTest extends ConvergenceCheckerTest { + /** */ + @Test + public void testConvergenceChecking() { + LocalDatasetBuilder<double[], Double> datasetBuilder = new LocalDatasetBuilder<>(data, 1); + ConvergenceChecker<double[], Double> checker = createChecker( + new MeanAbsValueConvergenceCheckerFactory(0.1), datasetBuilder); + + double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl); + Assert.assertEquals(1.9, error, 0.01); + Assert.assertFalse(checker.isConverged(datasetBuilder, notConvergedMdl)); + Assert.assertTrue(checker.isConverged(datasetBuilder, convergedMdl)); + + try(LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build( + new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(fExtr, lbExtr))) { + + double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl); + Assert.assertEquals(1.55, onDSError, 0.01); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** Mean error more sensitive to anomalies in data */ + @Test + public void testConvergenceCheckingWithAnomaliesInData() { + data.put(new double[]{10, 11}, 100000.0); + LocalDatasetBuilder<double[], Double> datasetBuilder = new LocalDatasetBuilder<>(data, 1); + ConvergenceChecker<double[], Double> checker = createChecker( + new MeanAbsValueConvergenceCheckerFactory(0.1), datasetBuilder); + + try(LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build( + new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(fExtr, lbExtr))) { + + double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl); + Assert.assertEquals(9090.41, onDSError, 0.01); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java new file mode 100644 index 0000000..d6880b4 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting.convergence.median; + +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker; +import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerTest; +import org.apache.ignite.ml.dataset.impl.local.LocalDataset; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData; +import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder; +import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.junit.Assert; +import org.junit.Test; + +/** */ +public class MedianOfMedianConvergenceCheckerTest extends ConvergenceCheckerTest { + /** */ + @Test + public void testConvergenceChecking() { + data.put(new double[]{10, 11}, 100000.0); + LocalDatasetBuilder<double[], Double> datasetBuilder = new LocalDatasetBuilder<>(data, 1); + + ConvergenceChecker<double[], Double> checker = createChecker( + new MedianOfMedianConvergenceCheckerFactory(0.1), datasetBuilder); + + double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl); + Assert.assertEquals(1.9, error, 0.01); + Assert.assertFalse(checker.isConverged(datasetBuilder, notConvergedMdl)); + Assert.assertTrue(checker.isConverged(datasetBuilder, convergedMdl)); + + try(LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build( + new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(fExtr, lbExtr))) { + + double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl); + Assert.assertEquals(1.6, onDSError, 0.01); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java index f88fd3e..b06fd67 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java @@ -17,185 +17,44 @@ package org.apache.ignite.ml.environment; -import java.util.Arrays; -import java.util.UUID; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import java.util.stream.IntStream; -import javax.cache.Cache; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; -import org.apache.ignite.cache.query.QueryCursor; -import org.apache.ignite.cache.query.ScanQuery; -import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.dataset.feature.FeatureMeta; import org.apache.ignite.ml.environment.logging.ConsoleLogger; import org.apache.ignite.ml.environment.logging.MLLogger; +import org.apache.ignite.ml.environment.parallelism.DefaultParallelismStrategy; import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy; -import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainer; import org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies; -import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; -import org.apache.ignite.thread.IgniteThread; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; /** * Tests for {@link LearningEnvironment} that require to start the whole Ignite infrastructure. IMPL NOTE based on * RandomForestRegressionExample example. */ -public class LearningEnvironmentTest extends GridCommonAbstractTest { - /** Number of nodes in grid */ - private static final int NODE_COUNT = 1; - - /** Ignite instance. */ - private Ignite ignite; - - /** {@inheritDoc} */ - @Override protected void beforeTestsStarted() throws Exception { - for (int i = 1; i <= NODE_COUNT; i++) - startGrid(i); - } - - /** {@inheritDoc} */ - @Override protected void afterTestsStopped() { - stopAllGrids(); - } - - /** - * {@inheritDoc} - */ - @Override protected void beforeTest() { - /* Grid instance. */ - ignite = grid(NODE_COUNT); - ignite.configuration().setPeerClassLoadingEnabled(true); - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - } - +public class LearningEnvironmentTest { /** */ + @Test public void testBasic() throws InterruptedException { - AtomicReference<Integer> actualAmount = new AtomicReference<>(null); - AtomicReference<Double> actualMse = new AtomicReference<>(null); - AtomicReference<Double> actualMae = new AtomicReference<>(null); - - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - LearningEnvironmentTest.class.getSimpleName(), () -> { - IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); - - AtomicInteger idx = new AtomicInteger(0); - RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer( - IntStream.range(0, data[0].length - 1).mapToObj( - x -> new FeatureMeta("", idx.getAndIncrement(), false)).collect(Collectors.toList()) - ).withCountOfTrees(101) - .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD) - .withMaxDepth(4) - .withMinImpurityDelta(0.) - .withSubsampleSize(0.3) - .withSeed(0); - - trainer.setEnvironment(LearningEnvironment.builder() - .withParallelismStrategy(ParallelismStrategy.Type.ON_DEFAULT_POOL) - .withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.LOW)) - .build() - ); - - ModelsComposition randomForest = trainer.fit(ignite, dataCache, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), - (k, v) -> v[v.length - 1] - ); - - double mse = 0.0; - double mae = 0.0; - int totalAmount = 0; - - try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, double[]> observation : observations) { - double difference = estimatePrediction(randomForest, observation); - - mse += Math.pow(difference, 2.0); - mae += Math.abs(difference); - - totalAmount++; - } - } - - actualAmount.set(totalAmount); - - mse = mse / totalAmount; - actualMse.set(mse); - - mae = mae / totalAmount; - actualMae.set(mae); - }); - - igniteThread.start(); - igniteThread.join(); - - assertEquals("Total amount", 23, (int)actualAmount.get()); - assertTrue("Mean squared error (MSE)", actualMse.get() > 0); - assertTrue("Mean absolute error (MAE)", actualMae.get() > 0); + RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer( + IntStream.range(0, 0).mapToObj( + x -> new FeatureMeta("", 0, false)).collect(Collectors.toList()) + ).withCountOfTrees(101) + .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD) + .withMaxDepth(4) + .withMinImpurityDelta(0.) + .withSubsampleSize(0.3) + .withSeed(0); + + LearningEnvironment environment = LearningEnvironment.builder() + .withParallelismStrategy(ParallelismStrategy.Type.ON_DEFAULT_POOL) + .withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.LOW)) + .build(); + trainer.setEnvironment(environment); + assertEquals(DefaultParallelismStrategy.class, environment.parallelismStrategy().getClass()); + assertEquals(ConsoleLogger.class, environment.logger().getClass()); } - - /** */ - private double estimatePrediction(ModelsComposition randomForest, Cache.Entry<Integer, double[]> observation) { - double[] val = observation.getValue(); - double[] inputs = Arrays.copyOfRange(val, 0, val.length - 1); - double groundTruth = val[val.length - 1]; - - double prediction = randomForest.apply(VectorUtils.of(inputs)); - - return prediction - groundTruth; - } - - /** - * Fills cache with data and returns it. - * - * @param ignite Ignite instance. - * @return Filled Ignite Cache. - */ - private IgniteCache<Integer, double[]> getTestCache(Ignite ignite) { - CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>(); - cacheConfiguration.setName(UUID.randomUUID().toString()); - cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); - - IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration); - - for (int i = 0; i < data.length; i++) - cache.put(i, data[i]); - - return cache; - } - - /** - * Part of the Boston housing dataset. - */ - private static final double[][] data = { - {0.02731,0.00,7.070,0,0.4690,6.4210,78.90,4.9671,2,242.0,17.80,396.90,9.14,21.60}, - {0.02729,0.00,7.070,0,0.4690,7.1850,61.10,4.9671,2,242.0,17.80,392.83,4.03,34.70}, - {0.03237,0.00,2.180,0,0.4580,6.9980,45.80,6.0622,3,222.0,18.70,394.63,2.94,33.40}, - {0.06905,0.00,2.180,0,0.4580,7.1470,54.20,6.0622,3,222.0,18.70,396.90,5.33,36.20}, - {0.02985,0.00,2.180,0,0.4580,6.4300,58.70,6.0622,3,222.0,18.70,394.12,5.21,28.70}, - {0.08829,12.50,7.870,0,0.5240,6.0120,66.60,5.5605,5,311.0,15.20,395.60,12.43,22.90}, - {0.14455,12.50,7.870,0,0.5240,6.1720,96.10,5.9505,5,311.0,15.20,396.90,19.15,27.10}, - {0.21124,12.50,7.870,0,0.5240,5.6310,100.00,6.0821,5,311.0,15.20,386.63,29.93,16.50}, - {0.17004,12.50,7.870,0,0.5240,6.0040,85.90,6.5921,5,311.0,15.20,386.71,17.10,18.90}, - {0.22489,12.50,7.870,0,0.5240,6.3770,94.30,6.3467,5,311.0,15.20,392.52,20.45,15.00}, - {0.11747,12.50,7.870,0,0.5240,6.0090,82.90,6.2267,5,311.0,15.20,396.90,13.27,18.90}, - {0.09378,12.50,7.870,0,0.5240,5.8890,39.00,5.4509,5,311.0,15.20,390.50,15.71,21.70}, - {0.62976,0.00,8.140,0,0.5380,5.9490,61.80,4.7075,4,307.0,21.00,396.90,8.26,20.40}, - {0.63796,0.00,8.140,0,0.5380,6.0960,84.50,4.4619,4,307.0,21.00,380.02,10.26,18.20}, - {0.62739,0.00,8.140,0,0.5380,5.8340,56.50,4.4986,4,307.0,21.00,395.62,8.47,19.90}, - {1.05393,0.00,8.140,0,0.5380,5.9350,29.30,4.4986,4,307.0,21.00,386.85,6.58,23.10}, - {0.78420,0.00,8.140,0,0.5380,5.9900,81.70,4.2579,4,307.0,21.00,386.75,14.67,17.50}, - {0.80271,0.00,8.140,0,0.5380,5.4560,36.60,3.7965,4,307.0,21.00,288.99,11.69,20.20}, - {0.72580,0.00,8.140,0,0.5380,5.7270,69.50,3.7965,4,307.0,21.00,390.95,11.28,18.20}, - {1.25179,0.00,8.140,0,0.5380,5.5700,98.10,3.7979,4,307.0,21.00,376.57,21.02,13.60}, - {0.85204,0.00,8.140,0,0.5380,5.9650,89.20,4.0123,4,307.0,21.00,392.53,13.83,19.60}, - {1.23247,0.00,8.140,0,0.5380,6.1420,91.70,3.9769,4,307.0,21.00,396.90,18.72,15.20}, - {0.98843,0.00,8.140,0,0.5380,5.8130,100.00,4.0952,4,307.0,21.00,394.54,19.88,14.50} - }; - } http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java index d8fb620..199644b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java @@ -93,17 +93,21 @@ public class ANNClassificationTest extends TrainerTest { .withDistanceMeasure(new EuclideanDistance()) .withStrategy(NNStrategy.SIMPLE); - ANNClassificationModel updatedOnSameDataset = trainer.withSeed(1234L).update(originalMdl, + ANNClassificationModel updatedOnSameDataset = (ANNClassificationModel) trainer.withSeed(1234L).update(originalMdl, cacheMock, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[2] - ); + ).withK(3) + .withDistanceMeasure(new EuclideanDistance()) + .withStrategy(NNStrategy.SIMPLE); - ANNClassificationModel updatedOnEmptyDataset = trainer.withSeed(1234L).update(originalMdl, + ANNClassificationModel updatedOnEmptyDataset = (ANNClassificationModel) trainer.withSeed(1234L).update(originalMdl, new HashMap<Integer, double[]>(), parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[2] - ); + ).withK(3) + .withDistanceMeasure(new EuclideanDistance()) + .withStrategy(NNStrategy.SIMPLE); Vector v1 = VectorUtils.of(550, 550); Vector v2 = VectorUtils.of(-550, -550);
