Repository: ignite Updated Branches: refs/heads/master 9b674ed9a -> 142648df5
http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java new file mode 100644 index 0000000..405c70b --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition; + +import java.util.Arrays; +import java.util.Map; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.common.TrainerTest; +import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator; +import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.environment.LearningEnvironment; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteTriFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.trainers.DatasetTrainer; +import org.apache.ignite.ml.trainers.TrainerTransformers; +import org.junit.Test; + +/** + * Tests for bagging algorithm. + */ +public class BaggingTest extends TrainerTest { + /** + * Test that count of entries in context is equal to initial dataset size * subsampleRatio. + */ + @Test + public void testBaggingContextCount() { + count((ctxCount, countData, integer) -> ctxCount); + } + + /** + * Test that count of entries in data is equal to initial dataset size * subsampleRatio. + */ + @Test + public void testBaggingDataCount() { + count((ctxCount, countData, integer) -> countData.cnt); + } + + /** + * Test that bagged log regression makes correct predictions. + */ + @Test + public void testNaiveBaggingLogRegression() { + Map<Integer, Double[]> cacheMock = getCacheMock(twoLinearlySeparableClasses); + + DatasetTrainer<LogisticRegressionModel, Double> trainer = + new LogisticRegressionSGDTrainer() + .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg)) + .withMaxIterations(30000) + .withLocIterations(100) + .withBatchSize(10) + .withSeed(123L); + + trainer.withEnvironmentBuilder(TestUtils.testEnvBuilder()); + + DatasetTrainer<ModelsComposition, Double> baggedTrainer = + TrainerTransformers.makeBagged( + trainer, + 10, + 0.7, + 2, + 2, + new OnMajorityPredictionsAggregator()); + + ModelsComposition mdl = baggedTrainer.fit( + cacheMock, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION); + TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION); + } + + /** + * Method used to test counts of data passed in context and in data builders. + * + * @param cntr Function specifying which data we should count. + */ + protected void count(IgniteTriFunction<Long, CountData, LearningEnvironment, Long> cntr) { + Map<Integer, Double[]> cacheMock = getCacheMock(twoLinearlySeparableClasses); + + CountTrainer cntTrainer = new CountTrainer(cntr); + + double subsampleRatio = 0.3; + + ModelsComposition mdl = TrainerTransformers.makeBagged( + cntTrainer, + 100, + subsampleRatio, + 2, + 2, + new MeanValuePredictionsAggregator()) + .fit(cacheMock, parts, null, null); + + Double res = mdl.apply(null); + + TestUtils.assertEquals(twoLinearlySeparableClasses.length * subsampleRatio, res, twoLinearlySeparableClasses.length / 10); + } + + /** + * Get sum of two Long values each of which can be null. + * + * @param a First value. + * @param b Second value. + * @return Sum of parameters. + */ + protected static Long plusOfNullables(Long a, Long b) { + if (a == null) + return b; + + if (b == null) + return a; + + return a + b; + } + + /** + * Trainer used to count entries in context or in data. + */ + protected static class CountTrainer extends DatasetTrainer<Model<Vector, Double>, Double> { + /** + * Function specifying which entries to count. + */ + private final IgniteTriFunction<Long, CountData, LearningEnvironment, Long> cntr; + + /** + * Construct instance of this class. + * + * @param cntr Function specifying which entries to count. + */ + public CountTrainer(IgniteTriFunction<Long, CountData, LearningEnvironment, Long> cntr) { + this.cntr = cntr; + } + + /** {@inheritDoc} */ + @Override public <K, V> Model<Vector, Double> fit( + DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + Dataset<Long, CountData> dataset = datasetBuilder.build( + TestUtils.testEnvBuilder(), + (env, upstreamData, upstreamDataSize) -> upstreamDataSize, + (env, upstreamData, upstreamDataSize, ctx) -> new CountData(upstreamDataSize) + ); + + Long cnt = dataset.computeWithCtx(cntr, BaggingTest::plusOfNullables); + + return x -> Double.valueOf(cnt); + } + + /** {@inheritDoc} */ + @Override protected boolean checkState(Model<Vector, Double> mdl) { + return true; + } + + /** {@inheritDoc} */ + @Override protected <K, V> Model<Vector, Double> updateModel( + Model<Vector, Double> mdl, + DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + return fit(datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override public CountTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (CountTrainer)super.withEnvironmentBuilder(envBuilder); + } + } + + /** Data for count trainer. */ + protected static class CountData implements AutoCloseable { + /** Counter. */ + private long cnt; + + /** + * Construct instance of this class. + * + * @param cnt Counter. + */ + public CountData(long cnt) { + this.cnt = cnt; + } + + /** {@inheritDoc} */ + @Override public void close() throws Exception { + // No-op + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java index 8714eb2..87d56cd 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java @@ -25,13 +25,15 @@ import org.junit.runner.RunWith; import org.junit.runners.Suite; /** - * Test suite for all tests located in org.apache.ignite.ml.composition package. + * Test suite for all ensemble models tests. */ @RunWith(Suite.class) @Suite.SuiteClasses({ GDBTrainerTest.class, MeanValuePredictionsAggregatorTest.class, OnMajorityPredictionsAggregatorTest.class, + BaggingTest.class, + StackingTest.class, WeightedPredictionsAggregatorTest.class }) public class CompositionTestSuite { http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java new file mode 100644 index 0000000..3336470 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition; + +import java.util.Arrays; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.common.TrainerTest; +import org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer; +import org.apache.ignite.ml.composition.stacking.StackedModel; +import org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.matrix.Matrix; +import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.nn.Activators; +import org.apache.ignite.ml.nn.MLPTrainer; +import org.apache.ignite.ml.nn.MultilayerPerceptron; +import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.nn.architecture.MLPArchitecture; +import org.apache.ignite.ml.optimization.LossFunctions; +import org.apache.ignite.ml.optimization.SmoothParametrized; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; +import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer; +import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; +import org.apache.ignite.ml.trainers.AdaptableDatasetModel; +import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer; +import org.apache.ignite.ml.trainers.DatasetTrainer; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import static junit.framework.TestCase.assertEquals; + +/** + * Tests stacked trainers. + */ +public class StackingTest extends TrainerTest { + /** Rule to check exceptions. */ + @Rule + public ExpectedException thrown = ExpectedException.none(); + + /** + * Tests simple stack training. + */ + @Test + public void testSimpleStack() { + StackedDatasetTrainer<Vector, Vector, Double, LinearRegressionModel, Double> trainer = + new StackedDatasetTrainer<>(); + + UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ); + + MLPArchitecture arch = new MLPArchitecture(2). + withAddedLayer(10, true, Activators.RELU). + withAddedLayer(1, false, Activators.SIGMOID); + + MLPTrainer<SimpleGDParameterUpdate> trainer1 = new MLPTrainer<>( + arch, + LossFunctions.MSE, + updatesStgy, + 3000, + 10, + 50, + 123L + ); + + // Convert model trainer to produce Vector -> Vector model + DatasetTrainer<AdaptableDatasetModel<Vector, Vector, Matrix, Matrix, MultilayerPerceptron>, Double> mlpTrainer = + AdaptableDatasetTrainer.of(trainer1) + .beforeTrainedModel((Vector v) -> new DenseMatrix(v.asArray(), 1)) + .afterTrainedModel((Matrix mtx) -> mtx.getRow(0)) + .withConvertedLabels(VectorUtils::num2Arr); + + final double factor = 3; + + StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer + .withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor)) + .addTrainer(mlpTrainer) + .withAggregatorInputMerger(VectorUtils::concat) + .withSubmodelOutput2VectorConverter(IgniteFunction.identity()) + .withVector2SubmodelInputConverter(IgniteFunction.identity()) + .withOriginalFeaturesKept(IgniteFunction.identity()) + .withEnvironmentBuilder(TestUtils.testEnvBuilder()) + .fit(getCacheMock(xor), + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[v.length - 1]); + + assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(0.0, 0.0)), 0.3); + assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(0.0, 1.0)), 0.3); + assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(1.0, 0.0)), 0.3); + assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(1.0, 1.0)), 0.3); + } + + /** + * Tests simple stack training. + */ + @Test + public void testSimpleVectorStack() { + StackedVectorDatasetTrainer<Double, LinearRegressionModel, Double> trainer = + new StackedVectorDatasetTrainer<>(); + + UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ); + + MLPArchitecture arch = new MLPArchitecture(2). + withAddedLayer(10, true, Activators.RELU). + withAddedLayer(1, false, Activators.SIGMOID); + + DatasetTrainer<MultilayerPerceptron, Double> mlpTrainer = new MLPTrainer<>( + arch, + LossFunctions.MSE, + updatesStgy, + 3000, + 10, + 50, + 123L + ).withConvertedLabels(VectorUtils::num2Arr); + + final double factor = 3; + + StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer + .withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor)) + .addMatrix2MatrixTrainer(mlpTrainer) + .withEnvironmentBuilder(TestUtils.testEnvBuilder()) + .fit(getCacheMock(xor), + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[v.length - 1]); + + assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(0.0, 0.0)), 0.3); + assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(0.0, 1.0)), 0.3); + assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(1.0, 0.0)), 0.3); + assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(1.0, 1.0)), 0.3); + } + + /** + * Tests that if there is no any way for input of first layer to propagate to second layer, + * exception will be thrown. + */ + @Test + public void testINoWaysOfPropagation() { + StackedDatasetTrainer<Void, Void, Void, Model<Void, Void>, Void> trainer = + new StackedDatasetTrainer<>(); + thrown.expect(IllegalStateException.class); + trainer.fit(null, null, null); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java index 61f9fc4..74841a3 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java @@ -47,7 +47,7 @@ public class OneVsRestTrainerTest extends TrainerTest { for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]); - LogisticRegressionSGDTrainer<?> binaryTrainer = new LogisticRegressionSGDTrainer<>() + LogisticRegressionSGDTrainer binaryTrainer = new LogisticRegressionSGDTrainer() .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg)) .withMaxIterations(1000) @@ -80,7 +80,7 @@ public class OneVsRestTrainerTest extends TrainerTest { for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]); - LogisticRegressionSGDTrainer<?> binaryTrainer = new LogisticRegressionSGDTrainer<>() + LogisticRegressionSGDTrainer binaryTrainer = new LogisticRegressionSGDTrainer() .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg)) .withMaxIterations(1000) http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java index bd31b19..5ee50a6 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java @@ -106,7 +106,7 @@ public class MLPTrainerMnistIntegrationTest extends GridCommonAbstractTest { ignite, trainingSet, (k, v) -> VectorUtils.of(v.getPixels()), - (k, v) -> VectorUtils.num2Vec(v.getLabel(), 10).getStorage().data() + (k, v) -> VectorUtils.oneHot(v.getLabel(), 10).getStorage().data() ); System.out.println("Training completed in " + (System.currentTimeMillis() - start) + "ms"); http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java index 6a17d18..9396009 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java @@ -76,7 +76,7 @@ public class MLPTrainerMnistTest { trainingSet, 1, (k, v) -> VectorUtils.of(v.getPixels()), - (k, v) -> VectorUtils.num2Vec(v.getLabel(), 10).getStorage().data() + (k, v) -> VectorUtils.oneHot(v.getLabel(), 10).getStorage().data() ); System.out.println("Training completed in " + (System.currentTimeMillis() - start) + "ms"); http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java index fec6220..694dcd3 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java @@ -51,7 +51,7 @@ public class PipelineTest extends TrainerTest { cacheMock.put(i, convertedRow); } - LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>() + LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer() .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg)) .withMaxIterations(100000) http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java index c343ab9..681cb72 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java @@ -43,7 +43,7 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest { for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]); - LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>() + LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer() .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg)) .withMaxIterations(100000) @@ -70,7 +70,7 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest { for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]); - LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>() + LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer() .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg)) .withMaxIterations(100000) http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java deleted file mode 100644 index 31fe8b3..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.common.TrainerTest; -import org.apache.ignite.ml.composition.ModelsComposition; -import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator; -import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator; -import org.apache.ignite.ml.dataset.Dataset; -import org.apache.ignite.ml.dataset.DatasetBuilder; -import org.apache.ignite.ml.environment.LearningEnvironment; -import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.functions.IgniteTriFunction; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.math.primitives.vector.VectorUtils; -import org.apache.ignite.ml.nn.UpdatesStrategy; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; -import org.junit.Test; - -/** - * Tests for bagging algorithm. - */ -public class BaggingTest extends TrainerTest { - /** - * Test that count of entries in context is equal to initial dataset size * subsampleRatio. - */ - @Test - public void testBaggingContextCount() { - count((ctxCount, countData, integer) -> ctxCount); - } - - /** - * Test that count of entries in data is equal to initial dataset size * subsampleRatio. - */ - @Test - public void testBaggingDataCount() { - count((ctxCount, countData, integer) -> countData.cnt); - } - - /** - * Test that bagged log regression makes correct predictions. - */ - @Test - public void testNaiveBaggingLogRegression() { - Map<Integer, Double[]> cacheMock = getCacheMock(); - - DatasetTrainer<LogisticRegressionModel, Double> trainer = - (LogisticRegressionSGDTrainer<?>)new LogisticRegressionSGDTrainer<>() - .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), - SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg)) - .withMaxIterations(30000) - .withLocIterations(100) - .withBatchSize(10) - .withSeed(123L); - - trainer.withEnvironmentBuilder(TestUtils.testEnvBuilder()); - - DatasetTrainer<ModelsComposition, Double> baggedTrainer = - TrainerTransformers.makeBagged( - trainer, - 10, - 0.7, - 2, - 2, - new OnMajorityPredictionsAggregator()); - - ModelsComposition mdl = baggedTrainer.fit( - cacheMock, - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION); - TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION); - } - - /** - * Method used to test counts of data passed in context and in data builders. - * - * @param counter Function specifying which data we should count. - */ - protected void count(IgniteTriFunction<Long, CountData, LearningEnvironment, Long> counter) { - Map<Integer, Double[]> cacheMock = getCacheMock(); - - CountTrainer countTrainer = new CountTrainer(counter); - - double subsampleRatio = 0.3; - - ModelsComposition model = TrainerTransformers.makeBagged( - countTrainer, - 100, - subsampleRatio, - 2, - 2, - new MeanValuePredictionsAggregator()) - .fit(cacheMock, parts, null, null); - - Double res = model.apply(null); - - TestUtils.assertEquals(twoLinearlySeparableClasses.length * subsampleRatio, res, twoLinearlySeparableClasses.length / 10); - } - - /** - * Create cache mock. - * - * @return Cache mock. - */ - private Map<Integer, Double[]> getCacheMock() { - Map<Integer, Double[]> cacheMock = new HashMap<>(); - - for (int i = 0; i < twoLinearlySeparableClasses.length; i++) { - double[] row = twoLinearlySeparableClasses[i]; - Double[] convertedRow = new Double[row.length]; - for (int j = 0; j < row.length; j++) - convertedRow[j] = row[j]; - cacheMock.put(i, convertedRow); - } - return cacheMock; - } - - /** - * Get sum of two Long values each of which can be null. - * - * @param a First value. - * @param b Second value. - * @return Sum of parameters. - */ - protected static Long plusOfNullables(Long a, Long b) { - if (a == null) - return b; - - if (b == null) - return a; - - return a + b; - } - - /** - * Trainer used to count entries in context or in data. - */ - protected static class CountTrainer extends DatasetTrainer<Model<Vector, Double>, Double> { - /** - * Function specifying which entries to count. - */ - private final IgniteTriFunction<Long, CountData, LearningEnvironment, Long> counter; - - /** - * Construct instance of this class. - * - * @param counter Function specifying which entries to count. - */ - public CountTrainer(IgniteTriFunction<Long, CountData, LearningEnvironment, Long> counter) { - this.counter = counter; - } - - /** {@inheritDoc} */ - @Override public <K, V> Model<Vector, Double> fit( - DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, Double> lbExtractor) { - Dataset<Long, CountData> dataset = datasetBuilder.build( - TestUtils.testEnvBuilder(), - (env, upstreamData, upstreamDataSize) -> upstreamDataSize, - (env, upstreamData, upstreamDataSize, ctx) -> new CountData(upstreamDataSize) - ); - - Long cnt = dataset.computeWithCtx(counter, BaggingTest::plusOfNullables); - - return x -> Double.valueOf(cnt); - } - - /** {@inheritDoc} */ - @Override protected boolean checkState(Model<Vector, Double> mdl) { - return true; - } - - /** {@inheritDoc} */ - @Override protected <K, V> Model<Vector, Double> updateModel( - Model<Vector, Double> mdl, - DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - return fit(datasetBuilder, featureExtractor, lbExtractor); - } - - /** {@inheritDoc} */ - @Override public CountTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { - return (CountTrainer)super.withEnvironmentBuilder(envBuilder); - } - } - - /** Data for count trainer. */ - protected static class CountData implements AutoCloseable { - /** Counter. */ - private long cnt; - - /** - * Construct instance of this class. - * - * @param cnt Counter. - */ - public CountData(long cnt) { - this.cnt = cnt; - } - - /** {@inheritDoc} */ - @Override public void close() throws Exception { - // No-op - } - } -}