IGNITE-10272: [ML] Inject learning environment into scope of dataset compute task
This closes #5484 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/ff6b8eed Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/ff6b8eed Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/ff6b8eed Branch: refs/heads/ignite-10044 Commit: ff6b8eed8401590792b1d11b91b0039ddd03f958 Parents: c6a05f8 Author: Artem Malykh <[email protected]> Authored: Tue Dec 4 14:57:31 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Tue Dec 4 14:57:31 2018 +0300 ---------------------------------------------------------------------- .../AlgorithmSpecificDatasetExample.java | 2 +- ...ggedLogisticRegressionSGDTrainerExample.java | 3 +- .../RandomForestRegressionExample.java | 9 +- .../ml/clustering/kmeans/KMeansTrainer.java | 9 +- .../boosting/GDBBinaryClassifierTrainer.java | 12 +- .../boosting/GDBLearningStrategy.java | 21 ++- .../boosting/GDBRegressionTrainer.java | 6 + .../ml/composition/boosting/GDBTrainer.java | 20 ++- .../convergence/ConvergenceChecker.java | 8 +- .../simple/ConvergenceCheckerStub.java | 3 +- .../org/apache/ignite/ml/dataset/Dataset.java | 61 ++++--- .../ignite/ml/dataset/DatasetBuilder.java | 17 +- .../ignite/ml/dataset/DatasetFactory.java | 174 ++++++++++++++++-- .../ml/dataset/PartitionContextBuilder.java | 11 +- .../ignite/ml/dataset/PartitionDataBuilder.java | 12 +- .../ignite/ml/dataset/UpstreamTransformer.java | 22 ++- .../ml/dataset/UpstreamTransformerBuilder.java | 72 ++++++++ .../ml/dataset/UpstreamTransformerChain.java | 153 ---------------- .../BootstrappedDatasetBuilder.java | 15 +- .../dataset/impl/cache/CacheBasedDataset.java | 49 +++-- .../impl/cache/CacheBasedDatasetBuilder.java | 40 +++-- .../dataset/impl/cache/util/ComputeUtils.java | 128 ++++++++----- .../ml/dataset/impl/local/LocalDataset.java | 18 +- .../dataset/impl/local/LocalDatasetBuilder.java | 118 ++++++------ .../ml/dataset/primitive/DatasetWrapper.java | 5 +- ...eatureMatrixWithLabelsOnHeapDataBuilder.java | 7 +- .../builder/context/EmptyContextBuilder.java | 3 +- .../builder/data/SimpleDatasetDataBuilder.java | 5 +- .../data/SimpleLabeledDatasetDataBuilder.java | 5 +- .../DefaultLearningEnvironmentBuilder.java | 178 +++++++++++++++++++ .../ml/environment/LearningEnvironment.java | 19 +- .../environment/LearningEnvironmentBuilder.java | 167 +++++++++-------- .../ml/environment/logging/ConsoleLogger.java | 3 + .../parallelism/ParallelismStrategy.java | 1 - .../java/org/apache/ignite/ml/knn/KNNUtils.java | 9 +- .../ml/knn/ann/ANNClassificationTrainer.java | 10 +- .../KNNClassificationTrainer.java | 8 +- .../ml/knn/regression/KNNRegressionTrainer.java | 9 +- .../ml/math/functions/IgniteFunction.java | 12 +- .../ignite/ml/math/isolve/lsqr/LSQROnHeap.java | 6 +- .../ml/math/primitives/vector/VectorUtils.java | 11 ++ .../ignite/ml/multiclass/OneVsRestTrainer.java | 3 +- .../gaussian/GaussianNaiveBayesTrainer.java | 11 +- .../org/apache/ignite/ml/nn/MLPTrainer.java | 1 + .../org/apache/ignite/ml/pipeline/Pipeline.java | 14 ++ .../ml/preprocessing/PreprocessingTrainer.java | 66 ++++++- .../binarization/BinarizationTrainer.java | 5 +- .../preprocessing/encoding/EncoderTrainer.java | 12 +- .../preprocessing/imputing/ImputerTrainer.java | 10 +- .../maxabsscaling/MaxAbsScalerTrainer.java | 10 +- .../minmaxscaling/MinMaxScalerTrainer.java | 12 +- .../normalization/NormalizationTrainer.java | 5 +- .../standardscaling/StandardScalerTrainer.java | 14 +- .../linear/LinearRegressionLSQRTrainer.java | 1 + .../LogRegressionMultiClassTrainer.java | 3 +- .../LabelPartitionDataBuilderOnHeap.java | 8 +- ...abeledDatasetPartitionDataBuilderOnHeap.java | 7 +- .../SVMLinearBinaryClassificationTrainer.java | 3 +- ...VMLinearMultiClassClassificationTrainer.java | 3 +- .../ignite/ml/trainers/DatasetTrainer.java | 28 ++- .../ignite/ml/trainers/TrainerTransformers.java | 68 +------ .../BaggingUpstreamTransformer.java | 29 ++- .../org/apache/ignite/ml/tree/DecisionTree.java | 7 + .../tree/DecisionTreeClassificationTrainer.java | 6 + .../ml/tree/DecisionTreeRegressionTrainer.java | 6 + .../GDBBinaryClassifierOnTreesTrainer.java | 6 + .../boosting/GDBOnTreesLearningStrategy.java | 3 +- .../boosting/GDBRegressionOnTreesTrainer.java | 6 + .../ml/tree/data/DecisionTreeDataBuilder.java | 7 +- .../tree/randomforest/RandomForestTrainer.java | 1 + .../java/org/apache/ignite/ml/TestUtils.java | 79 ++++++++ .../MeanAbsValueConvergenceCheckerTest.java | 10 +- .../MedianOfMedianConvergenceCheckerTest.java | 10 +- .../cache/CacheBasedDatasetBuilderTest.java | 11 +- .../impl/cache/CacheBasedDatasetTest.java | 11 +- .../impl/cache/util/ComputeUtilsTest.java | 16 +- .../impl/local/LocalDatasetBuilderTest.java | 14 +- .../ml/dataset/primitive/SimpleDatasetTest.java | 2 + .../primitive/SimpleLabeledDatasetTest.java | 11 +- .../LearningEnvironmentBuilderTest.java | 36 ++-- .../ml/environment/LearningEnvironmentTest.java | 130 +++++++++++++- .../ml/math/isolve/lsqr/LSQROnHeapTest.java | 4 + .../binarization/BinarizationTrainerTest.java | 3 + .../encoding/EncoderTrainerTest.java | 5 + .../imputing/ImputerTrainerTest.java | 2 + .../maxabsscaling/MaxAbsScalerTrainerTest.java | 2 + .../minmaxscaling/MinMaxScalerTrainerTest.java | 2 + .../normalization/NormalizationTrainerTest.java | 2 + .../StandardScalerTrainerTest.java | 3 + .../scoring/evaluator/EvaluatorTest.java | 11 +- .../apache/ignite/ml/trainers/BaggingTest.java | 30 +++- 91 files changed, 1553 insertions(+), 637 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/examples/src/main/java/org/apache/ignite/examples/ml/dataset/AlgorithmSpecificDatasetExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/dataset/AlgorithmSpecificDatasetExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/dataset/AlgorithmSpecificDatasetExample.java index 4d42d19..5148d9a 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/dataset/AlgorithmSpecificDatasetExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/dataset/AlgorithmSpecificDatasetExample.java @@ -73,7 +73,7 @@ public class AlgorithmSpecificDatasetExample { try (AlgorithmSpecificDataset dataset = DatasetFactory.create( ignite, persons, - (upstream, upstreamSize) -> new AlgorithmSpecificPartitionContext(), + (env, upstream, upstreamSize) -> new AlgorithmSpecificPartitionContext(), new SimpleLabeledDatasetDataBuilder<Integer, Person, AlgorithmSpecificPartitionContext>( (k, v) -> VectorUtils.of(v.getAge()), (k, v) -> new double[] {v.getSalary()} http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java index baf513a..44fb77e 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java @@ -81,8 +81,7 @@ public class BaggedLogisticRegressionSGDTrainerExample { 0.6, 4, 3, - new OnMajorityPredictionsAggregator(), - 123L); + new OnMajorityPredictionsAggregator()); System.out.println(">>> Perform evaluation of the model."); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java index 3bf2c8e..a3c33cb 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java @@ -31,7 +31,7 @@ import org.apache.ignite.examples.ml.util.MLSandboxDatasets; import org.apache.ignite.examples.ml.util.SandboxMLCache; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.dataset.feature.FeatureMeta; -import org.apache.ignite.ml.environment.LearningEnvironment; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.environment.logging.ConsoleLogger; import org.apache.ignite.ml.environment.logging.MLLogger; import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy; @@ -80,10 +80,9 @@ public class RandomForestRegressionExample { .withSubSampleSize(0.3) .withSeed(0); - trainer.setEnvironment(LearningEnvironment.builder() - .withParallelismStrategy(ParallelismStrategy.Type.ON_DEFAULT_POOL) - .withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.LOW)) - .build() + trainer.withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder() + .withParallelismStrategyTypeDependency(part -> ParallelismStrategy.Type.ON_DEFAULT_POOL) + .withLoggingFactoryDependency(part -> ConsoleLogger.factory(MLLogger.VerboseLevel.LOW)) ); System.out.println(">>> Configured trainer: " + trainer.getClass().getSimpleName()); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java index a20d5da..88ea9b9 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java @@ -32,6 +32,7 @@ import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.distances.DistanceMeasure; import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.functions.IgniteBiFunction; @@ -78,6 +79,11 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> { } /** {@inheritDoc} */ + @Override public KMeansTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (KMeansTrainer)super.withEnvironmentBuilder(envBuilder); + } + + /** {@inheritDoc} */ @Override protected <K, V> KMeansModel updateModel(KMeansModel mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { @@ -91,7 +97,8 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> { Vector[] centers; try (Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new EmptyContext(), + envBuilder, + (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder )) { final Integer cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> { http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java index 8682a46..3acca14 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java @@ -25,12 +25,14 @@ import org.apache.ignite.ml.composition.boosting.loss.LogLoss; import org.apache.ignite.ml.composition.boosting.loss.Loss; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.structures.LabeledVector; import org.apache.ignite.ml.structures.LabeledVectorSet; import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap; +import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer; /** * Trainer for binary classifier using Gradient Boosting. As preparing stage this algorithm learn labels in dataset and @@ -69,7 +71,10 @@ public abstract class GDBBinaryClassifierTrainer extends GDBTrainer { IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lExtractor) { - Set<Double> uniqLabels = builder.build(new EmptyContextBuilder<>(), new LabeledDatasetPartitionDataBuilderOnHeap<>(featureExtractor, lExtractor)) + Set<Double> uniqLabels = builder.build( + envBuilder, + new EmptyContextBuilder<>(), + new LabeledDatasetPartitionDataBuilderOnHeap<>(featureExtractor, lExtractor)) .compute((IgniteFunction<LabeledVectorSet<Double, LabeledVector>, Set<Double>>)x -> Arrays.stream(x.labels()).boxed().collect(Collectors.toSet()), (a, b) -> { if (a == null) @@ -102,4 +107,9 @@ public abstract class GDBBinaryClassifierTrainer extends GDBTrainer { double internalCls = sigma < 0.5 ? 0.0 : 1.0; return internalCls == 0.0 ? externalFirstCls : externalSecondCls; } + + /** {@inheritDoc} */ + @Override public GDBBinaryClassifierOnTreesTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (GDBBinaryClassifierOnTreesTrainer)super.withEnvironmentBuilder(envBuilder); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java index e689b91..0b87748 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java @@ -29,6 +29,7 @@ import org.apache.ignite.ml.composition.boosting.loss.Loss; import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.environment.LearningEnvironment; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.environment.logging.MLLogger; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.functions.IgniteFunction; @@ -41,8 +42,11 @@ import org.jetbrains.annotations.NotNull; * Learning strategy for gradient boosting. */ public class GDBLearningStrategy { - /** Learning environment. */ - protected LearningEnvironment environment; + /** Learning environment builder. */ + protected LearningEnvironmentBuilder envBuilder; + + /** Learning environment used for trainer. */ + protected LearningEnvironment trainerEnvironment; /** Count of iterations. */ protected int cntOfIterations; @@ -101,6 +105,8 @@ public class GDBLearningStrategy { public <K,V> List<Model<Vector, Double>> update(GDBTrainer.GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + if (trainerEnvironment == null) + throw new IllegalStateException("Learning environment builder is not set."); List<Model<Vector, Double>> models = initLearningState(mdlToUpdate); @@ -113,7 +119,7 @@ public class GDBLearningStrategy { WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal); ModelsComposition currComposition = new ModelsComposition(models, aggregator); - if (convCheck.isConverged(datasetBuilder, currComposition)) + if (convCheck.isConverged(envBuilder, datasetBuilder, currComposition)) break; IgniteBiFunction<K, V, Double> lbExtractorWrap = (k, v) -> { @@ -125,7 +131,7 @@ public class GDBLearningStrategy { long startTs = System.currentTimeMillis(); models.add(trainer.fit(datasetBuilder, featureExtractor, lbExtractorWrap)); double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0; - environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime); + trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime); } return models; @@ -157,10 +163,11 @@ public class GDBLearningStrategy { /** * Sets learning environment. * - * @param environment Learning Environment. + * @param envBuilder Learning Environment. */ - public GDBLearningStrategy withEnvironment(LearningEnvironment environment) { - this.environment = environment; + public GDBLearningStrategy withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + this.envBuilder = envBuilder; + this.trainerEnvironment = envBuilder.buildForTrainer(); return this; } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java index 8c1afd7..3dc95ee 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.composition.boosting; import org.apache.ignite.ml.composition.boosting.loss.SquaredError; import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; @@ -53,4 +54,9 @@ public abstract class GDBRegressionTrainer extends GDBTrainer { @Override protected double internalLabelToExternal(double x) { return x; } + + /** {@inheritDoc} */ + @Override public GDBRegressionTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (GDBRegressionTrainer)super.withEnvironmentBuilder(envBuilder); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java index 89cc6b1..03772ec 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java @@ -30,6 +30,7 @@ import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.environment.logging.MLLogger; import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer; import org.apache.ignite.ml.math.functions.IgniteBiFunction; @@ -99,7 +100,11 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl if (!learnLabels(datasetBuilder, featureExtractor, lbExtractor)) return getLastTrainedModelOrThrowEmptyDatasetException(mdl); - IgniteBiTuple<Double, Long> initAndSampleSize = computeInitialValue(datasetBuilder, featureExtractor, lbExtractor); + IgniteBiTuple<Double, Long> initAndSampleSize = computeInitialValue( + envBuilder, + datasetBuilder, + featureExtractor, + lbExtractor); if(initAndSampleSize == null) return getLastTrainedModelOrThrowEmptyDatasetException(mdl); @@ -112,7 +117,7 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl .withBaseModelTrainerBuilder(this::buildBaseModelTrainer) .withExternalLabelToInternal(this::externalLabelToInternal) .withCntOfIterations(cntOfIterations) - .withEnvironment(environment) + .withEnvironmentBuilder(envBuilder) .withLossGradient(loss) .withSampleSize(sampleSize) .withMeanLabelValue(mean) @@ -140,6 +145,11 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl return mdl instanceof GDBModel; } + /** {@inheritDoc} */ + @Override public GDBTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (GDBTrainer)super.withEnvironmentBuilder(envBuilder); + } + /** * Defines unique labels in dataset if need (useful in case of classification). * @@ -175,14 +185,18 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl * Compute mean value of label as first approximation. * * @param builder Dataset builder. + * @param envBuilder Learning environment builder. * @param featureExtractor Feature extractor. * @param lbExtractor Label extractor. */ - protected <V, K> IgniteBiTuple<Double, Long> computeInitialValue(DatasetBuilder<K, V> builder, + protected <V, K> IgniteBiTuple<Double, Long> computeInitialValue( + LearningEnvironmentBuilder envBuilder, + DatasetBuilder<K, V> builder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { try (Dataset<EmptyContext, DecisionTreeData> dataset = builder.build( + envBuilder, new EmptyContextBuilder<>(), new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, false) )) { http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java index 88841e2..e383e39 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java @@ -26,6 +26,7 @@ import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData; import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder; import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; @@ -88,11 +89,16 @@ public abstract class ConvergenceChecker<K, V> implements Serializable { /** * Checks convergency on dataset. * + * @param envBuilder Learning environment builder. * @param currMdl Current model. * @return true if GDB is converged. */ - public boolean isConverged(DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) { + public boolean isConverged( + LearningEnvironmentBuilder envBuilder, + DatasetBuilder<K, V> datasetBuilder, + ModelsComposition currMdl) { try (Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build( + envBuilder, new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(featureExtractor, lbExtractor) )) { http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java index 98cfbe1..193afaf 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java @@ -24,6 +24,7 @@ import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; @@ -60,7 +61,7 @@ public class ConvergenceCheckerStub<K,V> extends ConvergenceChecker<K,V> { } /** {@inheritDoc} */ - @Override public boolean isConverged(DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) { + @Override public boolean isConverged(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) { return false; } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/Dataset.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/Dataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/Dataset.java index 230a467..d821fe3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/Dataset.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/Dataset.java @@ -20,6 +20,7 @@ package org.apache.ignite.ml.dataset; import java.io.Serializable; import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDataset; import org.apache.ignite.ml.dataset.impl.local.LocalDataset; +import org.apache.ignite.ml.environment.LearningEnvironment; import org.apache.ignite.ml.math.functions.IgniteBiConsumer; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.functions.IgniteBinaryOperator; @@ -56,49 +57,50 @@ public interface Dataset<C extends Serializable, D extends AutoCloseable> extend * Applies the specified {@code map} function to every partition {@code data}, {@code context} and partition * index in the dataset and then reduces {@code map} results to final result by using the {@code reduce} function. * - * @param map Function applied to every partition {@code data}, {@code context} and partition index. + * @param map Function applied to every partition {@code data}, {@code context} and {@link LearningEnvironment}. * @param reduce Function applied to results of {@code map} to get final result. * @param identity Identity. * @param <R> Type of a result. * @return Final result. */ - public <R> R computeWithCtx(IgniteTriFunction<C, D, Integer, R> map, IgniteBinaryOperator<R> reduce, R identity); + public <R> R computeWithCtx(IgniteTriFunction<C, D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity); /** - * Applies the specified {@code map} function to every partition {@code data} and partition index in the dataset - * and then reduces {@code map} results to final result by using the {@code reduce} function. + * Applies the specified {@code map} function to every partition {@code data} and {@link LearningEnvironment} + * in the dataset and then reduces {@code map} results to final result by using the {@code reduce} function. * - * @param map Function applied to every partition {@code data} and partition index. + * @param map Function applied to every partition {@code data} and {@link LearningEnvironment}. * @param reduce Function applied to results of {@code map} to get final result. * @param identity Identity. * @param <R> Type of a result. * @return Final result. */ - public <R> R compute(IgniteBiFunction<D, Integer, R> map, IgniteBinaryOperator<R> reduce, R identity); + public <R> R compute(IgniteBiFunction<D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity); /** - * Applies the specified {@code map} function to every partition {@code data}, {@code context} and partition - * index in the dataset and then reduces {@code map} results to final result by using the {@code reduce} function. + * Applies the specified {@code map} function to every partition {@code data}, {@code context} and + * {@link LearningEnvironment} in the dataset and then reduces {@code map} results to final + * result by using the {@code reduce} function. * - * @param map Function applied to every partition {@code data}, {@code context} and partition index. + * @param map Function applied to every partition {@code data}, {@code context} and {@link LearningEnvironment}. * @param reduce Function applied to results of {@code map} to get final result. * @param <R> Type of a result. * @return Final result. */ - public default <R> R computeWithCtx(IgniteTriFunction<C, D, Integer, R> map, IgniteBinaryOperator<R> reduce) { + public default <R> R computeWithCtx(IgniteTriFunction<C, D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce) { return computeWithCtx(map, reduce, null); } /** - * Applies the specified {@code map} function to every partition {@code data} and partition index in the dataset - * and then reduces {@code map} results to final result by using the {@code reduce} function. + * Applies the specified {@code map} function to every partition {@code data} and {@link LearningEnvironment} + * in the dataset and then reduces {@code map} results to final result by using the {@code reduce} function. * - * @param map Function applied to every partition {@code data} and partition index. + * @param map Function applied to every partition {@code data} and {@link LearningEnvironment}. * @param reduce Function applied to results of {@code map} to get final result. * @param <R> Type of a result. * @return Final result. */ - public default <R> R compute(IgniteBiFunction<D, Integer, R> map, IgniteBinaryOperator<R> reduce) { + public default <R> R compute(IgniteBiFunction<D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce) { return compute(map, reduce, null); } @@ -113,7 +115,7 @@ public interface Dataset<C extends Serializable, D extends AutoCloseable> extend * @return Final result. */ public default <R> R computeWithCtx(IgniteBiFunction<C, D, R> map, IgniteBinaryOperator<R> reduce, R identity) { - return computeWithCtx((ctx, data, partIdx) -> map.apply(ctx, data), reduce, identity); + return computeWithCtx((ctx, data, env) -> map.apply(ctx, data), reduce, identity); } /** @@ -127,7 +129,7 @@ public interface Dataset<C extends Serializable, D extends AutoCloseable> extend * @return Final result. */ public default <R> R compute(IgniteFunction<D, R> map, IgniteBinaryOperator<R> reduce, R identity) { - return compute((data, partIdx) -> map.apply(data), reduce, identity); + return compute((data, env) -> map.apply(data), reduce, identity); } /** @@ -140,7 +142,7 @@ public interface Dataset<C extends Serializable, D extends AutoCloseable> extend * @return Final result. */ public default <R> R computeWithCtx(IgniteBiFunction<C, D, R> map, IgniteBinaryOperator<R> reduce) { - return computeWithCtx((ctx, data, partIdx) -> map.apply(ctx, data), reduce); + return computeWithCtx((ctx, data, env) -> map.apply(ctx, data), reduce); } /** @@ -153,30 +155,31 @@ public interface Dataset<C extends Serializable, D extends AutoCloseable> extend * @return Final result. */ public default <R> R compute(IgniteFunction<D, R> map, IgniteBinaryOperator<R> reduce) { - return compute((data, partIdx) -> map.apply(data), reduce); + return compute((data, env) -> map.apply(data), reduce); } /** - * Applies the specified {@code map} function to every partition {@code data}, {@code context} and partition - * index in the dataset. + * Applies the specified {@code map} function to every partition {@code data}, {@code context} and + * {@link LearningEnvironment} in the dataset. * * @param map Function applied to every partition {@code data}, {@code context} and partition index. */ - public default void computeWithCtx(IgniteTriConsumer<C, D, Integer> map) { - computeWithCtx((ctx, data, partIdx) -> { - map.accept(ctx, data, partIdx); + public default void computeWithCtx(IgniteTriConsumer<C, D, LearningEnvironment> map) { + computeWithCtx((ctx, data, env) -> { + map.accept(ctx, data, env); return null; }, (a, b) -> null); } /** - * Applies the specified {@code map} function to every partition {@code data} in the dataset and partition index. + * Applies the specified {@code map} function to every partition {@code data} in the dataset and + * {@link LearningEnvironment}. * * @param map Function applied to every partition {@code data} and partition index. */ - public default void compute(IgniteBiConsumer<D, Integer> map) { - compute((data, partIdx) -> { - map.accept(data, partIdx); + public default void compute(IgniteBiConsumer<D, LearningEnvironment> map) { + compute((data, env) -> { + map.accept(data, env); return null; }, (a, b) -> null); } @@ -187,7 +190,7 @@ public interface Dataset<C extends Serializable, D extends AutoCloseable> extend * @param map Function applied to every partition {@code data} and {@code context}. */ public default void computeWithCtx(IgniteBiConsumer<C, D> map) { - computeWithCtx((ctx, data, partIdx) -> map.accept(ctx, data)); + computeWithCtx((ctx, data, env) -> map.accept(ctx, data)); } /** @@ -196,7 +199,7 @@ public interface Dataset<C extends Serializable, D extends AutoCloseable> extend * @param map Function applied to every partition {@code data}. */ public default void compute(IgniteConsumer<D> map) { - compute((data, partIdx) -> map.accept(data)); + compute((data, env) -> map.accept(data)); } /** http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java index 4dd0a96..9900659 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java @@ -21,6 +21,7 @@ import java.io.Serializable; import org.apache.ignite.lang.IgniteBiPredicate; import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer; /** @@ -40,6 +41,7 @@ public interface DatasetBuilder<K, V> { * Constructs a new instance of {@link Dataset} that includes allocation required data structures and * initialization of {@code context} part of partitions. * + * @param envBuilder Learning environment builder. * @param partCtxBuilder Partition {@code context} builder. * @param partDataBuilder Partition {@code data} builder. * @param <C> Type of a partition {@code context}. @@ -47,18 +49,25 @@ public interface DatasetBuilder<K, V> { * @return Dataset. */ public <C extends Serializable, D extends AutoCloseable> Dataset<C, D> build( - PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder); + LearningEnvironmentBuilder envBuilder, + PartitionContextBuilder<K, V, C> partCtxBuilder, + PartitionDataBuilder<K, V, C, D> partDataBuilder); /** - * Get upstream transformers chain. This chain is applied to upstream data before it is passed + * Returns new instance of {@link DatasetBuilder} with new {@link UpstreamTransformerBuilder} added + * to chain of upstream transformer builders. When needed, each builder in chain first transformed into + * {@link UpstreamTransformer}, those are in turn composed together one after another forming + * final {@link UpstreamTransformer}. + * This transformer is applied to upstream data before it is passed * to {@link PartitionDataBuilder} and {@link PartitionContextBuilder}. This is needed to allow * transformation to upstream data which are agnostic of any changes that happen after. * Such transformations may be used for deriving meta-algorithms such as bagging * (see {@link BaggingUpstreamTransformer}). * - * @return Upstream transformers chain. + * @return Returns new instance of {@link DatasetBuilder} with new {@link UpstreamTransformerBuilder} added + * to chain of upstream transformer builders. */ - public UpstreamTransformerChain<K, V> upstreamTransformersChain(); + public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder<K, V> builder); /** * Returns new instance of DatasetBuilder using conjunction of internal filter and {@code filterToAdd}. http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetFactory.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetFactory.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetFactory.java index 1623a2b..ef8eb23 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetFactory.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetFactory.java @@ -31,6 +31,7 @@ import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetD import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; import org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData; import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; @@ -76,6 +77,7 @@ public class DatasetFactory { * {@code partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with * any desired partition {@code context} and {@code data}. * + * @param envBuilder Learning environment builder. * @param datasetBuilder Dataset builder. * @param partCtxBuilder Partition {@code context} builder. * @param partDataBuilder Partition {@code data} builder. @@ -86,13 +88,42 @@ public class DatasetFactory { * @return Dataset. */ public static <K, V, C extends Serializable, D extends AutoCloseable> Dataset<C, D> create( - DatasetBuilder<K, V> datasetBuilder, PartitionContextBuilder<K, V, C> partCtxBuilder, + DatasetBuilder<K, V> datasetBuilder, + LearningEnvironmentBuilder envBuilder, + PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder) { return datasetBuilder.build( + envBuilder, partCtxBuilder, partDataBuilder ); } + + /** + * Creates a new instance of distributed dataset using the specified {@code partCtxBuilder} and + * {@code partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with + * any desired partition {@code context} and {@code data}. + * + * @param datasetBuilder Dataset builder. + * @param partCtxBuilder Partition {@code context} builder. + * @param partDataBuilder Partition {@code data} builder. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> ype of a value in {@code upstream} data. + * @param <C> Type of a partition {@code context}. + * @param <D> Type of a partition {@code data}. + * @return Dataset. + */ + public static <K, V, C extends Serializable, D extends AutoCloseable> Dataset<C, D> create( + DatasetBuilder<K, V> datasetBuilder, + PartitionContextBuilder<K, V, C> partCtxBuilder, + PartitionDataBuilder<K, V, C, D> partDataBuilder) { + return datasetBuilder.build( + LearningEnvironmentBuilder.defaultBuilder(), + partCtxBuilder, + partDataBuilder + ); + } + /** * Creates a new instance of distributed dataset using the specified {@code partCtxBuilder} and * {@code partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with @@ -100,6 +131,7 @@ public class DatasetFactory { * * @param ignite Ignite instance. * @param upstreamCache Ignite Cache with {@code upstream} data. + * @param envBuilder Learning environment builder. * @param partCtxBuilder Partition {@code context} builder. * @param partDataBuilder Partition {@code data} builder. * @param <K> Type of a key in {@code upstream} data. @@ -109,7 +141,36 @@ public class DatasetFactory { * @return Dataset. */ public static <K, V, C extends Serializable, D extends AutoCloseable> Dataset<C, D> create( - Ignite ignite, IgniteCache<K, V> upstreamCache, PartitionContextBuilder<K, V, C> partCtxBuilder, + Ignite ignite, IgniteCache<K, V> upstreamCache, + LearningEnvironmentBuilder envBuilder, + PartitionContextBuilder<K, V, C> partCtxBuilder, + PartitionDataBuilder<K, V, C, D> partDataBuilder) { + return create( + new CacheBasedDatasetBuilder<>(ignite, upstreamCache), + envBuilder, + partCtxBuilder, + partDataBuilder + ); + } + + /** + * Creates a new instance of distributed dataset using the specified {@code partCtxBuilder} and + * {@code partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with + * any desired partition {@code context} and {@code data}. + * + * @param ignite Ignite instance. + * @param upstreamCache Ignite Cache with {@code upstream} data. + * @param partCtxBuilder Partition {@code context} builder. + * @param partDataBuilder Partition {@code data} builder. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @param <C> Type of a partition {@code context}. + * @param <D> Type of a partition {@code data}. + * @return Dataset. + */ + public static <K, V, C extends Serializable, D extends AutoCloseable> Dataset<C, D> create( + Ignite ignite, IgniteCache<K, V> upstreamCache, + PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder) { return create( new CacheBasedDatasetBuilder<>(ignite, upstreamCache), @@ -124,6 +185,7 @@ public class DatasetFactory { * allows to use any desired type of partition {@code context}. * * @param datasetBuilder Dataset builder. + * @param envBuilder Learning environment builder. * @param partCtxBuilder Partition {@code context} builder. * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}. * @param <K> Type of a key in {@code upstream} data. @@ -132,10 +194,13 @@ public class DatasetFactory { * @return Dataset. */ public static <K, V, C extends Serializable> SimpleDataset<C> createSimpleDataset( - DatasetBuilder<K, V> datasetBuilder, PartitionContextBuilder<K, V, C> partCtxBuilder, + DatasetBuilder<K, V> datasetBuilder, + LearningEnvironmentBuilder envBuilder, + PartitionContextBuilder<K, V, C> partCtxBuilder, IgniteBiFunction<K, V, Vector> featureExtractor) { return create( datasetBuilder, + envBuilder, partCtxBuilder, new SimpleDatasetDataBuilder<>(featureExtractor) ).wrap(SimpleDataset::new); @@ -148,6 +213,7 @@ public class DatasetFactory { * * @param ignite Ignite instance. * @param upstreamCache Ignite Cache with {@code upstream} data. + * @param envBuilder Learning environment builder. * @param partCtxBuilder Partition {@code context} builder. * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}. * @param <K> Type of a key in {@code upstream} data. @@ -156,10 +222,13 @@ public class DatasetFactory { * @return Dataset. */ public static <K, V, C extends Serializable> SimpleDataset<C> createSimpleDataset(Ignite ignite, - IgniteCache<K, V> upstreamCache, PartitionContextBuilder<K, V, C> partCtxBuilder, + IgniteCache<K, V> upstreamCache, + LearningEnvironmentBuilder envBuilder, + PartitionContextBuilder<K, V, C> partCtxBuilder, IgniteBiFunction<K, V, Vector> featureExtractor) { return createSimpleDataset( new CacheBasedDatasetBuilder<>(ignite, upstreamCache), + envBuilder, partCtxBuilder, featureExtractor ); @@ -171,6 +240,7 @@ public class DatasetFactory { * {@link SimpleLabeledDatasetData}, but allows to use any desired type of partition {@code context}. * * @param datasetBuilder Dataset builder. + * @param envBuilder Learning environment builder. * @param partCtxBuilder Partition {@code context} builder. * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}. * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}. @@ -180,10 +250,14 @@ public class DatasetFactory { * @return Dataset. */ public static <K, V, C extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset( - DatasetBuilder<K, V> datasetBuilder, PartitionContextBuilder<K, V, C> partCtxBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) { + DatasetBuilder<K, V> datasetBuilder, + LearningEnvironmentBuilder envBuilder, + PartitionContextBuilder<K, V, C> partCtxBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, double[]> lbExtractor) { return create( datasetBuilder, + envBuilder, partCtxBuilder, new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor) ).wrap(SimpleLabeledDataset::new); @@ -196,6 +270,7 @@ public class DatasetFactory { * * @param ignite Ignite instance. * @param upstreamCache Ignite Cache with {@code upstream} data. + * @param envBuilder Learning environment builder. * @param partCtxBuilder Partition {@code context} builder. * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}. * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}. @@ -204,11 +279,16 @@ public class DatasetFactory { * @param <C> Type of a partition {@code context}. * @return Dataset. */ - public static <K, V, C extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset(Ignite ignite, - IgniteCache<K, V> upstreamCache, PartitionContextBuilder<K, V, C> partCtxBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) { + public static <K, V, C extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset( + Ignite ignite, + IgniteCache<K, V> upstreamCache, + LearningEnvironmentBuilder envBuilder, + PartitionContextBuilder<K, V, C> partCtxBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, double[]> lbExtractor) { return createSimpleLabeledDataset( new CacheBasedDatasetBuilder<>(ignite, upstreamCache), + envBuilder, partCtxBuilder, featureExtractor, lbExtractor @@ -221,15 +301,19 @@ public class DatasetFactory { * {@link SimpleDatasetData}. * * @param datasetBuilder Dataset builder. + * @param envBuilder Learning environment builder. * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}. * @param <K> Type of a key in {@code upstream} data. * @param <V> Type of a value in {@code upstream} data. * @return Dataset. */ - public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(DatasetBuilder<K, V> datasetBuilder, + public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset( + DatasetBuilder<K, V> datasetBuilder, + LearningEnvironmentBuilder envBuilder, IgniteBiFunction<K, V, Vector> featureExtractor) { return createSimpleDataset( datasetBuilder, + envBuilder, new EmptyContextBuilder<>(), featureExtractor ); @@ -242,15 +326,43 @@ public class DatasetFactory { * * @param ignite Ignite instance. * @param upstreamCache Ignite Cache with {@code upstream} data. + * @param envBuilder Learning environment builder. * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}. * @param <K> Type of a key in {@code upstream} data. * @param <V> Type of a value in {@code upstream} data. * @return Dataset. */ - public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(Ignite ignite, IgniteCache<K, V> upstreamCache, + public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset( + Ignite ignite, + IgniteCache<K, V> upstreamCache, + LearningEnvironmentBuilder envBuilder, IgniteBiFunction<K, V, Vector> featureExtractor) { return createSimpleDataset( new CacheBasedDatasetBuilder<>(ignite, upstreamCache), + envBuilder, + featureExtractor + ); + } + + /** + * Creates a new instance of distributed {@link SimpleDataset} using the specified {@code featureExtractor}. This + * methods determines partition {@code context} to be {@link EmptyContext} and partition {@code data} to be + * {@link SimpleDatasetData}. + * + * @param ignite Ignite instance. + * @param upstreamCache Ignite Cache with {@code upstream} data. + * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @return Dataset. + */ + public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset( + Ignite ignite, + IgniteCache<K, V> upstreamCache, + IgniteBiFunction<K, V, Vector> featureExtractor) { + return createSimpleDataset( + new CacheBasedDatasetBuilder<>(ignite, upstreamCache), + LearningEnvironmentBuilder.defaultBuilder(), featureExtractor ); } @@ -261,6 +373,7 @@ public class DatasetFactory { * partition {@code data} to be {@link SimpleLabeledDatasetData}. * * @param datasetBuilder Dataset builder. + * @param envBuilder Learning environment builder. * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}. * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}. * @param <K> Type of a key in {@code upstream} data. @@ -268,10 +381,13 @@ public class DatasetFactory { * @return Dataset. */ public static <K, V> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset( - DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + DatasetBuilder<K, V> datasetBuilder, + LearningEnvironmentBuilder envBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) { return createSimpleLabeledDataset( datasetBuilder, + envBuilder, new EmptyContextBuilder<>(), featureExtractor, lbExtractor @@ -285,17 +401,21 @@ public class DatasetFactory { * * @param ignite Ignite instance. * @param upstreamCache Ignite Cache with {@code upstream} data. + * @param envBuilder Learning environment builder. * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}. * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}. * @param <K> Type of a key in {@code upstream} data. * @param <V> Type of a value in {@code upstream} data. * @return Dataset. */ - public static <K, V> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset(Ignite ignite, + public static <K, V> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset( + Ignite ignite, + LearningEnvironmentBuilder envBuilder, IgniteCache<K, V> upstreamCache, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) { return createSimpleLabeledDataset( new CacheBasedDatasetBuilder<>(ignite, upstreamCache), + envBuilder, featureExtractor, lbExtractor ); @@ -309,6 +429,7 @@ public class DatasetFactory { * @param upstreamMap {@code Map} with {@code upstream} data. * @param partitions Number of partitions {@code upstream} {@code Map} will be divided on. * @param partCtxBuilder Partition {@code context} builder. + * @param envBuilder Learning environment builder. * @param partDataBuilder Partition {@code data} builder. * @param <K> Type of a key in {@code upstream} data. * @param <V> Type of a value in {@code upstream} data. @@ -317,10 +438,13 @@ public class DatasetFactory { * @return Dataset. */ public static <K, V, C extends Serializable, D extends AutoCloseable> Dataset<C, D> create( - Map<K, V> upstreamMap, int partitions, PartitionContextBuilder<K, V, C> partCtxBuilder, + Map<K, V> upstreamMap, + LearningEnvironmentBuilder envBuilder, + int partitions, PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder) { return create( new LocalDatasetBuilder<>(upstreamMap, partitions), + envBuilder, partCtxBuilder, partDataBuilder ); @@ -333,6 +457,7 @@ public class DatasetFactory { * * @param upstreamMap {@code Map} with {@code upstream} data. * @param partitions Number of partitions {@code upstream} {@code Map} will be divided on. + * @param envBuilder Learning environment builder. * @param partCtxBuilder Partition {@code context} builder. * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}. * @param <K> Type of a key in {@code upstream} data. @@ -340,11 +465,15 @@ public class DatasetFactory { * @param <C> Type of a partition {@code context}. * @return Dataset. */ - public static <K, V, C extends Serializable> SimpleDataset<C> createSimpleDataset(Map<K, V> upstreamMap, - int partitions, PartitionContextBuilder<K, V, C> partCtxBuilder, + public static <K, V, C extends Serializable> SimpleDataset<C> createSimpleDataset( + Map<K, V> upstreamMap, + int partitions, + LearningEnvironmentBuilder envBuilder, + PartitionContextBuilder<K, V, C> partCtxBuilder, IgniteBiFunction<K, V, Vector> featureExtractor) { return createSimpleDataset( new LocalDatasetBuilder<>(upstreamMap, partitions), + envBuilder, partCtxBuilder, featureExtractor ); @@ -357,6 +486,7 @@ public class DatasetFactory { * * @param upstreamMap {@code Map} with {@code upstream} data. * @param partitions Number of partitions {@code upstream} {@code Map} will be divided on. + * @param envBuilder Learning environment builder. * @param partCtxBuilder Partition {@code context} builder. * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}. * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}. @@ -366,10 +496,14 @@ public class DatasetFactory { * @return Dataset. */ public static <K, V, C extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset( - Map<K, V> upstreamMap, int partitions, PartitionContextBuilder<K, V, C> partCtxBuilder, + Map<K, V> upstreamMap, + int partitions, + LearningEnvironmentBuilder envBuilder, + PartitionContextBuilder<K, V, C> partCtxBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) { return createSimpleLabeledDataset( new LocalDatasetBuilder<>(upstreamMap, partitions), + envBuilder, partCtxBuilder, featureExtractor, lbExtractor ); @@ -382,15 +516,18 @@ public class DatasetFactory { * * @param upstreamMap {@code Map} with {@code upstream} data. * @param partitions Number of partitions {@code upstream} {@code Map} will be divided on. + * @param envBuilder Learning environment builder. * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}. * @param <K> Type of a key in {@code upstream} data. * @param <V> Type of a value in {@code upstream} data. * @return Dataset. */ public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(Map<K, V> upstreamMap, int partitions, + LearningEnvironmentBuilder envBuilder, IgniteBiFunction<K, V, Vector> featureExtractor) { return createSimpleDataset( new LocalDatasetBuilder<>(upstreamMap, partitions), + envBuilder, featureExtractor ); } @@ -402,6 +539,7 @@ public class DatasetFactory { * * @param upstreamMap {@code Map} with {@code upstream} data. * @param partitions Number of partitions {@code upstream} {@code Map} will be divided on. + * @param envBuilder Learning environment builder. * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}. * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}. * @param <K> Type of a key in {@code upstream} data. @@ -409,10 +547,12 @@ public class DatasetFactory { * @return Dataset. */ public static <K, V> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset(Map<K, V> upstreamMap, + LearningEnvironmentBuilder envBuilder, int partitions, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) { return createSimpleLabeledDataset( new LocalDatasetBuilder<>(upstreamMap, partitions), + envBuilder, featureExtractor, lbExtractor ); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionContextBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionContextBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionContextBuilder.java index 6e1fec3..c5eac88 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionContextBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionContextBuilder.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.Iterator; import java.util.stream.Stream; import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; +import org.apache.ignite.ml.environment.LearningEnvironment; import org.apache.ignite.ml.math.functions.IgniteFunction; /** @@ -43,11 +44,12 @@ public interface PartitionContextBuilder<K, V, C extends Serializable> extends S * constraint. This constraint is omitted to allow upstream data transformers in {@link DatasetBuilder} replicating * entries. For example it can be useful for bootstrapping. * + * @param env Learning environment. * @param upstreamData Partition {@code upstream} data. * @param upstreamDataSize Partition {@code upstream} data size. * @return Partition {@code context}. */ - public C build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize); + public C build(LearningEnvironment env, Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize); /** @@ -57,12 +59,13 @@ public interface PartitionContextBuilder<K, V, C extends Serializable> extends S * constraint. This constraint is omitted to allow upstream data transformers in {@link DatasetBuilder} replicating * entries. For example it can be useful for bootstrapping. * + * @param env Learning environment. * @param upstreamData Partition {@code upstream} data. * @param upstreamDataSize Partition {@code upstream} data size. * @return Partition {@code context}. */ - public default C build(Stream<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize) { - return build(upstreamData.iterator(), upstreamDataSize); + public default C build(LearningEnvironment env, Stream<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize) { + return build(env, upstreamData.iterator(), upstreamDataSize); } /** @@ -74,6 +77,6 @@ public interface PartitionContextBuilder<K, V, C extends Serializable> extends S * @return Composed partition {@code context} builder. */ public default <C2 extends Serializable> PartitionContextBuilder<K, V, C2> andThen(IgniteFunction<C, C2> fun) { - return (upstreamData, upstreamDataSize) -> fun.apply(build(upstreamData, upstreamDataSize)); + return (env, upstreamData, upstreamDataSize) -> fun.apply(build(env, upstreamData, upstreamDataSize)); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionDataBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionDataBuilder.java index 106084b..4a0e68e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionDataBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionDataBuilder.java @@ -22,6 +22,7 @@ import java.util.Iterator; import java.util.stream.Stream; import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleDatasetDataBuilder; import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder; +import org.apache.ignite.ml.environment.LearningEnvironment; import org.apache.ignite.ml.math.functions.IgniteBiFunction; /** @@ -46,12 +47,13 @@ public interface PartitionDataBuilder<K, V, C extends Serializable, D extends Au * constraint. This constraint is omitted to allow upstream data transformers in {@link DatasetBuilder} replicating * entries. For example it can be useful for bootstrapping. * + * @param env Learning environment. * @param upstreamData Partition {@code upstream} data. * @param upstreamDataSize Partition {@code upstream} data size. * @param ctx Partition {@code context}. * @return Partition {@code data}. */ - public D build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx); + public D build(LearningEnvironment env, Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx); /** * Builds a new partition {@code data} from a partition {@code upstream} data and partition {@code context}. @@ -60,13 +62,14 @@ public interface PartitionDataBuilder<K, V, C extends Serializable, D extends Au * constraint. This constraint is omitted to allow upstream data transformers in {@link DatasetBuilder} replicating * entries. For example it can be useful for bootstrapping. * + * @param env Learning environment. * @param upstreamData Partition {@code upstream} data. * @param upstreamDataSize Partition {@code upstream} data size. * @param ctx Partition {@code context}. * @return Partition {@code data}. */ - public default D build(Stream<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) { - return build(upstreamData.iterator(), upstreamDataSize, ctx); + public default D build(LearningEnvironment env, Stream<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) { + return build(env, upstreamData.iterator(), upstreamDataSize, ctx); } /** @@ -79,6 +82,7 @@ public interface PartitionDataBuilder<K, V, C extends Serializable, D extends Au */ public default <D2 extends AutoCloseable> PartitionDataBuilder<K, V, C, D2> andThen( IgniteBiFunction<D, C, D2> fun) { - return (upstreamData, upstreamDataSize, ctx) -> fun.apply(build(upstreamData, upstreamDataSize, ctx), ctx); + return (env, upstreamData, upstreamDataSize, ctx) -> + fun.apply(build(env, upstreamData, upstreamDataSize, ctx), ctx); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java index ba70e2e..11b250b 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java @@ -18,7 +18,6 @@ package org.apache.ignite.ml.dataset; import java.io.Serializable; -import java.util.Random; import java.util.stream.Stream; /** @@ -27,16 +26,25 @@ import java.util.stream.Stream; * @param <K> Type of keys in the upstream. * @param <V> Type of values in the upstream. */ +// TODO: IGNITE-10297: Investigate possibility of API change. @FunctionalInterface public interface UpstreamTransformer<K, V> extends Serializable { /** - * Perform transformation of upstream. + * Transform upstream. * - * @param rnd Random numbers generator. - * @param upstream Upstream. + * @param upstream Upstream to transform. * @return Transformed upstream. */ - // TODO: IGNITE-10296: Inject capabilities of randomization through learning environment. - // TODO: IGNITE-10297: Investigate possibility of API change. - public Stream<UpstreamEntry<K, V>> transform(Random rnd, Stream<UpstreamEntry<K, V>> upstream); + public Stream<UpstreamEntry<K, V>> transform(Stream<UpstreamEntry<K, V>> upstream); + + /** + * Get composition of this transformer and other transformer which is + * itself is {@link UpstreamTransformer} applying this transformer and then other transformer. + * + * @param other Other transformer. + * @return Composition of this and other transformer. + */ + default UpstreamTransformer<K, V> andThen(UpstreamTransformer<K, V> other) { + return upstream -> other.transform(transform(upstream)); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java new file mode 100644 index 0000000..9adfab5 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.dataset; + +import java.io.Serializable; +import org.apache.ignite.ml.environment.LearningEnvironment; + +/** + * Builder of {@link UpstreamTransformerBuilder}. + * @param <K> Type of keys in upstream. + * @param <V> Type of values in upstream. + */ +@FunctionalInterface +public interface UpstreamTransformerBuilder<K, V> extends Serializable { + /** + * Create {@link UpstreamTransformer} based on learning environment. + * + * @param env Learning environment. + * @return Upstream transformer. + */ + public UpstreamTransformer<K, V> build(LearningEnvironment env); + + /** + * Combunes two builders (this and other respectfully) + * <pre> + * env -> transformer1 + * env -> transformer2 + * </pre> + * into + * <pre> + * env -> transformer2 . transformer1 + * </pre> + * + * @param other Builder to combine with. + * @return Compositional builder. + */ + public default UpstreamTransformerBuilder<K, V> andThen(UpstreamTransformerBuilder<K, V> other) { + UpstreamTransformerBuilder<K, V> self = this; + return env -> { + UpstreamTransformer<K, V> transformer1 = self.build(env); + UpstreamTransformer<K, V> transformer2 = other.build(env); + + return upstream -> transformer2.transform(transformer1.transform(upstream)); + }; + } + + /** + * Returns identity upstream transformer. + * + * @param <K> Type of keys in upstream. + * @param <V> Type of values in upstream. + * @return Identity upstream transformer. + */ + public static <K, V> UpstreamTransformerBuilder<K, V> identity() { + return env -> upstream -> upstream; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerChain.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerChain.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerChain.java deleted file mode 100644 index 3ad6446..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerChain.java +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.dataset; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; -import java.util.Random; -import java.util.stream.Stream; -import org.apache.ignite.ml.math.functions.IgniteFunction; - -/** - * Class representing chain of transformers applied to upstream. - * - * @param <K> Type of upstream keys. - * @param <V> Type of upstream values. - */ -public class UpstreamTransformerChain<K, V> implements Serializable { - /** Seed used for transformations. */ - private Long seed; - - /** List of upstream transformations. */ - private List<UpstreamTransformer<K, V>> list; - - /** - * Creates empty upstream transformers chain (basically identity function). - * - * @param <K> Type of upstream keys. - * @param <V> Type of upstream values. - * @return Empty upstream transformers chain. - */ - public static <K, V> UpstreamTransformerChain<K, V> empty() { - return new UpstreamTransformerChain<>(); - } - - /** - * Creates upstream transformers chain consisting of one specified transformer. - * - * @param <K> Type of upstream keys. - * @param <V> Type of upstream values. - * @return Upstream transformers chain consisting of one specified transformer. - */ - public static <K, V> UpstreamTransformerChain<K, V> of(UpstreamTransformer<K, V> trans) { - UpstreamTransformerChain<K, V> res = new UpstreamTransformerChain<>(); - return res.addUpstreamTransformer(trans); - } - - /** - * Construct instance of this class. - */ - private UpstreamTransformerChain() { - list = new ArrayList<>(); - seed = new Random().nextLong(); - } - - /** - * Adds upstream transformer to this chain. - * - * @param next Transformer to add. - * @return This chain with added transformer. - */ - public UpstreamTransformerChain<K, V> addUpstreamTransformer(UpstreamTransformer<K, V> next) { - list.add(next); - - return this; - } - - /** - * Add upstream transformer based on given lambda. - * - * @param transformer Transformer. - * @return This object. - */ - public UpstreamTransformerChain<K, V> addUpstreamTransformer(IgniteFunction<Stream<UpstreamEntry<K, V>>, - Stream<UpstreamEntry<K, V>>> transformer) { - return addUpstreamTransformer((rnd, upstream) -> transformer.apply(upstream)); - } - - /** - * Performs stream transformation using RNG based on provided seed as pseudo-randomness source for all - * transformers in the chain. - * - * @param upstream Upstream. - * @return Transformed upstream. - */ - public Stream<UpstreamEntry<K, V>> transform(Stream<UpstreamEntry<K, V>> upstream) { - Random rnd = new Random(seed); - - Stream<UpstreamEntry<K, V>> res = upstream; - - for (UpstreamTransformer<K, V> kvUpstreamTransformer : list) - res = kvUpstreamTransformer.transform(rnd, res); - - return res; - } - - /** - * Checks if this chain is empty. - * - * @return Result of check if this chain is empty. - */ - public boolean isEmpty() { - return list.isEmpty(); - } - - /** - * Set seed for transformations. - * - * @param seed Seed. - * @return This object. - */ - public UpstreamTransformerChain<K, V> setSeed(long seed) { - this.seed = seed; - - return this; - } - - /** - * Modifies seed for transformations if it is present. - * - * @param f Modification function. - * @return This object. - */ - public UpstreamTransformerChain<K, V> modifySeed(IgniteFunction<Long, Long> f) { - seed = f.apply(seed); - - return this; - } - - /** - * Get seed used for RNG in transformations. - * - * @return Seed used for RNG in transformations. - */ - public Long seed() { - return seed; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetBuilder.java index 8707e3a..c8d78dd 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetBuilder.java @@ -20,9 +20,11 @@ package org.apache.ignite.ml.dataset.impl.bootstrapping; import java.util.Arrays; import java.util.Iterator; import org.apache.commons.math3.distribution.PoissonDistribution; +import org.apache.commons.math3.random.Well19937c; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironment; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; @@ -69,13 +71,22 @@ public class BootstrappedDatasetBuilder<K,V> implements PartitionDataBuilder<K,V } /** {@inheritDoc} */ - @Override public BootstrappedDatasetPartition build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, + @Override public BootstrappedDatasetPartition build( + LearningEnvironment env, + Iterator<UpstreamEntry<K, V>> upstreamData, + long upstreamDataSize, EmptyContext ctx) { BootstrappedVector[] dataset = new BootstrappedVector[Math.toIntExact(upstreamDataSize)]; int cntr = 0; - PoissonDistribution poissonDistribution = new PoissonDistribution(subsampleSize); + + PoissonDistribution poissonDistribution = new PoissonDistribution( + new Well19937c(env.randomNumbersGenerator().nextLong()), + subsampleSize, + PoissonDistribution.DEFAULT_EPSILON, + PoissonDistribution.DEFAULT_MAX_ITERATIONS); + while(upstreamData.hasNext()) { UpstreamEntry<K, V> nextRow = upstreamData.next(); Vector features = featureExtractor.apply(nextRow.getKey(), nextRow.getValue());
