http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetModel.java index 0e80325..1a251fa 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetModel.java @@ -17,7 +17,7 @@ package org.apache.ignite.ml.trainers; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.math.functions.IgniteFunction; /** @@ -29,7 +29,7 @@ import org.apache.ignite.ml.math.functions.IgniteFunction; * @param <OW> Type of output of inner model. * @param <M> Type of inner model. */ -public class AdaptableDatasetModel<I, O, IW, OW, M extends Model<IW, OW>> implements Model<I, O> { +public class AdaptableDatasetModel<I, O, IW, OW, M extends IgniteModel<IW, OW>> implements IgniteModel<I, O> { /** Function applied before inner model. */ private final IgniteFunction<I, IW> before; @@ -55,13 +55,13 @@ public class AdaptableDatasetModel<I, O, IW, OW, M extends Model<IW, OW>> implem /** * Result of this model application is a result of composition {@code before `andThen` inner mdl `andThen` after}. */ - @Override public O apply(I i) { - return before.andThen(mdl).andThen(after).apply(i); + @Override public O predict(I i) { + return before.andThen(mdl::predict).andThen(after).apply(i); } /** {@inheritDoc} */ - @Override public <O1> AdaptableDatasetModel<I, O1, IW, OW, M> andThen(IgniteFunction<O, O1> after) { - return new AdaptableDatasetModel<>(before, mdl, i -> after.apply(this.after.apply(i))); + @Override public <O1> AdaptableDatasetModel<I, O1, IW, OW, M> andThen(IgniteModel<O, O1> after) { + return new AdaptableDatasetModel<>(before, mdl, i -> after.predict(this.after.apply(i))); } /** @@ -92,7 +92,7 @@ public class AdaptableDatasetModel<I, O, IW, OW, M extends Model<IW, OW>> implem * @param <M1> Type of inner model. * @return New instance of this class with changed inner model. */ - public <M1 extends Model<IW, OW>> AdaptableDatasetModel<I, O, IW, OW, M1> withInnerModel(M1 mdl) { + public <M1 extends IgniteModel<IW, OW>> AdaptableDatasetModel<I, O, IW, OW, M1> withInnerModel(M1 mdl) { return new AdaptableDatasetModel<>(before, mdl, after); } }
http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java index 7e2e810..4205286 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java @@ -17,7 +17,7 @@ package org.apache.ignite.ml.trainers; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.functions.IgniteFunction; @@ -35,7 +35,7 @@ import org.apache.ignite.ml.math.primitives.vector.Vector; * @param <M> Type of model produced by wrapped model. * @param <L> Type of labels. */ -public class AdaptableDatasetTrainer<I, O, IW, OW, M extends Model<IW, OW>, L> +public class AdaptableDatasetTrainer<I, O, IW, OW, M extends IgniteModel<IW, OW>, L> extends DatasetTrainer<AdaptableDatasetModel<I, O, IW, OW, M>, L> { /** Wrapped trainer. */ private final DatasetTrainer<M, L> wrapped; @@ -56,7 +56,7 @@ public class AdaptableDatasetTrainer<I, O, IW, OW, M extends Model<IW, OW>, L> * @param <L> Type of labels. * @return Instance of this class. */ - public static <I, O, M extends Model<I, O>, L> AdaptableDatasetTrainer<I, O, I, O, M, L> of(DatasetTrainer<M, L> wrapped) { + public static <I, O, M extends IgniteModel<I, O>, L> AdaptableDatasetTrainer<I, O, I, O, M, L> of(DatasetTrainer<M, L> wrapped) { return new AdaptableDatasetTrainer<>(IgniteFunction.identity(), wrapped, IgniteFunction.identity()); } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java index 3f715dc..88c4bcd 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java @@ -21,7 +21,7 @@ import java.util.Map; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.lang.IgniteBiPredicate; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; @@ -39,7 +39,7 @@ import org.jetbrains.annotations.NotNull; * @param <M> Type of a produced model. * @param <L> Type of a label. */ -public abstract class DatasetTrainer<M extends Model, L> { +public abstract class DatasetTrainer<M extends IgniteModel, L> { /** Learning environment builder. */ protected LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder(); http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/main/java/org/apache/ignite/ml/trainers/MultiLabelDatasetTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/MultiLabelDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/MultiLabelDatasetTrainer.java index 5ae7de8..815bdd0 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/MultiLabelDatasetTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/MultiLabelDatasetTrainer.java @@ -17,12 +17,12 @@ package org.apache.ignite.ml.trainers; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; /** * Interface for trainers that trains on dataset with multiple label per object. * * @param <M> Type of a produced model. */ -public abstract class MultiLabelDatasetTrainer<M extends Model> extends DatasetTrainer<M, double[]> { +public abstract class MultiLabelDatasetTrainer<M extends IgniteModel> extends DatasetTrainer<M, double[]> { } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/main/java/org/apache/ignite/ml/trainers/SingleLabelDatasetTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/SingleLabelDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/SingleLabelDatasetTrainer.java index 38dda93..c4a1fa4 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/SingleLabelDatasetTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/SingleLabelDatasetTrainer.java @@ -17,12 +17,12 @@ package org.apache.ignite.ml.trainers; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; /** * Interface for trainers that trains on dataset with singe label per object. * * @param <M> Type of a produced model. */ -public abstract class SingleLabelDatasetTrainer<M extends Model> extends DatasetTrainer<M, Double> { +public abstract class SingleLabelDatasetTrainer<M extends IgniteModel> extends DatasetTrainer<M, Double> { } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java index 80a57e0..43c1600 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java @@ -22,7 +22,7 @@ import java.util.List; import java.util.Random; import java.util.stream.Collectors; import java.util.stream.IntStream; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator; import org.apache.ignite.ml.dataset.DatasetBuilder; @@ -52,7 +52,7 @@ public class TrainerTransformers { * @param <L> Type of labels. * @return Bagged trainer. */ - public static <M extends Model<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged( + public static <M extends IgniteModel<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged( DatasetTrainer<M, L> trainer, int ensembleSize, double subsampleRatio, @@ -71,7 +71,7 @@ public class TrainerTransformers { * @param <L> Type of labels. * @return Bagged trainer. */ - public static <M extends Model<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged( + public static <M extends IgniteModel<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged( DatasetTrainer<M, L> trainer, int ensembleSize, double subsampleRatio, @@ -142,7 +142,7 @@ public class TrainerTransformers { * @param <M> Type of model. * @return Composition of models trained on bagged dataset. */ - private static <K, V, M extends Model<Vector, Double>> ModelsComposition runOnEnsemble( + private static <K, V, M extends IgniteModel<Vector, Double>> ModelsComposition runOnEnsemble( IgniteTriFunction<DatasetBuilder<K, V>, Integer, IgniteBiFunction<K, V, Vector>, IgniteSupplier<M>> trainingTaskGenerator, DatasetBuilder<K, V> datasetBuilder, int ensembleSize, @@ -257,7 +257,7 @@ public class TrainerTransformers { * @param <Y> Output space. * @param <M> Model. */ - private static class ModelWithMapping<X, Y, M extends Model<X, Y>> implements Model<X, Y> { + private static class ModelWithMapping<X, Y, M extends IgniteModel<X, Y>> implements IgniteModel<X, Y> { /** Model. */ private final M model; @@ -295,8 +295,8 @@ public class TrainerTransformers { } /** {@inheritDoc} */ - @Override public Y apply(X x) { - return model.apply(mapping.apply(x)); + @Override public Y predict(X x) { + return model.predict(mapping.apply(x)); } /** http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java index ef4d115..7eff06c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java @@ -60,17 +60,17 @@ public class DecisionTreeConditionalNode implements DecisionTreeNode { } /** {@inheritDoc} */ - @Override public Double apply(Vector features) { + @Override public Double predict(Vector features) { double val = features.get(col); if (Double.isNaN(val)) { if (missingNode == null) throw new IllegalArgumentException("Feature must not be null or missing node should be specified"); - return missingNode.apply(features); + return missingNode.predict(features); } - return val > threshold ? thenNode.apply(features) : elseNode.apply(features); + return val > threshold ? thenNode.predict(features) : elseNode.predict(features); } /** */ http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java index 97cc3ee..43f0f05 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java @@ -39,7 +39,7 @@ public class DecisionTreeLeafNode implements DecisionTreeNode { } /** {@inheritDoc} */ - @Override public Double apply(Vector doubles) { + @Override public Double predict(Vector doubles) { return val; } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java index bd065f0..80036ba 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java @@ -17,11 +17,11 @@ package org.apache.ignite.ml.tree; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.math.primitives.vector.Vector; /** * Base interface for decision tree nodes. */ -public interface DecisionTreeNode extends Model<Vector, Double> { +public interface DecisionTreeNode extends IgniteModel<Vector, Double> { } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java index 71e840c..bc771eb 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java @@ -19,7 +19,7 @@ package org.apache.ignite.ml.tree.boosting; import java.util.Arrays; import java.util.List; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy; import org.apache.ignite.ml.composition.boosting.GDBTrainer; @@ -56,15 +56,15 @@ public class GDBOnTreesLearningStrategy extends GDBLearningStrategy { } /** {@inheritDoc} */ - @Override public <K, V> List<Model<Vector, Double>> update(GDBTrainer.GDBModel mdlToUpdate, + @Override public <K, V> List<IgniteModel<Vector, Double>> update(GDBTrainer.GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - DatasetTrainer<? extends Model<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get(); + DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get(); assert trainer instanceof DecisionTree; DecisionTree decisionTreeTrainer = (DecisionTree) trainer; - List<Model<Vector, Double>> models = initLearningState(mdlToUpdate); + List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate); ConvergenceChecker<K,V> convCheck = checkConvergenceStgyFactory.create(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, featureExtractor, lbExtractor); @@ -87,7 +87,7 @@ public class GDBOnTreesLearningStrategy extends GDBLearningStrategy { part.setCopiedOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length)); for(int j = 0; j < part.getLabels().length; j++) { - double mdlAnswer = currComposition.apply(VectorUtils.of(part.getFeatures()[j])); + double mdlAnswer = currComposition.predict(VectorUtils.of(part.getFeatures()[j])); double originalLbVal = externalLbToInternalMapping.apply(part.getCopiedOriginalLabels()[j]); part.getLabels()[j] = -loss.gradient(sampleSize, originalLbVal, mdlAnswer); } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java index 3ee90cb..d9b8e30 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java @@ -30,7 +30,7 @@ import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; @@ -248,7 +248,7 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra @Override protected <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - ArrayList<Model<Vector, Double>> oldModels = new ArrayList<>(mdl.getModels()); + ArrayList<IgniteModel<Vector, Double>> oldModels = new ArrayList<>(mdl.getModels()); ModelsComposition newModels = fit(datasetBuilder, featureExtractor, lbExtractor); oldModels.addAll(newModels.getModels()); http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java index 528e31d..8c11bdb 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java @@ -20,13 +20,13 @@ package org.apache.ignite.ml.tree.randomforest.data; import java.io.Serializable; import java.util.Arrays; import java.util.List; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.math.primitives.vector.Vector; /** * Decision tree node class. */ -public class TreeNode implements Model<Vector, Double>, Serializable { +public class TreeNode implements IgniteModel<Vector, Double>, Serializable { /** Serial version uid. */ private static final long serialVersionUID = -8546263332508653661L; @@ -83,16 +83,16 @@ public class TreeNode implements Model<Vector, Double>, Serializable { } /** {@inheritDoc} */ - public Double apply(Vector features) { + public Double predict(Vector features) { assert type != Type.UNKNOWN; if (type == Type.LEAF) return val; else { if (features.get(featureId) <= val) - return left.apply(features); + return left.predict(features); else - return right.apply(features); + return right.predict(features); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeRoot.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeRoot.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeRoot.java index e47868d..e9fcb74 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeRoot.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeRoot.java @@ -20,13 +20,13 @@ package org.apache.ignite.ml.tree.randomforest.data; import java.util.ArrayList; import java.util.List; import java.util.Set; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.math.primitives.vector.Vector; /** * Tree root class. */ -public class TreeRoot implements Model<Vector, Double> { +public class TreeRoot implements IgniteModel<Vector, Double> { /** Serial version uid. */ private static final long serialVersionUID = 531797299171329057L; @@ -48,8 +48,8 @@ public class TreeRoot implements Model<Vector, Double> { } /** {@inheritDoc} */ - @Override public Double apply(Vector vector) { - return node.apply(vector); + @Override public Double predict(Vector vector) { + return node.predict(vector); } /** */ http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java index b85a5c3..fc3bf5c 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java @@ -405,7 +405,7 @@ public class TestUtils { * @param <V> Type of output. * @return Model which returns given constant. */ - public static <T, V> Model<T, V> constantModel(V v) { + public static <T, V> IgniteModel<T, V> constantModel(V v) { return t -> v; } @@ -419,7 +419,7 @@ public class TestUtils { * @param <L> Type of dataset labels. * @return Trainer which independently of dataset outputs given model. */ - public static <I, O, M extends Model<I, O>, L> DatasetTrainer<M, L> constantTrainer(M ml) { + public static <I, O, M extends IgniteModel<I, O>, L> DatasetTrainer<M, L> constantTrainer(M ml) { return new DatasetTrainer<M, L>() { /** {@inheritDoc} */ @Override public <K, V> M fit(DatasetBuilder<K, V> datasetBuilder, http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java index f71b7b3..cc652e8 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java @@ -48,10 +48,10 @@ public class KMeansModelTest { Assert.assertTrue(mdl.toString().contains("KMeansModel")); - Assert.assertEquals(mdl.apply(new DenseVector(new double[]{1.1, 1.1})), 0.0, PRECISION); - Assert.assertEquals(mdl.apply(new DenseVector(new double[]{-1.1, 1.1})), 1.0, PRECISION); - Assert.assertEquals(mdl.apply(new DenseVector(new double[]{1.1, -1.1})), 2.0, PRECISION); - Assert.assertEquals(mdl.apply(new DenseVector(new double[]{-1.1, -1.1})), 3.0, PRECISION); + Assert.assertEquals(mdl.predict(new DenseVector(new double[]{1.1, 1.1})), 0.0, PRECISION); + Assert.assertEquals(mdl.predict(new DenseVector(new double[]{-1.1, 1.1})), 1.0, PRECISION); + Assert.assertEquals(mdl.predict(new DenseVector(new double[]{1.1, -1.1})), 2.0, PRECISION); + Assert.assertEquals(mdl.predict(new DenseVector(new double[]{-1.1, -1.1})), 3.0, PRECISION); Assert.assertEquals(mdl.distanceMeasure(), distanceMeasure); Assert.assertEquals(mdl.getAmountOfClusters(), 4); http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java index 205f0ff..e33ad08 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java @@ -66,9 +66,9 @@ public class KMeansTrainerTest extends TrainerTest { ); Vector firstVector = new DenseVector(new double[] {2.0, 2.0}); - assertEquals(knnMdl.apply(firstVector), 0.0, PRECISION); + assertEquals(knnMdl.predict(firstVector), 0.0, PRECISION); Vector secondVector = new DenseVector(new double[] {-2.0, -2.0}); - assertEquals(knnMdl.apply(secondVector), 0.0, PRECISION); + assertEquals(knnMdl.predict(secondVector), 0.0, PRECISION); assertEquals(trainer.getMaxIterations(), 1); assertEquals(trainer.getEpsilon(), PRECISION, PRECISION); } @@ -97,10 +97,10 @@ public class KMeansTrainerTest extends TrainerTest { Vector firstVector = new DenseVector(new double[] {2.0, 2.0}); Vector secondVector = new DenseVector(new double[] {-2.0, -2.0}); - assertEquals(originalMdl.apply(firstVector), updatedMdlOnSameDataset.apply(firstVector), PRECISION); - assertEquals(originalMdl.apply(secondVector), updatedMdlOnSameDataset.apply(secondVector), PRECISION); - assertEquals(originalMdl.apply(firstVector), updatedMdlOnEmptyDataset.apply(firstVector), PRECISION); - assertEquals(originalMdl.apply(secondVector), updatedMdlOnEmptyDataset.apply(secondVector), PRECISION); + assertEquals(originalMdl.predict(firstVector), updatedMdlOnSameDataset.predict(firstVector), PRECISION); + assertEquals(originalMdl.predict(secondVector), updatedMdlOnSameDataset.predict(secondVector), PRECISION); + assertEquals(originalMdl.predict(firstVector), updatedMdlOnEmptyDataset.predict(firstVector), PRECISION); + assertEquals(originalMdl.predict(secondVector), updatedMdlOnEmptyDataset.predict(secondVector), PRECISION); } /** */ http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/common/ModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/ModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/ModelTest.java index cfc081b..66be960 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/common/ModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/ModelTest.java @@ -17,28 +17,28 @@ package org.apache.ignite.ml.common; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; import org.junit.Test; import static org.junit.Assert.assertNotNull; /** - * Tests for {@link Model} functionality. + * Tests for {@link IgniteModel} functionality. */ public class ModelTest { /** */ @Test public void testCombine() { - Model<Object, Object> mdl = new TestModel<>().combine(new TestModel<>(), (x, y) -> x); + IgniteModel<Object, Object> mdl = new TestModel<>().combine(new TestModel<>(), (x, y) -> x); assertNotNull(mdl.toString(true)); assertNotNull(mdl.toString(false)); } /** */ - private static class TestModel<T, V> implements Model<T, V> { + private static class TestModel<T, V> implements IgniteModel<T, V> { /** {@inheritDoc} */ - @Override public V apply(T t) { + @Override public V predict(T t) { return null; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java index 405c70b..dd4b11e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java @@ -19,7 +19,7 @@ package org.apache.ignite.ml.composition; import java.util.Arrays; import java.util.Map; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator; @@ -95,8 +95,8 @@ public class BaggingTest extends TrainerTest { (k, v) -> v[0] ); - TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION); - TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION); + TestUtils.assertEquals(0, mdl.predict(VectorUtils.of(100, 10)), PRECISION); + TestUtils.assertEquals(1, mdl.predict(VectorUtils.of(10, 100)), PRECISION); } /** @@ -120,7 +120,7 @@ public class BaggingTest extends TrainerTest { new MeanValuePredictionsAggregator()) .fit(cacheMock, parts, null, null); - Double res = mdl.apply(null); + Double res = mdl.predict(null); TestUtils.assertEquals(twoLinearlySeparableClasses.length * subsampleRatio, res, twoLinearlySeparableClasses.length / 10); } @@ -145,7 +145,7 @@ public class BaggingTest extends TrainerTest { /** * Trainer used to count entries in context or in data. */ - protected static class CountTrainer extends DatasetTrainer<Model<Vector, Double>, Double> { + protected static class CountTrainer extends DatasetTrainer<IgniteModel<Vector, Double>, Double> { /** * Function specifying which entries to count. */ @@ -161,7 +161,7 @@ public class BaggingTest extends TrainerTest { } /** {@inheritDoc} */ - @Override public <K, V> Model<Vector, Double> fit( + @Override public <K, V> IgniteModel<Vector, Double> fit( DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { @@ -177,13 +177,13 @@ public class BaggingTest extends TrainerTest { } /** {@inheritDoc} */ - @Override protected boolean checkState(Model<Vector, Double> mdl) { + @Override protected boolean checkState(IgniteModel<Vector, Double> mdl) { return true; } /** {@inheritDoc} */ - @Override protected <K, V> Model<Vector, Double> updateModel( - Model<Vector, Double> mdl, + @Override protected <K, V> IgniteModel<Vector, Double> updateModel( + IgniteModel<Vector, Double> mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { return fit(datasetBuilder, featureExtractor, lbExtractor); http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java index 3336470..5cb2fe1 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java @@ -18,7 +18,7 @@ package org.apache.ignite.ml.composition; import java.util.Arrays; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer; @@ -107,10 +107,10 @@ public class StackingTest extends TrainerTest { (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[v.length - 1]); - assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(0.0, 0.0)), 0.3); - assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(0.0, 1.0)), 0.3); - assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(1.0, 0.0)), 0.3); - assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(1.0, 1.0)), 0.3); + assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3); + assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3); + assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3); + assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3); } /** @@ -152,10 +152,10 @@ public class StackingTest extends TrainerTest { (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[v.length - 1]); - assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(0.0, 0.0)), 0.3); - assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(0.0, 1.0)), 0.3); - assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(1.0, 0.0)), 0.3); - assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(1.0, 1.0)), 0.3); + assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3); + assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3); + assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3); + assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3); } /** @@ -164,7 +164,7 @@ public class StackingTest extends TrainerTest { */ @Test public void testINoWaysOfPropagation() { - StackedDatasetTrainer<Void, Void, Void, Model<Void, Void>, Void> trainer = + StackedDatasetTrainer<Void, Void, Void, IgniteModel<Void, Void>, Void> trainer = new StackedDatasetTrainer<>(); thrown.expect(IllegalStateException.class); trainer.fit(null, null, null); http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java index 4958b4b..f6fd0c4 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java @@ -20,7 +20,7 @@ package org.apache.ignite.ml.composition.boosting; import java.util.HashMap; import java.util.Map; import java.util.function.BiFunction; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory; @@ -60,7 +60,7 @@ public class GDBTrainerTest extends TrainerTest { GDBTrainer trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0) .withUsingIdx(true); - Model<Vector, Double> mdl = trainer.fit( + IgniteModel<Vector, Double> mdl = trainer.fit( learningSample, 1, (k, v) -> VectorUtils.of(v[0]), (k, v) -> v[1] @@ -70,7 +70,7 @@ public class GDBTrainerTest extends TrainerTest { for (int j = 0; j < size; j++) { double x = xs[j]; double y = ys[j]; - double p = mdl.apply(VectorUtils.of(x)); + double p = mdl.predict(VectorUtils.of(x)); mse += Math.pow(y - p, 2); } mse /= size; @@ -117,7 +117,7 @@ public class GDBTrainerTest extends TrainerTest { /** */ private void testClassifier(BiFunction<GDBTrainer, Map<Integer, double[]>, - Model<Vector, Double>> fitter) { + IgniteModel<Vector, Double>> fitter) { int sampleSize = 100; double[] xs = new double[sampleSize]; double[] ys = new double[sampleSize]; @@ -135,13 +135,13 @@ public class GDBTrainerTest extends TrainerTest { .withUsingIdx(true) .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.3)); - Model<Vector, Double> mdl = fitter.apply(trainer, learningSample); + IgniteModel<Vector, Double> mdl = fitter.apply(trainer, learningSample); int errorsCnt = 0; for (int j = 0; j < sampleSize; j++) { double x = xs[j]; double y = ys[j]; - double p = mdl.apply(VectorUtils.of(x)); + double p = mdl.predict(VectorUtils.of(x)); if (p != y) errorsCnt++; } @@ -201,9 +201,9 @@ public class GDBTrainerTest extends TrainerTest { dataset.forEach((k,v) -> { Vector features = fExtr.apply(k, v); - Double originalAnswer = originalMdl.apply(features); - Double updatedMdlAnswer1 = updatedOnSameDataset.apply(features); - Double updatedMdlAnswer2 = updatedOnEmptyDataset.apply(features); + Double originalAnswer = originalMdl.predict(features); + Double updatedMdlAnswer1 = updatedOnSameDataset.predict(features); + Double updatedMdlAnswer2 = updatedOnEmptyDataset.predict(features); assertEquals(originalAnswer, updatedMdlAnswer1, 0.01); assertEquals(originalAnswer, updatedMdlAnswer2, 0.01); http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java index 50fdf8b..abc24e7 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java @@ -32,14 +32,14 @@ import org.junit.Before; public abstract class ConvergenceCheckerTest { /** Not converged model. */ protected ModelsComposition notConvergedMdl = new ModelsComposition(Collections.emptyList(), null) { - @Override public Double apply(Vector features) { + @Override public Double predict(Vector features) { return 2.1 * features.get(0); } }; /** Converged model. */ protected ModelsComposition convergedMdl = new ModelsComposition(Collections.emptyList(), null) { - @Override public Double apply(Vector features) { + @Override public Double predict(Vector features) { return 2 * (features.get(0) + 1); } }; http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java index 4b44196..d253ea0 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java @@ -21,7 +21,7 @@ import java.util.Map; import java.util.Random; import java.util.stream.Collectors; import java.util.stream.IntStream; -import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; @@ -85,9 +85,9 @@ public class LearningEnvironmentTest { int partitions = 10; int iterations = 2; - DatasetTrainer<Model<Object, Vector>, Void> trainer = new DatasetTrainer<Model<Object, Vector>, Void>() { + DatasetTrainer<IgniteModel<Object, Vector>, Void> trainer = new DatasetTrainer<IgniteModel<Object, Vector>, Void>() { /** {@inheritDoc} */ - @Override public <K, V> Model<Object, Vector> fit(DatasetBuilder<K, V> datasetBuilder, + @Override public <K, V> IgniteModel<Object, Vector> fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Void> lbExtractor) { Dataset<EmptyContext, TestUtils.DataWrapper<Integer>> ds = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), @@ -103,26 +103,26 @@ public class LearningEnvironmentTest { } /** {@inheritDoc} */ - @Override protected boolean checkState(Model<Object, Vector> mdl) { + @Override protected boolean checkState(IgniteModel<Object, Vector> mdl) { return false; } /** {@inheritDoc} */ - @Override protected <K, V> Model<Object, Vector> updateModel(Model<Object, Vector> mdl, + @Override protected <K, V> IgniteModel<Object, Vector> updateModel(IgniteModel<Object, Vector> mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Void> lbExtractor) { return null; } }; trainer.withEnvironmentBuilder(envBuilder); - Model<Object, Vector> mdl = trainer.fit(getCacheMock(partitions), partitions, null, null); + IgniteModel<Object, Vector> mdl = trainer.fit(getCacheMock(partitions), partitions, null, null); Vector exp = VectorUtils.zeroes(partitions); for (int i = 0; i < partitions; i++) exp.set(i, i * iterations); - Vector res = mdl.apply(null); + Vector res = mdl.predict(null); assertEquals(exp, res); } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/inference/InferenceTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/InferenceTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/InferenceTestSuite.java index c670629..7364216 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/inference/InferenceTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/inference/InferenceTestSuite.java @@ -17,9 +17,11 @@ package org.apache.ignite.ml.inference; -import org.apache.ignite.ml.inference.builder.IgniteDistributedInfModelBuilderTest; -import org.apache.ignite.ml.inference.builder.SingleInfModelBuilderTest; -import org.apache.ignite.ml.inference.builder.ThreadedInfModelBuilderTest; +import junit.framework.JUnit4TestAdapter; +import junit.framework.TestSuite; +import org.apache.ignite.ml.inference.builder.IgniteDistributedModelBuilderTest; +import org.apache.ignite.ml.inference.builder.SingleModelBuilderTest; +import org.apache.ignite.ml.inference.builder.ThreadedModelBuilderTest; import org.apache.ignite.ml.inference.storage.model.DefaultModelStorageTest; import org.apache.ignite.ml.inference.util.DirectorySerializerTest; import org.junit.runner.RunWith; @@ -30,11 +32,23 @@ import org.junit.runners.Suite; */ @RunWith(Suite.class) @Suite.SuiteClasses({ - SingleInfModelBuilderTest.class, - ThreadedInfModelBuilderTest.class, + SingleModelBuilderTest.class, + ThreadedModelBuilderTest.class, DirectorySerializerTest.class, DefaultModelStorageTest.class, - IgniteDistributedInfModelBuilderTest.class + IgniteDistributedModelBuilderTest.class }) public class InferenceTestSuite { + /** */ + public static TestSuite suite() { + TestSuite suite = new TestSuite(InferenceTestSuite.class.getSimpleName()); + + suite.addTest(new JUnit4TestAdapter(SingleModelBuilderTest.class)); + suite.addTest(new JUnit4TestAdapter(ThreadedModelBuilderTest.class)); + suite.addTest(new JUnit4TestAdapter(DirectorySerializerTest.class)); + suite.addTest(new JUnit4TestAdapter(DefaultModelStorageTest.class)); + suite.addTest(new JUnit4TestAdapter(IgniteDistributedModelBuilderTest.class)); + + return suite; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/IgniteDistributedInfModelBuilderTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/IgniteDistributedInfModelBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/IgniteDistributedInfModelBuilderTest.java deleted file mode 100644 index b17a3fd..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/IgniteDistributedInfModelBuilderTest.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.inference.builder; - -import java.util.concurrent.Future; -import org.apache.ignite.Ignite; -import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.ml.inference.InfModel; -import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; -import org.junit.Test; - -/** - * Tests for {@link IgniteDistributedInfModelBuilder} class. - */ -public class IgniteDistributedInfModelBuilderTest extends GridCommonAbstractTest { - /** Number of nodes in grid */ - private static final int NODE_COUNT = 3; - - /** Ignite instance. */ - private Ignite ignite; - - /** {@inheritDoc} */ - @Override protected void beforeTestsStarted() throws Exception { - for (int i = 1; i <= NODE_COUNT; i++) - startGrid(i); - } - - /** {@inheritDoc} */ - @Override protected void afterTestsStopped() { - stopAllGrids(); - } - - /** - * {@inheritDoc} - */ - @Override protected void beforeTest() { - /* Grid instance. */ - ignite = grid(NODE_COUNT); - ignite.configuration().setPeerClassLoadingEnabled(true); - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - } - - /** */ - @Test - public void testBuild() { - AsyncInfModelBuilder mdlBuilder = new IgniteDistributedInfModelBuilder(ignite, 1, 1); - - InfModel<Integer, Future<Integer>> infMdl = mdlBuilder.build( - InfModelBuilderTestUtil.getReader(), - InfModelBuilderTestUtil.getParser() - ); - - // TODO: IGNITE-10250: Test hangs sometimes because of Ignite queue issue. - // for (int i = 0; i < 100; i++) - // assertEquals(Integer.valueOf(i), infMdl.predict(i).get()); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/IgniteDistributedModelBuilderTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/IgniteDistributedModelBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/IgniteDistributedModelBuilderTest.java new file mode 100644 index 0000000..2755dcd --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/IgniteDistributedModelBuilderTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.builder; + +import java.util.concurrent.Future; +import org.apache.ignite.Ignite; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.inference.Model; +import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; +import org.junit.Test; + +/** + * Tests for {@link IgniteDistributedModelBuilder} class. + */ +public class IgniteDistributedModelBuilderTest extends GridCommonAbstractTest { + /** Number of nodes in grid */ + private static final int NODE_COUNT = 3; + + /** Ignite instance. */ + private Ignite ignite; + + /** {@inheritDoc} */ + @Override protected void beforeTestsStarted() throws Exception { + for (int i = 1; i <= NODE_COUNT; i++) + startGrid(i); + } + + /** {@inheritDoc} */ + @Override protected void afterTestsStopped() { + stopAllGrids(); + } + + /** + * {@inheritDoc} + */ + @Override protected void beforeTest() { + /* Grid instance. */ + ignite = grid(NODE_COUNT); + ignite.configuration().setPeerClassLoadingEnabled(true); + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + } + + /** */ + @Test + public void testBuild() { + AsyncModelBuilder mdlBuilder = new IgniteDistributedModelBuilder(ignite, 1, 1); + + Model<Integer, Future<Integer>> infMdl = mdlBuilder.build( + ModelBuilderTestUtil.getReader(), + ModelBuilderTestUtil.getParser() + ); + + // TODO: IGNITE-10250: Test hangs sometimes because of Ignite queue issue. + // for (int i = 0; i < 100; i++) + // assertEquals(Integer.valueOf(i), infMdl.predict(i).get()); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/InfModelBuilderTestUtil.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/InfModelBuilderTestUtil.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/InfModelBuilderTestUtil.java deleted file mode 100644 index 6b20fc1..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/InfModelBuilderTestUtil.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.inference.builder; - -import org.apache.ignite.ml.inference.InfModel; -import org.apache.ignite.ml.inference.parser.InfModelParser; -import org.apache.ignite.ml.inference.reader.InfModelReader; - -/** - * Util class for model builder tests. - */ -class InfModelBuilderTestUtil { - /** - * Creates dummy model reader used in tests. - * - * @return Dummy model reader used in tests. - */ - static InfModelReader getReader() { - return () -> new byte[0]; - } - - /** - * Creates dummy model parser used in tests. - * - * @return Dummy model parser used in tests. - */ - static InfModelParser<Integer, Integer, InfModel<Integer, Integer>> getParser() { - return m -> new InfModel<Integer, Integer>() { - @Override public Integer apply(Integer input) { - return input; - } - - @Override public void close() { - // Do nothing. - } - }; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ModelBuilderTestUtil.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ModelBuilderTestUtil.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ModelBuilderTestUtil.java new file mode 100644 index 0000000..4ff501a --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ModelBuilderTestUtil.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.builder; + +import org.apache.ignite.ml.inference.Model; +import org.apache.ignite.ml.inference.parser.ModelParser; +import org.apache.ignite.ml.inference.reader.ModelReader; + +/** + * Util class for model builder tests. + */ +class ModelBuilderTestUtil { + /** + * Creates dummy model reader used in tests. + * + * @return Dummy model reader used in tests. + */ + static ModelReader getReader() { + return () -> new byte[0]; + } + + /** + * Creates dummy model parser used in tests. + * + * @return Dummy model parser used in tests. + */ + static ModelParser<Integer, Integer, Model<Integer, Integer>> getParser() { + return m -> new Model<Integer, Integer>() { + @Override public Integer predict(Integer input) { + return input; + } + + @Override public void close() { + // Do nothing. + } + }; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/SingleInfModelBuilderTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/SingleInfModelBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/SingleInfModelBuilderTest.java deleted file mode 100644 index b0bae25..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/SingleInfModelBuilderTest.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.inference.builder; - -import org.apache.ignite.ml.inference.InfModel; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -/** - * Tests for {@link SingleInfModelBuilder}. - */ -public class SingleInfModelBuilderTest { - /** */ - @Test - public void testBuild() { - SyncInfModelBuilder mdlBuilder = new SingleInfModelBuilder(); - - InfModel<Integer, Integer> infMdl = mdlBuilder.build( - InfModelBuilderTestUtil.getReader(), - InfModelBuilderTestUtil.getParser() - ); - - for (int i = 0; i < 100; i++) - assertEquals(Integer.valueOf(i), infMdl.apply(i)); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/SingleModelBuilderTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/SingleModelBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/SingleModelBuilderTest.java new file mode 100644 index 0000000..c09e997 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/SingleModelBuilderTest.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.builder; + +import org.apache.ignite.ml.inference.Model; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link SingleModelBuilder}. + */ +public class SingleModelBuilderTest { + /** */ + @Test + public void testBuild() { + SyncModelBuilder mdlBuilder = new SingleModelBuilder(); + + Model<Integer, Integer> infMdl = mdlBuilder.build( + ModelBuilderTestUtil.getReader(), + ModelBuilderTestUtil.getParser() + ); + + for (int i = 0; i < 100; i++) + assertEquals(Integer.valueOf(i), infMdl.predict(i)); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ThreadedInfModelBuilderTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ThreadedInfModelBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ThreadedInfModelBuilderTest.java deleted file mode 100644 index b4207f5..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ThreadedInfModelBuilderTest.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.inference.builder; - -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Future; -import org.apache.ignite.ml.inference.InfModel; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -/** - * Tests for {@link ThreadedInfModelBuilder} class. - */ -public class ThreadedInfModelBuilderTest { - /** */ - @Test - public void testBuild() throws ExecutionException, InterruptedException { - AsyncInfModelBuilder mdlBuilder = new ThreadedInfModelBuilder(10); - - InfModel<Integer, Future<Integer>> infMdl = mdlBuilder.build( - InfModelBuilderTestUtil.getReader(), - InfModelBuilderTestUtil.getParser() - ); - - for (int i = 0; i < 100; i++) - assertEquals(Integer.valueOf(i), infMdl.apply(i).get()); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ThreadedModelBuilderTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ThreadedModelBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ThreadedModelBuilderTest.java new file mode 100644 index 0000000..46862cc --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ThreadedModelBuilderTest.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.inference.builder; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import org.apache.ignite.ml.inference.Model; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link ThreadedModelBuilder} class. + */ +public class ThreadedModelBuilderTest { + /** */ + @Test + public void testBuild() throws ExecutionException, InterruptedException { + AsyncModelBuilder mdlBuilder = new ThreadedModelBuilder(10); + + Model<Integer, Future<Integer>> infMdl = mdlBuilder.build( + ModelBuilderTestUtil.getReader(), + ModelBuilderTestUtil.getParser() + ); + + for (int i = 0; i < 100; i++) + assertEquals(Integer.valueOf(i), infMdl.predict(i).get()); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java index 6fe8a63..3683fbb 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java @@ -60,7 +60,7 @@ public class KNNClassificationTest { /** */ @Test(expected = IllegalStateException.class) public void testNullDataset() { - new KNNClassificationModel(null).apply(null); + new KNNClassificationModel(null).predict(null); } /** */ @@ -90,9 +90,9 @@ public class KNNClassificationTest { assertTrue(!knnMdl.toString(false).isEmpty()); Vector firstVector = new DenseVector(new double[] {2.0, 2.0}); - assertEquals(knnMdl.apply(firstVector), 1.0); + assertEquals(knnMdl.predict(firstVector), 1.0); Vector secondVector = new DenseVector(new double[] {-2.0, -2.0}); - assertEquals(knnMdl.apply(secondVector), 2.0); + assertEquals(knnMdl.predict(secondVector), 2.0); } /** */ @@ -118,9 +118,9 @@ public class KNNClassificationTest { .withStrategy(NNStrategy.SIMPLE); Vector firstVector = new DenseVector(new double[] {2.0, 2.0}); - assertEquals(knnMdl.apply(firstVector), 1.0); + assertEquals(knnMdl.predict(firstVector), 1.0); Vector secondVector = new DenseVector(new double[] {-2.0, -2.0}); - assertEquals(knnMdl.apply(secondVector), 2.0); + assertEquals(knnMdl.predict(secondVector), 2.0); } /** */ @@ -146,7 +146,7 @@ public class KNNClassificationTest { .withStrategy(NNStrategy.SIMPLE); Vector vector = new DenseVector(new double[] {-1.01, -1.01}); - assertEquals(knnMdl.apply(vector), 2.0); + assertEquals(knnMdl.predict(vector), 2.0); } /** */ @@ -172,7 +172,7 @@ public class KNNClassificationTest { .withStrategy(NNStrategy.WEIGHTED); Vector vector = new DenseVector(new double[] {-1.01, -1.01}); - assertEquals(knnMdl.apply(vector), 1.0); + assertEquals(knnMdl.predict(vector), 1.0); } /** */ @@ -210,7 +210,7 @@ public class KNNClassificationTest { ); Vector vector = new DenseVector(new double[] {-1.01, -1.01}); - assertEquals(originalMdl.apply(vector), updatedOnSameDataset.apply(vector)); - assertEquals(originalMdl.apply(vector), updatedOnEmptyDataset.apply(vector)); + assertEquals(originalMdl.predict(vector), updatedOnSameDataset.predict(vector)); + assertEquals(originalMdl.predict(vector), updatedOnEmptyDataset.predict(vector)); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java index 9ff0bc2..75ab551 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java @@ -60,8 +60,8 @@ public class KNNRegressionTest extends TrainerTest { .withStrategy(NNStrategy.SIMPLE); Vector vector = new DenseVector(new double[] {0, 0, 0, 5.0, 0.0}); - System.out.println(knnMdl.apply(vector)); - Assert.assertEquals(15, knnMdl.apply(vector), 1E-12); + System.out.println(knnMdl.predict(vector)); + Assert.assertEquals(15, knnMdl.predict(vector), 1E-12); } /** */ @@ -107,9 +107,9 @@ public class KNNRegressionTest extends TrainerTest { Vector vector = new DenseVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956}); - Assert.assertNotNull(knnMdl.apply(vector)); + Assert.assertNotNull(knnMdl.predict(vector)); - Assert.assertEquals(67857, knnMdl.apply(vector), 2000); + Assert.assertEquals(67857, knnMdl.predict(vector), 2000); Assert.assertTrue(knnMdl.toString().contains(stgy.name())); Assert.assertTrue(knnMdl.toString(true).contains(stgy.name())); @@ -150,7 +150,7 @@ public class KNNRegressionTest extends TrainerTest { ); Vector vector = new DenseVector(new double[] {0, 0, 0, 5.0, 0.0}); - assertEquals(originalMdl.apply(vector), updatedOnSameDataset.apply(vector)); - assertEquals(originalMdl.apply(vector), updatedOnEmptyDataset.apply(vector)); + assertEquals(originalMdl.predict(vector), updatedOnSameDataset.predict(vector)); + assertEquals(originalMdl.predict(vector), updatedOnEmptyDataset.predict(vector)); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java index 74841a3..b2728c5 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java @@ -68,8 +68,8 @@ public class OneVsRestTrainerTest extends TrainerTest { Assert.assertTrue(mdl.toString(true).length() > 0); Assert.assertTrue(mdl.toString(false).length() > 0); - TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(-100, 0)), PRECISION); - TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 0)), PRECISION); + TestUtils.assertEquals(1, mdl.predict(VectorUtils.of(-100, 0)), PRECISION); + TestUtils.assertEquals(0, mdl.predict(VectorUtils.of(100, 0)), PRECISION); } /** */ @@ -119,8 +119,8 @@ public class OneVsRestTrainerTest extends TrainerTest { ); for (Vector vec : vectors) { - TestUtils.assertEquals(originalMdl.apply(vec), updatedOnSameDS.apply(vec), PRECISION); - TestUtils.assertEquals(originalMdl.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION); + TestUtils.assertEquals(originalMdl.predict(vec), updatedOnSameDS.predict(vec), PRECISION); + TestUtils.assertEquals(originalMdl.predict(vec), updatedOnEmptyDS.predict(vec), PRECISION); } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java index f6b947b..41d320d 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java @@ -39,7 +39,7 @@ public class DiscreteNaiveBayesModelTest { DiscreteNaiveBayesModel mdl = new DiscreteNaiveBayesModel(probabilities, classProbabilities, new double[] {first, second}, thresholds, new DiscreteNaiveBayesSumsHolder()); Vector observation = VectorUtils.of(2, 0, 1, 2, 0); - Assert.assertEquals(second, mdl.apply(observation), 0.0001); + Assert.assertEquals(second, mdl.predict(observation), 0.0001); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTest.java index 25fb37b..6c0e323 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTest.java @@ -62,6 +62,6 @@ public class DiscreteNaiveBayesTest { ); Vector observation = VectorUtils.of(1, 0, 1, 1, 0); - Assert.assertEquals(scottish, model.apply(observation), PRECISION); + Assert.assertEquals(scottish, model.predict(observation), PRECISION); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java index 7592811..d35ea3d 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java @@ -43,7 +43,7 @@ public class GaussianNaiveBayesModelTest { GaussianNaiveBayesModel mdl = new GaussianNaiveBayesModel(means, variances, probabilities, new double[] {first, second}, null); Vector observation = VectorUtils.of(6, 130, 8); - Assert.assertEquals(second, mdl.apply(observation), 0.0001); + Assert.assertEquals(second, mdl.predict(observation), 0.0001); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTest.java index 504b464..fd95a4e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTest.java @@ -57,7 +57,7 @@ public class GaussianNaiveBayesTest { ); Vector observation = VectorUtils.of(6, 130, 8); - Assert.assertEquals(female, model.apply(observation), PRECISION); + Assert.assertEquals(female, model.predict(observation), PRECISION); } /** Dataset from Gaussian NB example in the scikit-learn documentation */ @@ -80,7 +80,7 @@ public class GaussianNaiveBayesTest { ); Vector observation = VectorUtils.of(-0.8, -1); - Assert.assertEquals(one, model.apply(observation), PRECISION); + Assert.assertEquals(one, model.predict(observation), PRECISION); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java index 64ea9d3..348b7c9 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java @@ -84,8 +84,8 @@ public class GaussianNaiveBayesTrainerTest extends TrainerTest { (k, v) -> v[0] ); - TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION); - TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION); + TestUtils.assertEquals(0, mdl.predict(VectorUtils.of(100, 10)), PRECISION); + TestUtils.assertEquals(1, mdl.predict(VectorUtils.of(10, 100)), PRECISION); } /** */ http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java index 0f15dda..0a7d9a7 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java @@ -44,7 +44,7 @@ public class MLPTest { int input = 2; - Matrix predict = mlp.apply(new DenseMatrix(new double[][] {{input}})); + Matrix predict = mlp.predict(new DenseMatrix(new double[][] {{input}})); Assert.assertEquals(predict, new DenseMatrix(new double[][] {{Activators.SIGMOID.apply(input)}})); } @@ -68,7 +68,7 @@ public class MLPTest { Matrix input = new DenseMatrix(new double[][] {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}}); - Matrix predict = mlp.apply(input); + Matrix predict = mlp.predict(input); Matrix truth = new DenseMatrix(new double[][] {{0.0}, {1.0}, {1.0}, {0.0}}); TestUtils.checkIsInEpsilonNeighbourhood(predict.getRow(0), truth.getRow(0), 1E-4); @@ -99,8 +99,8 @@ public class MLPTest { MultilayerPerceptron stackedMLP = mlp1.add(mlp2); - Matrix predict = mlp.apply(new DenseMatrix(new double[][] {{1}, {2}, {3}, {4}}).transpose()); - Matrix stackedPredict = stackedMLP.apply(new DenseMatrix(new double[][] {{1}, {2}, {3}, {4}}).transpose()); + Matrix predict = mlp.predict(new DenseMatrix(new double[][] {{1}, {2}, {3}, {4}}).transpose()); + Matrix stackedPredict = stackedMLP.predict(new DenseMatrix(new double[][] {{1}, {2}, {3}, {4}}).transpose()); Assert.assertEquals(predict, stackedPredict); } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java index ff6754a..3c2d64b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java @@ -149,7 +149,7 @@ public class MLPTrainerIntegrationTest extends GridCommonAbstractTest { (k, v) -> new double[]{ v.lb} ); - Matrix predict = mlp.apply(new DenseMatrix(new double[][]{ + Matrix predict = mlp.predict(new DenseMatrix(new double[][]{ {0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java index 6a6555e..053fe33 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java @@ -146,7 +146,7 @@ public class MLPTrainerTest { (k, v) -> v[1] ); - Matrix predict = mlp.apply(new DenseMatrix(new double[][]{ + Matrix predict = mlp.predict(new DenseMatrix(new double[][]{ {0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, @@ -215,8 +215,8 @@ public class MLPTrainerTest { {1.0, 1.0} }); - TestUtils.checkIsInEpsilonNeighbourhood(originalMdl.apply(matrix).getRow(0), updatedOnSameDS.apply(matrix).getRow(0), 1E-1); - TestUtils.checkIsInEpsilonNeighbourhood(originalMdl.apply(matrix).getRow(0), updatedOnEmptyDS.apply(matrix).getRow(0), 1E-1); + TestUtils.checkIsInEpsilonNeighbourhood(originalMdl.predict(matrix).getRow(0), updatedOnSameDS.predict(matrix).getRow(0), 1E-1); + TestUtils.checkIsInEpsilonNeighbourhood(originalMdl.predict(matrix).getRow(0), updatedOnEmptyDS.predict(matrix).getRow(0), 1E-1); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java index e406c31..6a6fec2 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java @@ -120,7 +120,7 @@ public class MLPTrainerMnistIntegrationTest extends GridCommonAbstractTest { for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTestSet(1_000)) { Matrix input = new DenseMatrix(new double[][]{e.getPixels()}); - Matrix outputMatrix = mdl.apply(input); + Matrix outputMatrix = mdl.predict(input); int predicted = (int) VectorUtils.vec2Num(outputMatrix.getRow(0)); http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java index 9396009..a21320a 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java @@ -85,7 +85,7 @@ public class MLPTrainerMnistTest { for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTestSet(10_000)) { Matrix input = new DenseMatrix(new double[][]{e.getPixels()}); - Matrix outputMatrix = mdl.apply(input); + Matrix outputMatrix = mdl.predict(input); int predicted = (int) VectorUtils.vec2Num(outputMatrix.getRow(0)); http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java index 8445900..3438e4b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java @@ -53,19 +53,19 @@ public class PipelineMdlTest { /** */ private void verifyPredict(PipelineMdl mdl) { Vector observation = new DenseVector(new double[] {1.0, 1.0}); - TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION); + TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 1.0), mdl.predict(observation), PRECISION); observation = new DenseVector(new double[] {2.0, 1.0}); - TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION); + TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 2.0 + 3.0 * 1.0), mdl.predict(observation), PRECISION); observation = new DenseVector(new double[] {1.0, 2.0}); - TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 2.0), mdl.apply(observation), PRECISION); + TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 2.0), mdl.predict(observation), PRECISION); observation = new DenseVector(new double[] {-2.0, 1.0}); - TestUtils.assertEquals(sigmoid(1.0 - 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION); + TestUtils.assertEquals(sigmoid(1.0 - 2.0 * 2.0 + 3.0 * 1.0), mdl.predict(observation), PRECISION); observation = new DenseVector(new double[] {1.0, -2.0}); - TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 - 3.0 * 2.0), mdl.apply(observation), PRECISION); + TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 - 3.0 * 2.0), mdl.predict(observation), PRECISION); } /** http://git-wip-us.apache.org/repos/asf/ignite/blob/2dc0d9f7/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java index 694dcd3..98de92b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java @@ -71,8 +71,8 @@ public class PipelineTest extends TrainerTest { parts ); - TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION); - TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION); + TestUtils.assertEquals(0, mdl.predict(VectorUtils.of(100, 10)), PRECISION); + TestUtils.assertEquals(1, mdl.predict(VectorUtils.of(10, 100)), PRECISION); } /** @@ -101,7 +101,7 @@ public class PipelineTest extends TrainerTest { parts ); - TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION); - TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION); + TestUtils.assertEquals(0, mdl.predict(VectorUtils.of(100, 10)), PRECISION); + TestUtils.assertEquals(1, mdl.predict(VectorUtils.of(10, 100)), PRECISION); } }