http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java index b720695..b743a37 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java @@ -20,6 +20,7 @@ package org.apache.ignite.ml.math.isolve.lsqr; import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; @@ -47,6 +48,7 @@ public class LSQROnHeapTest extends TrainerTest { LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>( datasetBuilder, + TestUtils.testEnvBuilder(), new SimpleLabeledDatasetDataBuilder<>( (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)), (k, v) -> new double[]{v[3]} @@ -80,6 +82,7 @@ public class LSQROnHeapTest extends TrainerTest { LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>( datasetBuilder, + TestUtils.testEnvBuilder(), new SimpleLabeledDatasetDataBuilder<>( (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)), (k, v) -> new double[]{v[3]} @@ -113,6 +116,7 @@ public class LSQROnHeapTest extends TrainerTest { try (LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>( datasetBuilder, + TestUtils.testEnvBuilder(), new SimpleLabeledDatasetDataBuilder<>( (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)), (k, v) -> new double[]{v[4]}
http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java index 4b7fa33..b611104 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.preprocessing.binarization; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; @@ -51,6 +52,7 @@ public class BinarizationTrainerTest extends TrainerTest { assertEquals(10., binarizationTrainer.getThreshold(), 0); BinarizationPreprocessor<Integer, double[]> preprocessor = binarizationTrainer.fit( + TestUtils.testEnvBuilder(), datasetBuilder, (k, v) -> VectorUtils.of(v) ); @@ -75,6 +77,7 @@ public class BinarizationTrainerTest extends TrainerTest { assertEquals(10., binarizationTrainer.getThreshold(), 0); IgniteBiFunction<Integer, double[], Vector> preprocessor = binarizationTrainer.fit( + TestUtils.testEnvBuilder(), data, parts, (k, v) -> VectorUtils.of(v) http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java index 7c7eabe..f9d56a9 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.preprocessing.encoding; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; @@ -51,6 +52,7 @@ public class EncoderTrainerTest extends TrainerTest { .withEncodedFeature(1); EncoderPreprocessor<Integer, String[]> preprocessor = strEncoderTrainer.fit( + TestUtils.testEnvBuilder(), datasetBuilder, (k, v) -> v ); @@ -77,6 +79,7 @@ public class EncoderTrainerTest extends TrainerTest { .withEncodedFeature(1); EncoderPreprocessor<Integer, Object[]> preprocessor = strEncoderTrainer.fit( + TestUtils.testEnvBuilder(), datasetBuilder, (k, v) -> v ); @@ -103,6 +106,7 @@ public class EncoderTrainerTest extends TrainerTest { .withEncodedFeature(1); EncoderPreprocessor<Integer, Object[]> preprocessor = strEncoderTrainer.fit( + TestUtils.testEnvBuilder(), datasetBuilder, (k, v) -> v ); @@ -136,6 +140,7 @@ public class EncoderTrainerTest extends TrainerTest { .withEncodedFeature(1); EncoderPreprocessor<Integer, String[]> preprocessor = strEncoderTrainer.fit( + TestUtils.testEnvBuilder(), datasetBuilder, (k, v) -> v ); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java index 9c11d13..f8a5e78 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.preprocessing.imputing; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; @@ -47,6 +48,7 @@ public class ImputerTrainerTest extends TrainerTest { .withImputingStrategy(ImputingStrategy.MOST_FREQUENT); ImputerPreprocessor<Integer, Vector> preprocessor = imputerTrainer.fit( + TestUtils.testEnvBuilder(), datasetBuilder, (k, v) -> v ); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainerTest.java index 844468e..fc3433b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainerTest.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.preprocessing.maxabsscaling; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; @@ -46,6 +47,7 @@ public class MaxAbsScalerTrainerTest extends TrainerTest { MaxAbsScalerTrainer<Integer, Vector> standardizationTrainer = new MaxAbsScalerTrainer<>(); MaxAbsScalerPreprocessor<Integer, Vector> preprocessor = standardizationTrainer.fit( + TestUtils.testEnvBuilder(), datasetBuilder, (k, v) -> v ); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java index 4c0a99f..8716324 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.preprocessing.minmaxscaling; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; @@ -46,6 +47,7 @@ public class MinMaxScalerTrainerTest extends TrainerTest { MinMaxScalerTrainer<Integer, Vector> standardizationTrainer = new MinMaxScalerTrainer<>(); MinMaxScalerPreprocessor<Integer, Vector> preprocessor = standardizationTrainer.fit( + TestUtils.testEnvBuilder(), datasetBuilder, (k, v) -> v ); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java index 9d39354..d8a8191 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.preprocessing.normalization; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; @@ -50,6 +51,7 @@ public class NormalizationTrainerTest extends TrainerTest { assertEquals(3., normalizationTrainer.p(), 0); NormalizationPreprocessor<Integer, double[]> preprocessor = normalizationTrainer.fit( + TestUtils.testEnvBuilder(), datasetBuilder, (k, v) -> VectorUtils.of(v) ); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainerTest.java index 6f10b37..839cb20 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainerTest.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.preprocessing.standardscaling; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; @@ -62,6 +63,7 @@ public class StandardScalerTrainerTest extends TrainerTest { double[] expectedMeans = new double[] {0.5, 1.75, 4.5, 0.875}; StandardScalerPreprocessor<Integer, Vector> preprocessor = standardizationTrainer.fit( + TestUtils.testEnvBuilder(), datasetBuilder, (k, v) -> v ); @@ -75,6 +77,7 @@ public class StandardScalerTrainerTest extends TrainerTest { double[] expectedSigmas = new double[] {0.5, 1.47901995, 14.51723114, 0.93374247}; StandardScalerPreprocessor<Integer, Vector> preprocessor = standardizationTrainer.fit( + TestUtils.testEnvBuilder(), datasetBuilder, (k, v) -> v ); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java index 6f7aa36..1abf7f0 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java @@ -48,6 +48,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; import org.apache.ignite.thread.IgniteThread; +import static org.apache.ignite.ml.TestUtils.testEnvBuilder; import static org.junit.Assert.assertArrayEquals; /** @@ -288,19 +289,24 @@ public class EvaluatorTest extends GridCommonAbstractTest { .withEncoderType(EncoderType.STRING_ENCODER) .withEncodedFeature(1) .withEncodedFeature(6) // <--- Changed index here - .fit(ignite, + .fit( + testEnvBuilder(123L), + ignite, cache, featureExtractor ); IgniteBiFunction<Integer, Object[], Vector> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>() - .fit(ignite, + .fit( + testEnvBuilder(124L), + ignite, cache, strEncoderPreprocessor ); IgniteBiFunction<Integer, Object[], Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>() .fit( + testEnvBuilder(125L), ignite, cache, imputingPreprocessor @@ -309,6 +315,7 @@ public class EvaluatorTest extends GridCommonAbstractTest { return new NormalizationTrainer<Integer, Object[]>() .withP(2) .fit( + testEnvBuilder(126L), ignite, cache, minMaxScalerPreprocessor http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java index a82374b..1b96ce2 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java @@ -28,6 +28,8 @@ import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictio import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.environment.LearningEnvironment; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.functions.IgniteTriFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; @@ -75,11 +77,15 @@ public class BaggingTest extends TrainerTest { .withBatchSize(10) .withSeed(123L); + trainer.withEnvironmentBuilder(TestUtils.testEnvBuilder()); + DatasetTrainer<ModelsComposition, Double> baggedTrainer = TrainerTransformers.makeBagged( trainer, 10, 0.7, + 2, + 2, new OnMajorityPredictionsAggregator()); ModelsComposition mdl = baggedTrainer.fit( @@ -98,14 +104,20 @@ public class BaggingTest extends TrainerTest { * * @param counter Function specifying which data we should count. */ - protected void count(IgniteTriFunction<Long, CountData, Integer, Long> counter) { + protected void count(IgniteTriFunction<Long, CountData, LearningEnvironment, Long> counter) { Map<Integer, Double[]> cacheMock = getCacheMock(); CountTrainer countTrainer = new CountTrainer(counter); double subsampleRatio = 0.3; - ModelsComposition model = TrainerTransformers.makeBagged(countTrainer, 100, subsampleRatio, new MeanValuePredictionsAggregator()) + ModelsComposition model = TrainerTransformers.makeBagged( + countTrainer, + 100, + subsampleRatio, + 2, + 2, + new MeanValuePredictionsAggregator()) .fit(cacheMock, parts, null, null); Double res = model.apply(null); @@ -155,14 +167,14 @@ public class BaggingTest extends TrainerTest { /** * Function specifying which entries to count. */ - private final IgniteTriFunction<Long, CountData, Integer, Long> counter; + private final IgniteTriFunction<Long, CountData, LearningEnvironment, Long> counter; /** * Construct instance of this class. * * @param counter Function specifying which entries to count. */ - public CountTrainer(IgniteTriFunction<Long, CountData, Integer, Long> counter) { + public CountTrainer(IgniteTriFunction<Long, CountData, LearningEnvironment, Long> counter) { this.counter = counter; } @@ -172,8 +184,9 @@ public class BaggingTest extends TrainerTest { IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { Dataset<Long, CountData> dataset = datasetBuilder.build( - (upstreamData, upstreamDataSize) -> upstreamDataSize, - (upstreamData, upstreamDataSize, ctx) -> new CountData(upstreamDataSize) + TestUtils.testEnvBuilder(), + (env, upstreamData, upstreamDataSize) -> upstreamDataSize, + (env, upstreamData, upstreamDataSize, ctx) -> new CountData(upstreamDataSize) ); Long cnt = dataset.computeWithCtx(counter, BaggingTest::plusOfNullables); @@ -193,6 +206,11 @@ public class BaggingTest extends TrainerTest { IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { return fit(datasetBuilder, featureExtractor, lbExtractor); } + + /** {@inheritDoc} */ + @Override public CountTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (CountTrainer)super.withEnvironmentBuilder(envBuilder); + } } /** Data for count trainer. */
