http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java index f555e09..8918450 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java @@ -29,10 +29,10 @@ import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.selection.scoring.metric.Metric; import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursor; -import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursor; import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor; +import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursor; +import org.apache.ignite.ml.selection.scoring.metric.Metric; import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper; import org.apache.ignite.ml.selection.split.mapper.UniformMapper; import org.apache.ignite.ml.trainers.DatasetTrainer; @@ -66,7 +66,7 @@ public class CrossValidation<M extends Model<Vector, L>, L, K, V> { * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Ignite ignite, - IgniteCache<K, V> upstreamCache, IgniteBiFunction<K, V, double[]> featureExtractor, + IgniteCache<K, V> upstreamCache, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) { return score(trainer, scoreCalculator, ignite, upstreamCache, (k, v) -> true, featureExtractor, lbExtractor, new SHA256UniformMapper<>(), cv); @@ -87,7 +87,7 @@ public class CrossValidation<M extends Model<Vector, L>, L, K, V> { */ public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter, - IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) { + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) { return score(trainer, scoreCalculator, ignite, upstreamCache, filter, featureExtractor, lbExtractor, new SHA256UniformMapper<>(), cv); } @@ -108,7 +108,7 @@ public class CrossValidation<M extends Model<Vector, L>, L, K, V> { */ public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter, - IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, UniformMapper<K, V> mapper, int cv) { return score( @@ -146,7 +146,7 @@ public class CrossValidation<M extends Model<Vector, L>, L, K, V> { * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Map<K, V> upstreamMap, - int parts, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) { + int parts, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) { return score(trainer, scoreCalculator, upstreamMap, (k, v) -> true, parts, featureExtractor, lbExtractor, new SHA256UniformMapper<>(), cv); } @@ -165,7 +165,7 @@ public class CrossValidation<M extends Model<Vector, L>, L, K, V> { * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Map<K, V> upstreamMap, - IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, double[]> featureExtractor, + IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) { return score(trainer, scoreCalculator, upstreamMap, filter, parts, featureExtractor, lbExtractor, new SHA256UniformMapper<>(), cv); @@ -186,7 +186,7 @@ public class CrossValidation<M extends Model<Vector, L>, L, K, V> { * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Map<K, V> upstreamMap, - IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, double[]> featureExtractor, + IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, UniformMapper<K, V> mapper, int cv) { return score( trainer, @@ -226,7 +226,7 @@ public class CrossValidation<M extends Model<Vector, L>, L, K, V> { private double[] score(DatasetTrainer<M, L> trainer, Function<IgniteBiPredicate<K, V>, DatasetBuilder<K, V>> datasetBuilderSupplier, BiFunction<IgniteBiPredicate<K, V>, M, LabelPairCursor<L>> testDataIterSupplier, - IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, Metric<L> scoreCalculator, UniformMapper<K, V> mapper, int cv) { double[] scores = new double[cv];
http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursor.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursor.java index bc84743..589aecc 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursor.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursor.java @@ -26,7 +26,6 @@ import org.apache.ignite.lang.IgniteBiPredicate; import org.apache.ignite.ml.Model; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.selection.scoring.LabelPair; import org.jetbrains.annotations.NotNull; @@ -42,7 +41,7 @@ public class CacheBasedLabelPairCursor<L, K, V> implements LabelPairCursor<L> { private final QueryCursor<Cache.Entry<K, V>> cursor; /** Feature extractor. */ - private final IgniteBiFunction<K, V, double[]> featureExtractor; + private final IgniteBiFunction<K, V, Vector> featureExtractor; /** Label extractor. */ private final IgniteBiFunction<K, V, L> lbExtractor; @@ -60,7 +59,7 @@ public class CacheBasedLabelPairCursor<L, K, V> implements LabelPairCursor<L> { * @param mdl Model for inference. */ public CacheBasedLabelPairCursor(IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter, - IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, Model<Vector, L> mdl) { this.cursor = query(upstreamCache, filter); this.featureExtractor = featureExtractor; @@ -77,7 +76,7 @@ public class CacheBasedLabelPairCursor<L, K, V> implements LabelPairCursor<L> { * @param mdl Model for inference. */ public CacheBasedLabelPairCursor(IgniteCache<K, V> upstreamCache, - IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, Model<Vector, L> mdl) { this.cursor = query(upstreamCache); this.featureExtractor = featureExtractor; @@ -146,10 +145,10 @@ public class CacheBasedLabelPairCursor<L, K, V> implements LabelPairCursor<L> { @Override public LabelPair<L> next() { Cache.Entry<K, V> entry = iter.next(); - double[] features = featureExtractor.apply(entry.getKey(), entry.getValue()); + Vector features = featureExtractor.apply(entry.getKey(), entry.getValue()); L lb = lbExtractor.apply(entry.getKey(), entry.getValue()); - return new LabelPair<>(lb, mdl.apply(new DenseLocalOnHeapVector(features))); + return new LabelPair<>(lb, mdl.apply(features)); } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursor.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursor.java index fbbe431..212dcd8 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursor.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursor.java @@ -24,7 +24,6 @@ import org.apache.ignite.lang.IgniteBiPredicate; import org.apache.ignite.ml.Model; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.selection.scoring.LabelPair; import org.jetbrains.annotations.NotNull; @@ -43,7 +42,7 @@ public class LocalLabelPairCursor<L, K, V, T> implements LabelPairCursor<L> { private final IgniteBiPredicate<K, V> filter; /** Feature extractor. */ - private final IgniteBiFunction<K, V, double[]> featureExtractor; + private final IgniteBiFunction<K, V, Vector> featureExtractor; /** Label extractor. */ private final IgniteBiFunction<K, V, L> lbExtractor; @@ -61,7 +60,7 @@ public class LocalLabelPairCursor<L, K, V, T> implements LabelPairCursor<L> { * @param mdl Model for inference. */ public LocalLabelPairCursor(Map<K, V> upstreamMap, IgniteBiPredicate<K, V> filter, - IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, Model<Vector, L> mdl) { this.upstreamMap = upstreamMap; this.filter = filter; @@ -114,12 +113,12 @@ public class LocalLabelPairCursor<L, K, V, T> implements LabelPairCursor<L> { K key = nextEntry.getKey(); V val = nextEntry.getValue(); - double[] features = featureExtractor.apply(key, val); + Vector features = featureExtractor.apply(key, val); L lb = lbExtractor.apply(key, val); nextEntry = null; - return new LabelPair<>(lb, mdl.apply(new DenseLocalOnHeapVector(features))); + return new LabelPair<>(lb, mdl.apply(features)); } /** http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/Evaluator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/Evaluator.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/Evaluator.java index 68eb5e6..2ee0b2d 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/Evaluator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/Evaluator.java @@ -45,7 +45,7 @@ public class Evaluator { */ public static <L, K, V> double evaluate(IgniteCache<K, V> dataCache, Model<Vector, L> mdl, - IgniteBiFunction<K, V, double[]> featureExtractor, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, Accuracy<L> metric) { double metricRes; @@ -81,7 +81,7 @@ public class Evaluator { */ public static <L, K, V> double evaluate(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter, Model<Vector, L> mdl, - IgniteBiFunction<K, V, double[]> featureExtractor, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, Accuracy<L> metric) { double metricRes; http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java index b7f62ac..00abde7 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.Iterator; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.structures.LabeledDataset; import org.apache.ignite.ml.structures.LabeledVector; @@ -38,7 +39,7 @@ public class LabeledDatasetPartitionDataBuilderOnHeap<K, V, C extends Serializab private static final long serialVersionUID = -7820760153954269227L; /** Extractor of X matrix row. */ - private final IgniteBiFunction<K, V, double[]> xExtractor; + private final IgniteBiFunction<K, V, Vector> xExtractor; /** Extractor of Y vector value. */ private final IgniteBiFunction<K, V, Double> yExtractor; @@ -49,7 +50,7 @@ public class LabeledDatasetPartitionDataBuilderOnHeap<K, V, C extends Serializab * @param xExtractor Extractor of X matrix row. * @param yExtractor Extractor of Y vector value. */ - public LabeledDatasetPartitionDataBuilderOnHeap(IgniteBiFunction<K, V, double[]> xExtractor, + public LabeledDatasetPartitionDataBuilderOnHeap(IgniteBiFunction<K, V, Vector> xExtractor, IgniteBiFunction<K, V, Double> yExtractor) { this.xExtractor = xExtractor; this.yExtractor = yExtractor; @@ -66,16 +67,16 @@ public class LabeledDatasetPartitionDataBuilderOnHeap<K, V, C extends Serializab while (upstreamData.hasNext()) { UpstreamEntry<K, V> entry = upstreamData.next(); - double[] row = xExtractor.apply(entry.getKey(), entry.getValue()); + Vector row = xExtractor.apply(entry.getKey(), entry.getValue()); if (xCols < 0) { - xCols = row.length; + xCols = row.size(); x = new double[Math.toIntExact(upstreamDataSize)][xCols]; } else - assert row.length == xCols : "X extractor must return exactly " + xCols + " columns"; + assert row.size() == xCols : "X extractor must return exactly " + xCols + " columns"; - x[ptr] = row; + x[ptr] = row.asArray(); y[ptr] = yExtractor.apply(entry.getKey(), entry.getValue()); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java index d56848c..10a339a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java @@ -18,17 +18,17 @@ package org.apache.ignite.ml.svm; import java.util.concurrent.ThreadLocalRandom; -import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; -import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap; -import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.structures.LabeledDataset; import org.apache.ignite.ml.structures.LabeledVector; +import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap; +import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; import org.jetbrains.annotations.NotNull; /** @@ -56,7 +56,7 @@ public class SVMLinearBinaryClassificationTrainer implements SingleLabelDatasetT * @return Model. */ @Override public <K, V> SVMLinearBinaryClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { assert datasetBuilder != null; http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java index 4e081c6..8b3c9a2 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java @@ -24,14 +24,15 @@ import java.util.List; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; -import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; -import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap; import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap; +import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; /** * Base class for a soft-margin SVM linear multiclass-classification trainer based on the communication-efficient @@ -59,7 +60,7 @@ public class SVMLinearMultiClassClassificationTrainer * @return Model. */ @Override public <K, V> SVMLinearMultiClassClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, double[]> featureExtractor, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java index 4d7a262..f72c5ee 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java @@ -25,6 +25,7 @@ import org.apache.ignite.ml.Model; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.functions.IgniteBiFunction; /** @@ -44,7 +45,7 @@ public interface DatasetTrainer<M extends Model, L> { * @param <V> Type of a value in {@code upstream} data. * @return Model. */ - public <K, V> M fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, + public <K, V> M fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor); /** @@ -59,7 +60,7 @@ public interface DatasetTrainer<M extends Model, L> { * @return Model. */ public default <K, V> M fit(Ignite ignite, IgniteCache<K, V> cache, - IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { return fit( new CacheBasedDatasetBuilder<>(ignite, cache), featureExtractor, @@ -80,7 +81,7 @@ public interface DatasetTrainer<M extends Model, L> { * @return Model. */ public default <K, V> M fit(Ignite ignite, IgniteCache<K, V> cache, IgniteBiPredicate<K, V> filter, - IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { return fit( new CacheBasedDatasetBuilder<>(ignite, cache, filter), featureExtractor, @@ -99,7 +100,7 @@ public interface DatasetTrainer<M extends Model, L> { * @param <V> Type of a value in {@code upstream} data. * @return Model. */ - public default <K, V> M fit(Map<K, V> data, int parts, IgniteBiFunction<K, V, double[]> featureExtractor, + public default <K, V> M fit(Map<K, V> data, int parts, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { return fit( new LocalDatasetBuilder<>(data, parts), @@ -121,7 +122,7 @@ public interface DatasetTrainer<M extends Model, L> { * @return Model. */ public default <K, V> M fit(Map<K, V> data, IgniteBiPredicate<K, V> filter, int parts, - IgniteBiFunction<K, V, double[]> featureExtractor, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { return fit( new LocalDatasetBuilder<>(data, filter, parts), http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java index a5d971f..b2dfd49 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java @@ -23,6 +23,7 @@ import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.trainers.DatasetTrainer; import org.apache.ignite.ml.tree.data.DecisionTreeData; @@ -68,7 +69,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> implements Data /** {@inheritDoc} */ @Override public <K, V> DecisionTreeNode fit(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build( new EmptyContextBuilder<>(), new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor) http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java index 819af2b..eca6ac3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.Iterator; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.functions.IgniteBiFunction; /** @@ -36,7 +37,7 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable> private static final long serialVersionUID = 3678784980215216039L; /** Function that extracts features from an {@code upstream} data. */ - private final IgniteBiFunction<K, V, double[]> featureExtractor; + private final IgniteBiFunction<K, V, Vector> featureExtractor; /** Function that extracts labels from an {@code upstream} data. */ private final IgniteBiFunction<K, V, Double> lbExtractor; @@ -47,7 +48,7 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable> * @param featureExtractor Function that extracts features from an {@code upstream} data. * @param lbExtractor Function that extracts labels from an {@code upstream} data. */ - public DecisionTreeDataBuilder(IgniteBiFunction<K, V, double[]> featureExtractor, + public DecisionTreeDataBuilder(IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { this.featureExtractor = featureExtractor; this.lbExtractor = lbExtractor; @@ -62,7 +63,7 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable> while (upstreamData.hasNext()) { UpstreamEntry<K, V> entry = upstreamData.next(); - features[ptr] = featureExtractor.apply(entry.getKey(), entry.getValue()); + features[ptr] = featureExtractor.apply(entry.getKey(), entry.getValue()).asArray(); labels[ptr] = lbExtractor.apply(entry.getKey(), entry.getValue()); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java index 353cc22..8a42fc0 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java @@ -31,6 +31,7 @@ import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.knn.classification.KNNClassificationModel; import org.apache.ignite.ml.knn.classification.KNNModelFormat; import org.apache.ignite.ml.knn.classification.KNNStrategy; +import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; @@ -153,7 +154,7 @@ public class LocalModelsTest { KMeansModel knnMdl = trainer.fit( new LocalDatasetBuilder<>(data, 2), - (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[2] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java index 846d0de..c21fbc8 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java @@ -24,6 +24,7 @@ import org.apache.ignite.ml.clustering.kmeans.KMeansModel; import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.junit.Test; @@ -59,7 +60,7 @@ public class KMeansTrainerTest { KMeansModel knnMdl = trainer.fit( new LocalDatasetBuilder<>(data, 2), - (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[2] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java index 40a416f..9363938 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java @@ -54,7 +54,7 @@ public class GDBTrainerTest { DatasetTrainer<Model<Vector, Double>, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0); Model<Vector, Double> model = trainer.fit( learningSample, 1, - (k, v) -> new double[] {v[0]}, + (k, v) -> VectorUtils.of(v[0]), (k, v) -> v[1] ); @@ -95,7 +95,7 @@ public class GDBTrainerTest { DatasetTrainer<Model<Vector, Double>, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0); Model<Vector, Double> model = trainer.fit( learningSample, 1, - (k, v) -> new double[] {v[0]}, + (k, v) -> VectorUtils.of(v[0]), (k, v) -> v[1] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java index 004718e..f9a0c55 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java @@ -26,6 +26,7 @@ import org.apache.ignite.ml.knn.classification.KNNClassificationModel; import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer; import org.apache.ignite.ml.knn.classification.KNNStrategy; import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.junit.Test; @@ -71,7 +72,7 @@ public class KNNClassificationTest { KNNClassificationModel knnMdl = trainer.fit( data, parts, - (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[2] ).withK(3) .withDistanceMeasure(new EuclideanDistance()) @@ -99,7 +100,7 @@ public class KNNClassificationTest { KNNClassificationModel knnMdl = trainer.fit( data, parts, - (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[2] ).withK(1) .withDistanceMeasure(new EuclideanDistance()) @@ -127,7 +128,7 @@ public class KNNClassificationTest { KNNClassificationModel knnMdl = trainer.fit( data, parts, - (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[2] ).withK(3) .withDistanceMeasure(new EuclideanDistance()) @@ -153,7 +154,7 @@ public class KNNClassificationTest { KNNClassificationModel knnMdl = trainer.fit( data, parts, - (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[2] ).withK(3) .withDistanceMeasure(new EuclideanDistance()) http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java index 0c26ba9..d66f1f2 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java @@ -27,6 +27,7 @@ import org.apache.ignite.ml.knn.classification.KNNStrategy; import org.apache.ignite.ml.knn.regression.KNNRegressionModel; import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer; import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.junit.Assert; @@ -72,7 +73,7 @@ public class KNNRegressionTest { KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit( new LocalDatasetBuilder<>(data, parts), - (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), (k, v) -> v[0] ).withK(1) .withDistanceMeasure(new EuclideanDistance()) @@ -107,7 +108,7 @@ public class KNNRegressionTest { KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit( new LocalDatasetBuilder<>(data, parts), - (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), (k, v) -> v[0] ).withK(3) .withDistanceMeasure(new EuclideanDistance()) @@ -142,7 +143,7 @@ public class KNNRegressionTest { KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit( new LocalDatasetBuilder<>(data, parts), - (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), (k, v) -> v[0] ).withK(3) .withDistanceMeasure(new EuclideanDistance()) @@ -152,4 +153,4 @@ public class KNNRegressionTest { System.out.println(knnMdl.apply(vector)); Assert.assertEquals(67857, knnMdl.apply(vector), 2000); } -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java index bdd1eea..e64eda4 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java @@ -23,6 +23,7 @@ import java.util.Map; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder; +import org.apache.ignite.ml.math.VectorUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -65,7 +66,7 @@ public class LSQROnHeapTest { LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>( datasetBuilder, new SimpleLabeledDatasetDataBuilder<>( - (k, v) -> Arrays.copyOf(v, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)), (k, v) -> new double[]{v[3]} ) ); @@ -88,7 +89,7 @@ public class LSQROnHeapTest { LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>( datasetBuilder, new SimpleLabeledDatasetDataBuilder<>( - (k, v) -> Arrays.copyOf(v, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)), (k, v) -> new double[]{v[3]} ) ); @@ -119,7 +120,7 @@ public class LSQROnHeapTest { try (LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>( datasetBuilder, new SimpleLabeledDatasetDataBuilder<>( - (k, v) -> Arrays.copyOf(v, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)), (k, v) -> new double[]{v[4]} ) )) { http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java index 654ebe0..bac6e5f 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java @@ -17,6 +17,7 @@ package org.apache.ignite.ml.nn; +import java.io.Serializable; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; @@ -26,15 +27,19 @@ import org.apache.ignite.internal.util.typedef.X; import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Tracer; +import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; -import org.apache.ignite.ml.optimization.updatecalculators.*; +import org.apache.ignite.ml.optimization.updatecalculators.NesterovParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; -import java.io.Serializable; - /** * Tests for {@link MLPTrainer} that require to start the whole Ignite infrastructure. */ @@ -133,7 +138,7 @@ public class MLPTrainerIntegrationTest extends GridCommonAbstractTest { MultilayerPerceptron mlp = trainer.fit( ignite, xorCache, - (k, v) -> new double[]{ v.x, v.y }, + (k, v) -> VectorUtils.of(v.x, v.y ), (k, v) -> new double[]{ v.lb} ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java index db14881..7f18465 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java @@ -17,25 +17,30 @@ package org.apache.ignite.ml.nn; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; -import org.apache.ignite.ml.optimization.updatecalculators.*; +import org.apache.ignite.ml.optimization.updatecalculators.NesterovParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; import org.junit.Before; import org.junit.Test; import org.junit.experimental.runners.Enclosed; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - /** * Tests for {@link MLPTrainer} that don't require to start the whole Ignite infrastructure. */ @@ -136,7 +141,7 @@ public class MLPTrainerTest { MultilayerPerceptron mlp = trainer.fit( xorData, parts, - (k, v) -> v[0], + (k, v) -> VectorUtils.of(v[0]), (k, v) -> v[1] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java index 3b65a28..5a26171 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java @@ -17,6 +17,7 @@ package org.apache.ignite.ml.nn.performance; +import java.io.IOException; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; @@ -28,16 +29,14 @@ import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.nn.Activators; import org.apache.ignite.ml.nn.MLPTrainer; import org.apache.ignite.ml.nn.MultilayerPerceptron; +import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; -import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.util.MnistUtils; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; -import java.io.IOException; - /** * Tests {@link MLPTrainer} on the MNIST dataset that require to start the whole Ignite infrastructure. */ @@ -106,7 +105,7 @@ public class MLPTrainerMnistIntegrationTest extends GridCommonAbstractTest { MultilayerPerceptron mdl = trainer.fit( ignite, trainingSet, - (k, v) -> v.getPixels(), + (k, v) -> VectorUtils.of(v.getPixels()), (k, v) -> VectorUtils.num2Vec(v.getLabel(), 10).getStorage().data() ); System.out.println("Training completed in " + (System.currentTimeMillis() - start) + "ms"); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java index 4063312..269082a 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java @@ -17,24 +17,23 @@ package org.apache.ignite.ml.nn.performance; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.nn.Activators; import org.apache.ignite.ml.nn.MLPTrainer; import org.apache.ignite.ml.nn.MultilayerPerceptron; +import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; -import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.util.MnistUtils; import org.junit.Test; -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - import static org.junit.Assert.assertTrue; /** @@ -76,7 +75,7 @@ public class MLPTrainerMnistTest { MultilayerPerceptron mdl = trainer.fit( trainingSet, 1, - (k, v) -> v.getPixels(), + (k, v) -> VectorUtils.of(v.getPixels()), (k, v) -> VectorUtils.num2Vec(v.getLabel(), 10).getStorage().data() ); System.out.println("Training completed in " + (System.currentTimeMillis() - start) + "ms"); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationPreprocessorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationPreprocessorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationPreprocessorTest.java index 2a4494a..a89b1aa 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationPreprocessorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationPreprocessorTest.java @@ -17,6 +17,7 @@ package org.apache.ignite.ml.preprocessing.binarization; +import org.apache.ignite.ml.math.VectorUtils; import org.junit.Test; import static org.junit.Assert.assertArrayEquals; @@ -36,7 +37,7 @@ public class BinarizationPreprocessorTest { BinarizationPreprocessor<Integer, double[]> preprocessor = new BinarizationPreprocessor<>( 7, - (k, v) -> v + (k, v) -> VectorUtils.of(v) ); double[][] postProcessedData = new double[][]{ @@ -46,6 +47,6 @@ public class BinarizationPreprocessorTest { }; for (int i = 0; i < data.length; i++) - assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]), 1e-8); + assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).asArray(), 1e-8); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java index 1922307..a7317a5 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.Map; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.math.VectorUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -67,9 +68,9 @@ public class BinarizationTrainerTest { BinarizationPreprocessor<Integer, double[]> preprocessor = binarizationTrainer.fit( datasetBuilder, - (k, v) -> v + (k, v) -> VectorUtils.of(v) ); - assertArrayEquals(new double[] {0, 0, 1}, preprocessor.apply(5, new double[] {1, 10, 100}), 1e-8); + assertArrayEquals(new double[] {0, 0, 1}, preprocessor.apply(5, new double[] {1, 10, 100}).asArray(), 1e-8); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java index d8c3aa0..f480209 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java @@ -69,6 +69,6 @@ public class StringEncoderPreprocessorTest { }; for (int i = 0; i < data.length; i++) - assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]), 1e-8); + assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).asArray(), 1e-8); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java index cc79584..4f9d757 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java @@ -75,6 +75,6 @@ public class StringEncoderTrainerTest { (k, v) -> v ); - assertArrayEquals(new double[] {0.0, 2.0}, preprocessor.apply(7, new String[] {"Monday", "September"}), 1e-8); + assertArrayEquals(new double[] {0.0, 2.0}, preprocessor.apply(7, new String[] {"Monday", "September"}).asArray(), 1e-8); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerPreprocessorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerPreprocessorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerPreprocessorTest.java index f0f56d3..8482928 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerPreprocessorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerPreprocessorTest.java @@ -17,6 +17,8 @@ package org.apache.ignite.ml.preprocessing.imputing; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.VectorUtils; import org.junit.Test; import static org.junit.Assert.assertArrayEquals; @@ -34,8 +36,8 @@ public class ImputerPreprocessorTest { {Double.NaN, Double.NaN, Double.NaN}, }; - ImputerPreprocessor<Integer, double[]> preprocessor = new ImputerPreprocessor<>( - new double[]{1.1, 10.1, 100.1}, + ImputerPreprocessor<Integer, Vector> preprocessor = new ImputerPreprocessor<>( + VectorUtils.of(1.1, 10.1, 100.1), (k, v) -> v ); @@ -46,6 +48,6 @@ public class ImputerPreprocessorTest { }; for (int i = 0; i < data.length; i++) - assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]), 1e-8); + assertArrayEquals(postProcessedData[i], preprocessor.apply(i, VectorUtils.of(data[i])).asArray(), 1e-8); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java index a4bb847..bbb9d07 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java @@ -22,6 +22,8 @@ import java.util.HashMap; import java.util.Map; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.VectorUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -54,22 +56,22 @@ public class ImputerTrainerTest { /** Tests {@code fit()} method. */ @Test public void testFit() { - Map<Integer, double[]> data = new HashMap<>(); - data.put(1, new double[] {1, 2, Double.NaN,}); - data.put(2, new double[] {1, Double.NaN, 22}); - data.put(3, new double[] {Double.NaN, 10, 100}); - data.put(4, new double[] {0, 2, 100}); + Map<Integer, Vector> data = new HashMap<>(); + data.put(1, VectorUtils.of(1, 2, Double.NaN)); + data.put(2, VectorUtils.of(1, Double.NaN, 22)); + data.put(3, VectorUtils.of(Double.NaN, 10, 100)); + data.put(4, VectorUtils.of(0, 2, 100)); - DatasetBuilder<Integer, double[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts); + DatasetBuilder<Integer, Vector> datasetBuilder = new LocalDatasetBuilder<>(data, parts); - ImputerTrainer<Integer, double[]> imputerTrainer = new ImputerTrainer<Integer, double[]>() + ImputerTrainer<Integer, Vector> imputerTrainer = new ImputerTrainer<Integer, Vector>() .withImputingStrategy(ImputingStrategy.MOST_FREQUENT); - ImputerPreprocessor<Integer, double[]> preprocessor = imputerTrainer.fit( + ImputerPreprocessor<Integer, Vector> preprocessor = imputerTrainer.fit( datasetBuilder, (k, v) -> v ); - assertArrayEquals(new double[] {1, 0, 100}, preprocessor.apply(5, new double[] {Double.NaN, 0, Double.NaN}), 1e-8); + assertArrayEquals(new double[] {1, 0, 100}, preprocessor.apply(5, VectorUtils.of(Double.NaN, 0, Double.NaN)).asArray(), 1e-8); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerPreprocessorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerPreprocessorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerPreprocessorTest.java index 5ce21d4..aef1587 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerPreprocessorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerPreprocessorTest.java @@ -17,6 +17,8 @@ package org.apache.ignite.ml.preprocessing.minmaxscaling; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.VectorUtils; import org.junit.Test; import static org.junit.Assert.assertArrayEquals; @@ -35,7 +37,7 @@ public class MinMaxScalerPreprocessorTest { {0., 22., 300.} }; - MinMaxScalerPreprocessor<Integer, double[]> preprocessor = new MinMaxScalerPreprocessor<>( + MinMaxScalerPreprocessor<Integer, Vector> preprocessor = new MinMaxScalerPreprocessor<>( new double[] {0, 4, 1}, new double[] {4, 22, 300}, (k, v) -> v @@ -49,6 +51,6 @@ public class MinMaxScalerPreprocessorTest { }; for (int i = 0; i < data.length; i++) - assertArrayEquals(standardData[i], preprocessor.apply(i, data[i]), 1e-8); + assertArrayEquals(standardData[i], preprocessor.apply(i, VectorUtils.of(data[i])).asArray(), 1e-8); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java index e411dca..8d3681b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java @@ -17,16 +17,17 @@ package org.apache.ignite.ml.preprocessing.minmaxscaling; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.VectorUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; - import static org.junit.Assert.assertArrayEquals; /** @@ -55,17 +56,17 @@ public class MinMaxScalerTrainerTest { /** Tests {@code fit()} method. */ @Test public void testFit() { - Map<Integer, double[]> data = new HashMap<>(); - data.put(1, new double[] {2, 4, 1}); - data.put(2, new double[] {1, 8, 22}); - data.put(3, new double[] {4, 10, 100}); - data.put(4, new double[] {0, 22, 300}); + Map<Integer, Vector> data = new HashMap<>(); + data.put(1, VectorUtils.of(2, 4, 1)); + data.put(2, VectorUtils.of(1, 8, 22)); + data.put(3, VectorUtils.of(4, 10, 100)); + data.put(4, VectorUtils.of(0, 22, 300)); - DatasetBuilder<Integer, double[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts); + DatasetBuilder<Integer, Vector> datasetBuilder = new LocalDatasetBuilder<>(data, parts); - MinMaxScalerTrainer<Integer, double[]> standardizationTrainer = new MinMaxScalerTrainer<>(); + MinMaxScalerTrainer<Integer, Vector> standardizationTrainer = new MinMaxScalerTrainer<>(); - MinMaxScalerPreprocessor<Integer, double[]> preprocessor = standardizationTrainer.fit( + MinMaxScalerPreprocessor<Integer, Vector> preprocessor = standardizationTrainer.fit( datasetBuilder, (k, v) -> v ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationPreprocessorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationPreprocessorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationPreprocessorTest.java index f3bf81f..a8bfd28 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationPreprocessorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationPreprocessorTest.java @@ -17,6 +17,8 @@ package org.apache.ignite.ml.preprocessing.normalization; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.preprocessing.binarization.BinarizationPreprocessor; import org.junit.Test; @@ -35,7 +37,7 @@ public class NormalizationPreprocessorTest { {1, 0, 0}, }; - NormalizationPreprocessor<Integer, double[]> preprocessor = new NormalizationPreprocessor<>( + NormalizationPreprocessor<Integer, Vector> preprocessor = new NormalizationPreprocessor<>( 1, (k, v) -> v ); @@ -47,6 +49,6 @@ public class NormalizationPreprocessorTest { }; for (int i = 0; i < data.length; i++) - assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]), 1e-2); + assertArrayEquals(postProcessedData[i], preprocessor.apply(i, VectorUtils.of(data[i])).asArray(), 1e-2); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java index ef86b07..f6be0f5 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java @@ -22,7 +22,7 @@ import java.util.HashMap; import java.util.Map; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; -import org.apache.ignite.ml.preprocessing.binarization.BinarizationPreprocessor; +import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.preprocessing.binarization.BinarizationTrainer; import org.junit.Test; import org.junit.runner.RunWith; @@ -69,9 +69,9 @@ public class NormalizationTrainerTest { NormalizationPreprocessor<Integer, double[]> preprocessor = normalizationTrainer.fit( datasetBuilder, - (k, v) -> v + (k, v) -> VectorUtils.of(v) ); - assertArrayEquals(new double[] {0.125, 0.99, 0.125}, preprocessor.apply(5, new double[] {1, 8, 1}), 1e-2); + assertArrayEquals(new double[] {0.125, 0.99, 0.125}, preprocessor.apply(5, new double[]{1., 8., 1.}).asArray(), 1e-2); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java index ac0117d..f2f264b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.Random; +import org.apache.ignite.ml.math.VectorUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -73,7 +74,7 @@ public class LinearRegressionLSQRTrainerTest { LinearRegressionModel mdl = trainer.fit( data, parts, - (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[4] ); @@ -112,7 +113,7 @@ public class LinearRegressionLSQRTrainerTest { LinearRegressionModel mdl = trainer.fit( data, parts, - (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[coef.length] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java index c62cca5..7c3cef1 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java @@ -17,17 +17,17 @@ package org.apache.ignite.ml.regressions.linear; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; -import org.apache.ignite.ml.nn.UpdatesStrategy; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; - import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @@ -79,7 +79,7 @@ public class LinearRegressionSGDTrainerTest { LinearRegressionModel mdl = trainer.fit( data, parts, - (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[4] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java index d26a4ca..b2d5e63 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.ThreadLocalRandom; import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.SmoothParametrized; @@ -88,7 +89,7 @@ public class LogRegMultiClassTrainerTest { LogRegressionMultiClassModel mdl = trainer.fit( data, 10, - (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), (k, v) -> v[0] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java index 27d3a30e..cbaab37 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java @@ -17,7 +17,12 @@ package org.apache.ignite.ml.regressions.logistic; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ThreadLocalRandom; import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; @@ -28,11 +33,6 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.ThreadLocalRandom; - /** * Tests for {@LogisticRegressionSGDTrainer}. */ @@ -93,7 +93,7 @@ public class LogisticRegressionSGDTrainerTest { LogisticRegressionModel mdl = trainer.fit( data, 10, - (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), (k, v) -> v[0] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java index f0d9f41..cc69074 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java @@ -18,10 +18,9 @@ package org.apache.ignite.ml.selection; import org.apache.ignite.ml.selection.cv.CrossValidationTest; -import org.apache.ignite.ml.selection.scoring.metric.AccuracyTest; import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursorTest; import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursorTest; -import org.apache.ignite.ml.selection.scoring.metric.Fmeasure; +import org.apache.ignite.ml.selection.scoring.metric.AccuracyTest; import org.apache.ignite.ml.selection.scoring.metric.FmeasureTest; import org.apache.ignite.ml.selection.scoring.metric.PrecisionTest; import org.apache.ignite.ml.selection.scoring.metric.RecallTest; http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java index f2fc76e..1980489 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.selection.cv; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; @@ -51,7 +52,7 @@ public class CrossValidationTest { new Accuracy<>(), data, 1, - (k, v) -> new double[]{k}, + (k, v) -> VectorUtils.of(k), (k, v) -> v, folds ); @@ -82,7 +83,7 @@ public class CrossValidationTest { new Accuracy<>(), data, 1, - (k, v) -> new double[]{k}, + (k, v) -> VectorUtils.of(k), (k, v) -> v, folds ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursorTest.java index 1ce10b1..7ad3998 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursorTest.java @@ -21,6 +21,7 @@ import java.util.UUID; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.selection.scoring.LabelPair; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; @@ -63,7 +64,7 @@ public class CacheBasedLabelPairCursorTest extends GridCommonAbstractTest { LabelPairCursor<Integer> cursor = new CacheBasedLabelPairCursor<>( data, (k, v) -> v % 2 == 0, - (k, v) -> new double[]{v}, + (k, v) -> VectorUtils.of(v), (k, v) -> v, vec -> (int)vec.get(0) ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursorTest.java index a5a6321..f998dc9 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursorTest.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.selection.scoring.cursor; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.selection.scoring.LabelPair; import org.junit.Test; @@ -39,7 +40,7 @@ public class LocalLabelPairCursorTest { LabelPairCursor<Integer> cursor = new LocalLabelPairCursor<>( data, (k, v) -> v % 2 == 0, - (k, v) -> new double[]{v}, + (k, v) -> VectorUtils.of(v), (k, v) -> v, vec -> (int)vec.get(0) ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/AccuracyTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/AccuracyTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/AccuracyTest.java index 7ebee1a..de7c68a 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/AccuracyTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/AccuracyTest.java @@ -19,8 +19,6 @@ package org.apache.ignite.ml.selection.scoring.metric; import java.util.Arrays; import org.apache.ignite.ml.selection.scoring.TestLabelPairCursor; -import org.apache.ignite.ml.selection.scoring.metric.Accuracy; -import org.apache.ignite.ml.selection.scoring.metric.Metric; import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor; import org.junit.Test; http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java index 0befd9b..d37bd47 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java @@ -17,14 +17,14 @@ package org.apache.ignite.ml.svm; -import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.junit.Test; - import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ThreadLocalRandom; +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.junit.Test; /** * Tests for {@link SVMLinearBinaryClassificationTrainer}. @@ -64,7 +64,7 @@ public class SVMBinaryTrainerTest { SVMLinearBinaryClassificationModel mdl = trainer.fit( data, 10, - (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), (k, v) -> v[0] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java index 31ab4d7..27c0cd0 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java @@ -17,14 +17,14 @@ package org.apache.ignite.ml.svm; -import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.junit.Test; - import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ThreadLocalRandom; +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.junit.Test; /** * Tests for {@link SVMLinearBinaryClassificationTrainer}. @@ -67,7 +67,7 @@ public class SVMMultiClassTrainerTest { SVMLinearMultiClassClassificationModel mdl = trainer.fit( data, 10, - (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), (k, v) -> v[0] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java index d5b0b86..da0a702 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java @@ -17,16 +17,16 @@ package org.apache.ignite.ml.tree; +import java.util.Arrays; +import java.util.Random; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; -import java.util.Arrays; -import java.util.Random; - /** * Tests for {@link DecisionTreeClassificationTrainer} that require to start the whole Ignite infrastructure. */ @@ -79,7 +79,7 @@ public class DecisionTreeClassificationTrainerIntegrationTest extends GridCommon DecisionTreeNode tree = trainer.fit( ignite, data, - (k, v) -> Arrays.copyOf(v, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)), (k, v) -> v[v.length - 1] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java index 12ef698..109fa6e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java @@ -17,12 +17,17 @@ package org.apache.ignite.ml.tree; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import org.apache.ignite.ml.math.VectorUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import java.util.*; - import static junit.framework.TestCase.assertEquals; import static junit.framework.TestCase.assertTrue; @@ -38,6 +43,7 @@ public class DecisionTreeClassificationTrainerTest { @Parameterized.Parameter public int parts; + @Parameterized.Parameters(name = "Data divided on {0} partitions") public static Iterable<Integer[]> data() { List<Integer[]> res = new ArrayList<>(); @@ -65,7 +71,7 @@ public class DecisionTreeClassificationTrainerTest { DecisionTreeNode tree = trainer.fit( data, parts, - (k, v) -> Arrays.copyOf(v, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)), (k, v) -> v[v.length - 1] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java index c2a4638..11b75cd 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java @@ -17,16 +17,16 @@ package org.apache.ignite.ml.tree; +import java.util.Arrays; +import java.util.Random; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; -import java.util.Arrays; -import java.util.Random; - /** * Tests for {@link DecisionTreeRegressionTrainer} that require to start the whole Ignite infrastructure. */ @@ -79,7 +79,7 @@ public class DecisionTreeRegressionTrainerIntegrationTest extends GridCommonAbst DecisionTreeNode tree = trainer.fit( ignite, data, - (k, v) -> Arrays.copyOf(v, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)), (k, v) -> v[v.length - 1] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/fa56a584/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java index bcfb53f..a552f85 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java @@ -17,12 +17,17 @@ package org.apache.ignite.ml.tree; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import org.apache.ignite.ml.math.VectorUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import java.util.*; - import static junit.framework.TestCase.assertEquals; import static junit.framework.TestCase.assertTrue; @@ -65,7 +70,7 @@ public class DecisionTreeRegressionTrainerTest { DecisionTreeNode tree = trainer.fit( data, parts, - (k, v) -> Arrays.copyOf(v, v.length - 1), + (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)), (k, v) -> v[v.length - 1] );