Repository: ignite
Updated Branches:
  refs/heads/master 9b674ed9a -> 142648df5


http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java
new file mode 100644
index 0000000..405c70b
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition;
+
+import java.util.Arrays;
+import java.util.Map;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.common.TrainerTest;
+import 
org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
+import 
org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteTriFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import 
org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import 
org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.apache.ignite.ml.trainers.TrainerTransformers;
+import org.junit.Test;
+
+/**
+ * Tests for bagging algorithm.
+ */
+public class BaggingTest extends TrainerTest {
+    /**
+     * Test that count of entries in context is equal to initial dataset size 
* subsampleRatio.
+     */
+    @Test
+    public void testBaggingContextCount() {
+        count((ctxCount, countData, integer) -> ctxCount);
+    }
+
+    /**
+     * Test that count of entries in data is equal to initial dataset size * 
subsampleRatio.
+     */
+    @Test
+    public void testBaggingDataCount() {
+        count((ctxCount, countData, integer) -> countData.cnt);
+    }
+
+    /**
+     * Test that bagged log regression makes correct predictions.
+     */
+    @Test
+    public void testNaiveBaggingLogRegression() {
+        Map<Integer, Double[]> cacheMock = 
getCacheMock(twoLinearlySeparableClasses);
+
+        DatasetTrainer<LogisticRegressionModel, Double> trainer =
+            new LogisticRegressionSGDTrainer()
+                .withUpdatesStgy(new UpdatesStrategy<>(new 
SimpleGDUpdateCalculator(0.2),
+                    SimpleGDParameterUpdate::sumLocal, 
SimpleGDParameterUpdate::avg))
+                .withMaxIterations(30000)
+                .withLocIterations(100)
+                .withBatchSize(10)
+                .withSeed(123L);
+
+        trainer.withEnvironmentBuilder(TestUtils.testEnvBuilder());
+
+        DatasetTrainer<ModelsComposition, Double> baggedTrainer =
+            TrainerTransformers.makeBagged(
+                trainer,
+                10,
+                0.7,
+                2,
+                2,
+                new OnMajorityPredictionsAggregator());
+
+        ModelsComposition mdl = baggedTrainer.fit(
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), 
PRECISION);
+        TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), 
PRECISION);
+    }
+
+    /**
+     * Method used to test counts of data passed in context and in data 
builders.
+     *
+     * @param cntr Function specifying which data we should count.
+     */
+    protected void count(IgniteTriFunction<Long, CountData, 
LearningEnvironment, Long> cntr) {
+        Map<Integer, Double[]> cacheMock = 
getCacheMock(twoLinearlySeparableClasses);
+
+        CountTrainer cntTrainer = new CountTrainer(cntr);
+
+        double subsampleRatio = 0.3;
+
+        ModelsComposition mdl = TrainerTransformers.makeBagged(
+            cntTrainer,
+            100,
+            subsampleRatio,
+            2,
+            2,
+            new MeanValuePredictionsAggregator())
+            .fit(cacheMock, parts, null, null);
+
+        Double res = mdl.apply(null);
+
+        TestUtils.assertEquals(twoLinearlySeparableClasses.length * 
subsampleRatio, res, twoLinearlySeparableClasses.length / 10);
+    }
+
+    /**
+     * Get sum of two Long values each of which can be null.
+     *
+     * @param a First value.
+     * @param b Second value.
+     * @return Sum of parameters.
+     */
+    protected static Long plusOfNullables(Long a, Long b) {
+        if (a == null)
+            return b;
+
+        if (b == null)
+            return a;
+
+        return a + b;
+    }
+
+    /**
+     * Trainer used to count entries in context or in data.
+     */
+    protected static class CountTrainer extends DatasetTrainer<Model<Vector, 
Double>, Double> {
+        /**
+         * Function specifying which entries to count.
+         */
+        private final IgniteTriFunction<Long, CountData, LearningEnvironment, 
Long> cntr;
+
+        /**
+         * Construct instance of this class.
+         *
+         * @param cntr Function specifying which entries to count.
+         */
+        public CountTrainer(IgniteTriFunction<Long, CountData, 
LearningEnvironment, Long> cntr) {
+            this.cntr = cntr;
+        }
+
+        /** {@inheritDoc} */
+        @Override public <K, V> Model<Vector, Double> fit(
+            DatasetBuilder<K, V> datasetBuilder,
+            IgniteBiFunction<K, V, Vector> featureExtractor,
+            IgniteBiFunction<K, V, Double> lbExtractor) {
+            Dataset<Long, CountData> dataset = datasetBuilder.build(
+                TestUtils.testEnvBuilder(),
+                (env, upstreamData, upstreamDataSize) -> upstreamDataSize,
+                (env, upstreamData, upstreamDataSize, ctx) -> new 
CountData(upstreamDataSize)
+            );
+
+            Long cnt = dataset.computeWithCtx(cntr, 
BaggingTest::plusOfNullables);
+
+            return x -> Double.valueOf(cnt);
+        }
+
+        /** {@inheritDoc} */
+        @Override protected boolean checkState(Model<Vector, Double> mdl) {
+            return true;
+        }
+
+        /** {@inheritDoc} */
+        @Override protected <K, V> Model<Vector, Double> updateModel(
+            Model<Vector, Double> mdl,
+            DatasetBuilder<K, V> datasetBuilder,
+            IgniteBiFunction<K, V, Vector> featureExtractor, 
IgniteBiFunction<K, V, Double> lbExtractor) {
+            return fit(datasetBuilder, featureExtractor, lbExtractor);
+        }
+
+        /** {@inheritDoc} */
+        @Override public CountTrainer 
withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+            return (CountTrainer)super.withEnvironmentBuilder(envBuilder);
+        }
+    }
+
+    /** Data for count trainer. */
+    protected static class CountData implements AutoCloseable {
+        /** Counter. */
+        private long cnt;
+
+        /**
+         * Construct instance of this class.
+         *
+         * @param cnt Counter.
+         */
+        public CountData(long cnt) {
+            this.cnt = cnt;
+        }
+
+        /** {@inheritDoc} */
+        @Override public void close() throws Exception {
+            // No-op
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java
index 8714eb2..87d56cd 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java
@@ -25,13 +25,15 @@ import org.junit.runner.RunWith;
 import org.junit.runners.Suite;
 
 /**
- * Test suite for all tests located in org.apache.ignite.ml.composition 
package.
+ * Test suite for all ensemble models tests.
  */
 @RunWith(Suite.class)
 @Suite.SuiteClasses({
     GDBTrainerTest.class,
     MeanValuePredictionsAggregatorTest.class,
     OnMajorityPredictionsAggregatorTest.class,
+    BaggingTest.class,
+    StackingTest.class,
     WeightedPredictionsAggregatorTest.class
 })
 public class CompositionTestSuite {

http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java
new file mode 100644
index 0000000..3336470
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition;
+
+import java.util.Arrays;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer;
+import org.apache.ignite.ml.composition.stacking.StackedModel;
+import org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.matrix.Matrix;
+import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.nn.Activators;
+import org.apache.ignite.ml.nn.MLPTrainer;
+import org.apache.ignite.ml.nn.MultilayerPerceptron;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.optimization.SmoothParametrized;
+import 
org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import 
org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.ml.trainers.AdaptableDatasetModel;
+import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import static junit.framework.TestCase.assertEquals;
+
+/**
+ * Tests stacked trainers.
+ */
+public class StackingTest extends TrainerTest {
+    /** Rule to check exceptions. */
+    @Rule
+    public ExpectedException thrown = ExpectedException.none();
+
+    /**
+     * Tests simple stack training.
+     */
+    @Test
+    public void testSimpleStack() {
+        StackedDatasetTrainer<Vector, Vector, Double, LinearRegressionModel, 
Double> trainer =
+            new StackedDatasetTrainer<>();
+
+        UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> 
updatesStgy = new UpdatesStrategy<>(
+            new SimpleGDUpdateCalculator(0.2),
+            SimpleGDParameterUpdate::sumLocal,
+            SimpleGDParameterUpdate::avg
+        );
+
+        MLPArchitecture arch = new MLPArchitecture(2).
+            withAddedLayer(10, true, Activators.RELU).
+            withAddedLayer(1, false, Activators.SIGMOID);
+
+        MLPTrainer<SimpleGDParameterUpdate> trainer1 = new MLPTrainer<>(
+            arch,
+            LossFunctions.MSE,
+            updatesStgy,
+            3000,
+            10,
+            50,
+            123L
+        );
+
+        // Convert model trainer to produce Vector -> Vector model
+        DatasetTrainer<AdaptableDatasetModel<Vector, Vector, Matrix, Matrix, 
MultilayerPerceptron>, Double> mlpTrainer =
+            AdaptableDatasetTrainer.of(trainer1)
+                .beforeTrainedModel((Vector v) -> new DenseMatrix(v.asArray(), 
1))
+                .afterTrainedModel((Matrix mtx) -> mtx.getRow(0))
+                .withConvertedLabels(VectorUtils::num2Arr);
+
+        final double factor = 3;
+
+        StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = 
trainer
+            .withAggregatorTrainer(new 
LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor))
+            .addTrainer(mlpTrainer)
+            .withAggregatorInputMerger(VectorUtils::concat)
+            .withSubmodelOutput2VectorConverter(IgniteFunction.identity())
+            .withVector2SubmodelInputConverter(IgniteFunction.identity())
+            .withOriginalFeaturesKept(IgniteFunction.identity())
+            .withEnvironmentBuilder(TestUtils.testEnvBuilder())
+            .fit(getCacheMock(xor),
+                parts,
+                (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 
1)),
+                (k, v) -> v[v.length - 1]);
+
+        assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(0.0, 0.0)), 0.3);
+        assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(0.0, 1.0)), 0.3);
+        assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(1.0, 0.0)), 0.3);
+        assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(1.0, 1.0)), 0.3);
+    }
+
+    /**
+     * Tests simple stack training.
+     */
+    @Test
+    public void testSimpleVectorStack() {
+        StackedVectorDatasetTrainer<Double, LinearRegressionModel, Double> 
trainer =
+            new StackedVectorDatasetTrainer<>();
+
+        UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> 
updatesStgy = new UpdatesStrategy<>(
+            new SimpleGDUpdateCalculator(0.2),
+            SimpleGDParameterUpdate::sumLocal,
+            SimpleGDParameterUpdate::avg
+        );
+
+        MLPArchitecture arch = new MLPArchitecture(2).
+            withAddedLayer(10, true, Activators.RELU).
+            withAddedLayer(1, false, Activators.SIGMOID);
+
+        DatasetTrainer<MultilayerPerceptron, Double> mlpTrainer = new 
MLPTrainer<>(
+            arch,
+            LossFunctions.MSE,
+            updatesStgy,
+            3000,
+            10,
+            50,
+            123L
+        ).withConvertedLabels(VectorUtils::num2Arr);
+
+        final double factor = 3;
+
+        StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = 
trainer
+            .withAggregatorTrainer(new 
LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor))
+            .addMatrix2MatrixTrainer(mlpTrainer)
+            .withEnvironmentBuilder(TestUtils.testEnvBuilder())
+            .fit(getCacheMock(xor),
+                parts,
+                (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 
1)),
+                (k, v) -> v[v.length - 1]);
+
+        assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(0.0, 0.0)), 0.3);
+        assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(0.0, 1.0)), 0.3);
+        assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(1.0, 0.0)), 0.3);
+        assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(1.0, 1.0)), 0.3);
+    }
+
+    /**
+     * Tests that if there is no any way for input of first layer to propagate 
to second layer,
+     * exception will be thrown.
+     */
+    @Test
+    public void testINoWaysOfPropagation() {
+        StackedDatasetTrainer<Void, Void, Void, Model<Void, Void>, Void> 
trainer =
+            new StackedDatasetTrainer<>();
+        thrown.expect(IllegalStateException.class);
+        trainer.fit(null, null, null);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
index 61f9fc4..74841a3 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
@@ -47,7 +47,7 @@ public class OneVsRestTrainerTest extends TrainerTest {
         for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
             cacheMock.put(i, twoLinearlySeparableClasses[i]);
 
-        LogisticRegressionSGDTrainer<?> binaryTrainer = new 
LogisticRegressionSGDTrainer<>()
+        LogisticRegressionSGDTrainer binaryTrainer = new 
LogisticRegressionSGDTrainer()
             .withUpdatesStgy(new UpdatesStrategy<>(new 
SimpleGDUpdateCalculator(0.2),
                 SimpleGDParameterUpdate::sumLocal, 
SimpleGDParameterUpdate::avg))
             .withMaxIterations(1000)
@@ -80,7 +80,7 @@ public class OneVsRestTrainerTest extends TrainerTest {
         for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
             cacheMock.put(i, twoLinearlySeparableClasses[i]);
 
-        LogisticRegressionSGDTrainer<?> binaryTrainer = new 
LogisticRegressionSGDTrainer<>()
+        LogisticRegressionSGDTrainer binaryTrainer = new 
LogisticRegressionSGDTrainer()
             .withUpdatesStgy(new UpdatesStrategy<>(new 
SimpleGDUpdateCalculator(0.2),
                 SimpleGDParameterUpdate::sumLocal, 
SimpleGDParameterUpdate::avg))
             .withMaxIterations(1000)

http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java
index bd31b19..5ee50a6 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java
@@ -106,7 +106,7 @@ public class MLPTrainerMnistIntegrationTest extends 
GridCommonAbstractTest {
             ignite,
             trainingSet,
             (k, v) -> VectorUtils.of(v.getPixels()),
-            (k, v) -> VectorUtils.num2Vec(v.getLabel(), 10).getStorage().data()
+            (k, v) -> VectorUtils.oneHot(v.getLabel(), 10).getStorage().data()
         );
         System.out.println("Training completed in " + 
(System.currentTimeMillis() - start) + "ms");
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java
index 6a17d18..9396009 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java
@@ -76,7 +76,7 @@ public class MLPTrainerMnistTest {
             trainingSet,
             1,
             (k, v) -> VectorUtils.of(v.getPixels()),
-            (k, v) -> VectorUtils.num2Vec(v.getLabel(), 10).getStorage().data()
+            (k, v) -> VectorUtils.oneHot(v.getLabel(), 10).getStorage().data()
         );
         System.out.println("Training completed in " + 
(System.currentTimeMillis() - start) + "ms");
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
index fec6220..694dcd3 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
@@ -51,7 +51,7 @@ public class PipelineTest extends TrainerTest {
             cacheMock.put(i, convertedRow);
         }
 
-        LogisticRegressionSGDTrainer<?> trainer = new 
LogisticRegressionSGDTrainer<>()
+        LogisticRegressionSGDTrainer trainer = new 
LogisticRegressionSGDTrainer()
             .withUpdatesStgy(new UpdatesStrategy<>(new 
SimpleGDUpdateCalculator(0.2),
                 SimpleGDParameterUpdate::sumLocal, 
SimpleGDParameterUpdate::avg))
             .withMaxIterations(100000)

http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
index c343ab9..681cb72 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
@@ -43,7 +43,7 @@ public class LogisticRegressionSGDTrainerTest extends 
TrainerTest {
         for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
             cacheMock.put(i, twoLinearlySeparableClasses[i]);
 
-        LogisticRegressionSGDTrainer<?> trainer = new 
LogisticRegressionSGDTrainer<>()
+        LogisticRegressionSGDTrainer trainer = new 
LogisticRegressionSGDTrainer()
             .withUpdatesStgy(new UpdatesStrategy<>(new 
SimpleGDUpdateCalculator(0.2),
                 SimpleGDParameterUpdate::sumLocal, 
SimpleGDParameterUpdate::avg))
             .withMaxIterations(100000)
@@ -70,7 +70,7 @@ public class LogisticRegressionSGDTrainerTest extends 
TrainerTest {
         for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
             cacheMock.put(i, twoLinearlySeparableClasses[i]);
 
-        LogisticRegressionSGDTrainer<?> trainer = new 
LogisticRegressionSGDTrainer<>()
+        LogisticRegressionSGDTrainer trainer = new 
LogisticRegressionSGDTrainer()
             .withUpdatesStgy(new UpdatesStrategy<>(new 
SimpleGDUpdateCalculator(0.2),
                 SimpleGDParameterUpdate::sumLocal, 
SimpleGDParameterUpdate::avg))
             .withMaxIterations(100000)

http://git-wip-us.apache.org/repos/asf/ignite/blob/142648df/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java
deleted file mode 100644
index 31fe8b3..0000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java
+++ /dev/null
@@ -1,235 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trainers;
-
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
-import org.apache.ignite.ml.Model;
-import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.common.TrainerTest;
-import org.apache.ignite.ml.composition.ModelsComposition;
-import 
org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
-import 
org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
-import org.apache.ignite.ml.dataset.Dataset;
-import org.apache.ignite.ml.dataset.DatasetBuilder;
-import org.apache.ignite.ml.environment.LearningEnvironment;
-import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.math.functions.IgniteTriFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.nn.UpdatesStrategy;
-import 
org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
-import 
org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
-import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
-import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
-import org.junit.Test;
-
-/**
- * Tests for bagging algorithm.
- */
-public class BaggingTest extends TrainerTest {
-    /**
-     * Test that count of entries in context is equal to initial dataset size 
* subsampleRatio.
-     */
-    @Test
-    public void testBaggingContextCount() {
-        count((ctxCount, countData, integer) -> ctxCount);
-    }
-
-    /**
-     * Test that count of entries in data is equal to initial dataset size * 
subsampleRatio.
-     */
-    @Test
-    public void testBaggingDataCount() {
-        count((ctxCount, countData, integer) -> countData.cnt);
-    }
-
-    /**
-     * Test that bagged log regression makes correct predictions.
-     */
-    @Test
-    public void testNaiveBaggingLogRegression() {
-        Map<Integer, Double[]> cacheMock = getCacheMock();
-
-        DatasetTrainer<LogisticRegressionModel, Double> trainer =
-            (LogisticRegressionSGDTrainer<?>)new 
LogisticRegressionSGDTrainer<>()
-                .withUpdatesStgy(new UpdatesStrategy<>(new 
SimpleGDUpdateCalculator(0.2),
-                    SimpleGDParameterUpdate::sumLocal, 
SimpleGDParameterUpdate::avg))
-                .withMaxIterations(30000)
-                .withLocIterations(100)
-                .withBatchSize(10)
-                .withSeed(123L);
-
-        trainer.withEnvironmentBuilder(TestUtils.testEnvBuilder());
-
-        DatasetTrainer<ModelsComposition, Double> baggedTrainer =
-            TrainerTransformers.makeBagged(
-                trainer,
-                10,
-                0.7,
-                2,
-                2,
-                new OnMajorityPredictionsAggregator());
-
-        ModelsComposition mdl = baggedTrainer.fit(
-            cacheMock,
-            parts,
-            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
-            (k, v) -> v[0]
-        );
-
-        TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), 
PRECISION);
-        TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), 
PRECISION);
-    }
-
-    /**
-     * Method used to test counts of data passed in context and in data 
builders.
-     *
-     * @param counter Function specifying which data we should count.
-     */
-    protected void count(IgniteTriFunction<Long, CountData, 
LearningEnvironment, Long> counter) {
-        Map<Integer, Double[]> cacheMock = getCacheMock();
-
-        CountTrainer countTrainer = new CountTrainer(counter);
-
-        double subsampleRatio = 0.3;
-
-        ModelsComposition model = TrainerTransformers.makeBagged(
-            countTrainer,
-            100,
-            subsampleRatio,
-            2,
-            2,
-            new MeanValuePredictionsAggregator())
-            .fit(cacheMock, parts, null, null);
-
-        Double res = model.apply(null);
-
-        TestUtils.assertEquals(twoLinearlySeparableClasses.length * 
subsampleRatio, res, twoLinearlySeparableClasses.length / 10);
-    }
-
-    /**
-     * Create cache mock.
-     *
-     * @return Cache mock.
-     */
-    private Map<Integer, Double[]> getCacheMock() {
-        Map<Integer, Double[]> cacheMock = new HashMap<>();
-
-        for (int i = 0; i < twoLinearlySeparableClasses.length; i++) {
-            double[] row = twoLinearlySeparableClasses[i];
-            Double[] convertedRow = new Double[row.length];
-            for (int j = 0; j < row.length; j++)
-                convertedRow[j] = row[j];
-            cacheMock.put(i, convertedRow);
-        }
-        return cacheMock;
-    }
-
-    /**
-     * Get sum of two Long values each of which can be null.
-     *
-     * @param a First value.
-     * @param b Second value.
-     * @return Sum of parameters.
-     */
-    protected static Long plusOfNullables(Long a, Long b) {
-        if (a == null)
-            return b;
-
-        if (b == null)
-            return a;
-
-        return a + b;
-    }
-
-    /**
-     * Trainer used to count entries in context or in data.
-     */
-    protected static class CountTrainer extends DatasetTrainer<Model<Vector, 
Double>, Double> {
-        /**
-         * Function specifying which entries to count.
-         */
-        private final IgniteTriFunction<Long, CountData, LearningEnvironment, 
Long> counter;
-
-        /**
-         * Construct instance of this class.
-         *
-         * @param counter Function specifying which entries to count.
-         */
-        public CountTrainer(IgniteTriFunction<Long, CountData, 
LearningEnvironment, Long> counter) {
-            this.counter = counter;
-        }
-
-        /** {@inheritDoc} */
-        @Override public <K, V> Model<Vector, Double> fit(
-            DatasetBuilder<K, V> datasetBuilder,
-            IgniteBiFunction<K, V, Vector> featureExtractor,
-            IgniteBiFunction<K, V, Double> lbExtractor) {
-            Dataset<Long, CountData> dataset = datasetBuilder.build(
-                TestUtils.testEnvBuilder(),
-                (env, upstreamData, upstreamDataSize) -> upstreamDataSize,
-                (env, upstreamData, upstreamDataSize, ctx) -> new 
CountData(upstreamDataSize)
-            );
-
-            Long cnt = dataset.computeWithCtx(counter, 
BaggingTest::plusOfNullables);
-
-            return x -> Double.valueOf(cnt);
-        }
-
-        /** {@inheritDoc} */
-        @Override protected boolean checkState(Model<Vector, Double> mdl) {
-            return true;
-        }
-
-        /** {@inheritDoc} */
-        @Override protected <K, V> Model<Vector, Double> updateModel(
-            Model<Vector, Double> mdl,
-            DatasetBuilder<K, V> datasetBuilder,
-            IgniteBiFunction<K, V, Vector> featureExtractor, 
IgniteBiFunction<K, V, Double> lbExtractor) {
-            return fit(datasetBuilder, featureExtractor, lbExtractor);
-        }
-
-        /** {@inheritDoc} */
-        @Override public CountTrainer 
withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
-            return (CountTrainer)super.withEnvironmentBuilder(envBuilder);
-        }
-    }
-
-    /** Data for count trainer. */
-    protected static class CountData implements AutoCloseable {
-        /** Counter. */
-        private long cnt;
-
-        /**
-         * Construct instance of this class.
-         *
-         * @param cnt Counter.
-         */
-        public CountData(long cnt) {
-            this.cnt = cnt;
-        }
-
-        /** {@inheritDoc} */
-        @Override public void close() throws Exception {
-            // No-op
-        }
-    }
-}

Reply via email to