http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerGroupTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerGroupTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerGroupTrainer.java deleted file mode 100644 index bebfe3e..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerGroupTrainer.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group; - -import java.io.Serializable; -import java.util.List; -import java.util.UUID; -import java.util.stream.Stream; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.functions.IgniteSupplier; -import org.apache.ignite.ml.trainers.group.chain.Chains; -import org.apache.ignite.ml.trainers.group.chain.ComputationsChain; -import org.apache.ignite.ml.trainers.group.chain.EntryAndContext; -import org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID; - -/** - * Group trainer using {@link Metaoptimizer}. - * Main purpose of this trainer is to extract various transformations (normalizations for example) of data which is processed - * in the training loop step into distinct entity called metaoptimizer and only fix the main part of logic in - * trainers extending this class. This way we'll be able to quickly switch between this transformations by using different metaoptimizers - * without touching main logic. - * - * @param <LC> Type of local context. - * @param <K> Type of data in {@link GroupTrainerCacheKey} keys on which the training is done. - * @param <V> Type of values of cache used in group training. - * @param <IN> Data type which is returned by distributed initializer. - * @param <R> Type of final result returned by nodes on which training is done. - * @param <I> Type of data which is fed into each training loop step and returned from it. - * @param <M> Type of model returned after training. - * @param <T> Type of input of this trainer. - * @param <G> Type of distributed context which is needed for forming final result which is send from each node to trainer for final model creation. - * @param <O> Type of output of postprocessor. - * @param <X> Type of data which is processed by dataProcessor. - * @param <Y> Type of data which is returned by postprocessor. - */ -public abstract class MetaoptimizerGroupTrainer<LC extends HasTrainingUUID, K, V, IN extends Serializable, - R extends Serializable, I extends Serializable, - M extends Model, T extends GroupTrainerInput<K>, - G, O extends Serializable, X, Y> extends - GroupTrainer<LC, K, V, IN, R, I, M, T, G> { - /** - * Metaoptimizer. - */ - private Metaoptimizer<LC, X, Y, I, IN, O> metaoptimizer; - - /** - * Construct instance of this class. - * - * @param cache Cache on which group trainer is done. - * @param ignite Ignite instance. - */ - public MetaoptimizerGroupTrainer(Metaoptimizer<LC, X, Y, I, IN, O> metaoptimizer, - IgniteCache<GroupTrainerCacheKey<K>, V> cache, - Ignite ignite) { - super(cache, ignite); - this.metaoptimizer = metaoptimizer; - } - - /** - * Get function used to map EntryAndContext to type which is processed by dataProcessor. - * - * @return Function used to map EntryAndContext to type which is processed by dataProcessor. - */ - protected abstract IgniteFunction<EntryAndContext<K, V, G>, X> trainingLoopStepDataExtractor(); - - /** - * Get supplier of keys which should be processed by training loop. - * - * @param locCtx Local text. - * @return Supplier of keys which should be processed by training loop. - */ - protected abstract IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keysToProcessInTrainingLoop(LC locCtx); - - /** - * Get supplier of context used in training loop step. - * - * @param input Input. - * @param ctx Local context. - * @return Supplier of context used in training loop step. - */ - protected abstract IgniteSupplier<G> remoteContextExtractor(I input, LC ctx); - - /** {@inheritDoc} */ - @Override protected void init(T data, UUID trainingUUID) { - } - - /** - * Get function used to process data in training loop step. - * - * @return Function used to process data in training loop step. - */ - protected abstract IgniteFunction<X, ResultAndUpdates<Y>> dataProcessor(); - - /** {@inheritDoc} */ - @Override protected ComputationsChain<LC, K, V, I, I> trainingLoopStep() { - ComputationsChain<LC, K, V, I, O> chain = Chains.create(new MetaoptimizerDistributedStep<>(metaoptimizer, this)); - return chain.thenLocally(metaoptimizer::localProcessor); - } - - /** {@inheritDoc} */ - @Override protected I locallyProcessInitData(IN data, LC locCtx) { - return metaoptimizer.locallyProcessInitData(data, locCtx); - } - - /** {@inheritDoc} */ - @Override protected boolean shouldContinue(I data, LC locCtx) { - return metaoptimizer.shouldContinue(data, locCtx); - } - - /** {@inheritDoc} */ - @Override protected IgniteFunction<List<IN>, IN> reduceDistributedInitData() { - return metaoptimizer.initialReducer(); - } -}
http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ResultAndUpdates.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ResultAndUpdates.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ResultAndUpdates.java deleted file mode 100644 index 9ed18af..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ResultAndUpdates.java +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group; - -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Collectors; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.ml.math.functions.IgniteFunction; - -/** - * Class containing result of computation and updates which should be made for caches. - * Purpose of this class is mainly performance optimization: suppose we have multiple computations which run in parallel - * and do some updates to caches. It is more efficient to collect all changes from all this computations and perform them - * in batch. - * - * @param <R> Type of computation result. - */ -public class ResultAndUpdates<R> { - /** - * Result of computation. - */ - private R res; - - /** - * Updates in the form cache name -> (key -> new value). - */ - private Map<String, Map> updates = new ConcurrentHashMap<>(); - - /** - * Construct an instance of this class. - * - * @param res Computation result. - */ - public ResultAndUpdates(R res) { - this.res = res; - } - - /** - * Construct an instance of this class. - * - * @param res Computation result. - * @param updates Map of updates in the form cache name -> (key -> new value). - */ - ResultAndUpdates(R res, Map<String, Map> updates) { - this.res = res; - this.updates = updates; - } - - /** - * Construct an empty result. - * - * @param <R> Result type. - * @return Empty result. - */ - public static <R> ResultAndUpdates<R> empty() { - return new ResultAndUpdates<>(null); - } - - /** - * Construct {@link ResultAndUpdates} object from given result. - * - * @param res Result of computation. - * @param <R> Type of result of computation. - * @return ResultAndUpdates object. - */ - public static <R> ResultAndUpdates<R> of(R res) { - return new ResultAndUpdates<>(res); - } - - /** - * Add a cache update to this object. - * - * @param cache Cache to be updated. - * @param key Key of cache to be updated. - * @param val New value. - * @param <K> Type of key of cache to be updated. - * @param <V> New value. - * @return This object. - */ - @SuppressWarnings("unchecked") - public <K, V> ResultAndUpdates<R> updateCache(IgniteCache<K, V> cache, K key, V val) { - String name = cache.getName(); - - updates.computeIfAbsent(name, s -> new ConcurrentHashMap()); - updates.get(name).put(key, val); - - return this; - } - - /** - * Get result of computation. - * - * @return Result of computation. - */ - public R result() { - return res; - } - - /** - * Sum collection of ResultAndUpdate into one: results are reduced by specified binary operator and updates are merged. - * - * @param reducer Reducer used to combine computation results. - * @param resultsAndUpdates ResultAndUpdates to be combined with. - * @param <R> Type of computation result. - * @return Sum of collection ResultAndUpdate objects. - */ - @SuppressWarnings("unchecked") - static <R> ResultAndUpdates<R> sum(IgniteFunction<List<R>, R> reducer, - Collection<ResultAndUpdates<R>> resultsAndUpdates) { - Map<String, Map> allUpdates = new HashMap<>(); - - for (ResultAndUpdates<R> ru : resultsAndUpdates) { - for (String cacheName : ru.updates.keySet()) { - allUpdates.computeIfAbsent(cacheName, s -> new HashMap()); - - allUpdates.get(cacheName).putAll(ru.updates.get(cacheName)); - } - } - - List<R> results = resultsAndUpdates.stream().map(ResultAndUpdates::result).filter(Objects::nonNull).collect(Collectors.toList()); - - return new ResultAndUpdates<>(reducer.apply(results), allUpdates); - } - - /** - * Get updates map. - * - * @return Updates map. - */ - public Map<String, Map> updates() { - return updates; - } - - /** - * Set updates map. - * - * @param updates New updates map. - * @return This object. - */ - ResultAndUpdates<R> setUpdates(Map<String, Map> updates) { - this.updates = updates; - return this; - } - - /** - * Apply updates to caches. - * - * @param ignite Ignite instance. - */ - void applyUpdates(Ignite ignite) { - for (Map.Entry<String, Map> entry : updates.entrySet()) { - IgniteCache<Object, Object> cache = ignite.getOrCreateCache(entry.getKey()); - - cache.putAll(entry.getValue()); - } - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdateStrategies.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdateStrategies.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdateStrategies.java deleted file mode 100644 index 33ec96a..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdateStrategies.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group; - -import org.apache.ignite.ml.optimization.SmoothParametrized; -import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; - -/** - * Holder class for various update strategies. - */ -public class UpdateStrategies { - /** - * Simple GD update strategy. - * - * @return GD update strategy. - */ - public static UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> GD() { - return new UpdatesStrategy<>(new SimpleGDUpdateCalculator(), SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg); - } - - /** - * RProp update strategy. - * - * @return RProp update strategy. - */ - public static UpdatesStrategy<SmoothParametrized, RPropParameterUpdate> RProp() { - return new UpdatesStrategy<>(new RPropUpdateCalculator(), RPropParameterUpdate::sumLocal, RPropParameterUpdate::avg); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdatesStrategy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdatesStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdatesStrategy.java deleted file mode 100644 index 5288dbf..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdatesStrategy.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group; - -import java.io.Serializable; -import java.util.List; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator; - -/** - * Class encapsulating update strategies for group trainers based on updates. - * - * @param <M> Type of model to be optimized. - * @param <U> Type of update. - */ -public class UpdatesStrategy<M, U extends Serializable> { - /** - * {@link ParameterUpdateCalculator}. - */ - private ParameterUpdateCalculator<M, U> updatesCalculator; - - /** - * Function used to reduce updates in one training (for example, sum all sequential gradient updates to get one - * gradient update). - */ - private IgniteFunction<List<U>, U> locStepUpdatesReducer; - - /** - * Function used to reduce updates from different trainings (for example, averaging of gradients of all parallel trainings). - */ - private IgniteFunction<List<U>, U> allUpdatesReducer; - - /** - * Construct instance of this class with given parameters. - * - * @param updatesCalculator Parameter update calculator. - * @param locStepUpdatesReducer Function used to reduce updates in one training - * (for example, sum all sequential gradient updates to get one gradient update). - * @param allUpdatesReducer Function used to reduce updates from different trainings - * (for example, averaging of gradients of all parallel trainings). - */ - public UpdatesStrategy( - ParameterUpdateCalculator<M, U> updatesCalculator, - IgniteFunction<List<U>, U> locStepUpdatesReducer, - IgniteFunction<List<U>, U> allUpdatesReducer) { - this.updatesCalculator = updatesCalculator; - this.locStepUpdatesReducer = locStepUpdatesReducer; - this.allUpdatesReducer = allUpdatesReducer; - } - - /** - * Get parameter update calculator (see {@link ParameterUpdateCalculator}). - * - * @return Parameter update calculator. - */ - public ParameterUpdateCalculator<M, U> getUpdatesCalculator() { - return updatesCalculator; - } - - /** - * Get function used to reduce updates in one training - * (for example, sum all sequential gradient updates to get one gradient update). - * - * @return Function used to reduce updates in one training - * (for example, sum all sequential gradient updates to get on gradient update). - */ - public IgniteFunction<List<U>, U> locStepUpdatesReducer() { - return locStepUpdatesReducer; - } - - /** - * Get function used to reduce updates from different trainings - * (for example, averaging of gradients of all parallel trainings). - * - * @return Function used to reduce updates from different trainings. - */ - public IgniteFunction<List<U>, U> allUpdatesReducer() { - return allUpdatesReducer; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/Chains.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/Chains.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/Chains.java deleted file mode 100644 index db4f13f..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/Chains.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group.chain; - -import java.io.Serializable; - -/** - * Class containing methods creating {@link ComputationsChain}. - */ -public class Chains { - /** - * Create computation chain consisting of one returning its input as output. - * - * @param <L> Type of local context of created chain. - * @param <K> Type of keys of cache used in computation chain. - * @param <V> Type of values of cache used in computation chain. - * @param <I> Type of input to computation chain. - * @return Computation chain consisting of one returning its input as output. - */ - public static <L extends HasTrainingUUID, K, V, I> ComputationsChain<L, K, V, I, I> create() { - return (input, context) -> input; - } - - /** - * Create {@link ComputationsChain} from {@link DistributedEntryProcessingStep}. - * - * @param step Distributed chain step. - * @param <L> Type of local context of created chain. - * @param <K> Type of keys of cache used in computation chain. - * @param <V> Type of values of cache used in computation chain. - * @param <C> Type of context used by worker in {@link DistributedEntryProcessingStep}. - * @param <I> Type of input to computation chain. - * @param <O> Type of output of computation chain. - * @return Computation created from {@link DistributedEntryProcessingStep}. - */ - public static <L extends HasTrainingUUID, K, V, C, I, O extends Serializable> ComputationsChain<L, K, V, I, O> create( - DistributedEntryProcessingStep<L, K, V, C, I, O> step) { - ComputationsChain<L, K, V, I, I> chain = create(); - return chain.thenDistributedForEntries(step); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java deleted file mode 100644 index 3c3bdab..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java +++ /dev/null @@ -1,246 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group.chain; - -import java.io.Serializable; -import java.util.List; -import java.util.UUID; -import java.util.stream.Stream; -import javax.cache.processor.EntryProcessor; -import org.apache.ignite.Ignite; -import org.apache.ignite.cluster.ClusterGroup; -import org.apache.ignite.lang.IgniteBiPredicate; -import org.apache.ignite.ml.math.functions.Functions; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.functions.IgniteSupplier; -import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey; -import org.apache.ignite.ml.trainers.group.GroupTrainerEntriesProcessorTask; -import org.apache.ignite.ml.trainers.group.GroupTrainerKeysProcessorTask; -import org.apache.ignite.ml.trainers.group.GroupTrainingContext; -import org.apache.ignite.ml.trainers.group.ResultAndUpdates; - -/** - * This class encapsulates convenient way for creating computations chain for distributed model training. - * Chain is meant in the sense that output of each non-final computation is fed as input to next computation. - * Chain is basically a bi-function from context and input to output, context is separated from input - * because input is specific to each individual step and context is something which is convenient to have access to in each of steps. - * Context is separated into two parts: local context and remote context. - * There are two kinds of computations: local and distributed. - * Local steps are just functions from two arguments: input and local context. - * Distributed steps are more sophisticated, but basically can be thought as functions of form - * localContext -> (function of remote context -> output), locally we fix local context and get function - * (function of remote context -> output) which is executed distributed. - * Chains are composable through 'then' method. - * - * @param <L> Type of local context. - * @param <K> Type of cache keys. - * @param <V> Type of cache values. - * @param <I> Type of input of this chain. - * @param <O> Type of output of this chain. - * // TODO: IGNITE-7405 check if it is possible to integrate with {@link EntryProcessor}. - */ -@FunctionalInterface -public interface ComputationsChain<L extends HasTrainingUUID, K, V, I, O> { - /** - * Process given input and {@link GroupTrainingContext}. - * - * @param input Computation chain input. - * @param ctx {@link GroupTrainingContext}. - * @return Result of processing input and context. - */ - O process(I input, GroupTrainingContext<K, V, L> ctx); - - /** - * Add a local step to this chain. - * - * @param locStep Local step. - * @param <O1> Output of local step. - * @return Composition of this chain and local step. - */ - default <O1> ComputationsChain<L, K, V, I, O1> thenLocally(IgniteBiFunction<O, L, O1> locStep) { - ComputationsChain<L, K, V, O, O1> nextStep = (input, context) -> locStep.apply(input, context.localContext()); - return then(nextStep); - } - - /** - * Add a distributed step which works in the following way: - * 1. apply local context and input to local context extractor and keys supplier to get corresponding suppliers; - * 2. on each node_n - * 2.1. get context object. - * 2.2. for each entry_i e located on node_n with key_i from keys stream compute worker((context, entry_i)) and get - * (cachesUpdates_i, result_i). - * 2.3. for all i on node_n merge cacheUpdates_i and apply them. - * 2.4. for all i on node_n, reduce result_i into result_n. - * 3. get all result_n, reduce them into result and return result. - * - * @param <O1> Type of worker output. - * @param <G> Type of context used by worker. - * @param workerCtxExtractor Extractor of context for worker. - * @param worker Function computed on each entry of cache used for training. Second argument is context: - * common part of data which is independent from key. - * @param ks Function from chain input and local context to supplier of keys for worker. - * @param reducer Function used for reducing results of worker. - * @return Combination of this chain and distributed step specified by given parameters. - */ - default <O1 extends Serializable, G> ComputationsChain<L, K, V, I, O1> thenDistributedForEntries( - IgniteBiFunction<O, L, IgniteSupplier<G>> workerCtxExtractor, - IgniteFunction<EntryAndContext<K, V, G>, ResultAndUpdates<O1>> worker, - IgniteBiFunction<O, L, IgniteSupplier<Stream<GroupTrainerCacheKey<K>>>> ks, - IgniteFunction<List<O1>, O1> reducer) { - ComputationsChain<L, K, V, O, O1> nextStep = (input, context) -> { - L locCtx = context.localContext(); - IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keysSupplier = ks.apply(input, locCtx); - - Ignite ignite = context.ignite(); - UUID trainingUUID = context.localContext().trainingUUID(); - String cacheName = context.cache().getName(); - ClusterGroup grp = ignite.cluster().forDataNodes(cacheName); - - // Apply first two arguments locally because it is common for all nodes. - IgniteSupplier<G> extractor = Functions.curry(workerCtxExtractor).apply(input).apply(locCtx); - - return ignite.compute(grp).execute(new GroupTrainerEntriesProcessorTask<>(trainingUUID, extractor, worker, keysSupplier, reducer, cacheName, ignite), null); - }; - return then(nextStep); - } - - /** - * Add a distributed step which works in the following way: - * 1. apply local context and input to local context extractor and keys supplier to get corresponding suppliers; - * 2. on each node_n - * 2.1. get context object. - * 2.2. for each key_i from keys stream such that key_i located on node_n compute worker((context, entry_i)) and get - * (cachesUpdates_i, result_i). - * 2.3. for all i on node_n merge cacheUpdates_i and apply them. - * 2.4. for all i on node_n, reduce result_i into result_n. - * 3. get all result_n, reduce them into result and return result. - * - * @param <O1> Type of worker output. - * @param <G> Type of context used by worker. - * @param workerCtxExtractor Extractor of context for worker. - * @param worker Function computed on each entry of cache used for training. Second argument is context: - * common part of data which is independent from key. - * @param keysSupplier Function from chain input and local context to supplier of keys for worker. - * @param reducer Function used for reducing results of worker. - * @return Combination of this chain and distributed step specified by given parameters. - */ - default <O1 extends Serializable, G> ComputationsChain<L, K, V, I, O1> thenDistributedForKeys( - IgniteBiFunction<O, L, IgniteSupplier<G>> workerCtxExtractor, - IgniteFunction<KeyAndContext<K, G>, ResultAndUpdates<O1>> worker, - IgniteBiFunction<O, L, IgniteSupplier<Stream<GroupTrainerCacheKey<K>>>> keysSupplier, - IgniteFunction<List<O1>, O1> reducer) { - ComputationsChain<L, K, V, O, O1> nextStep = (input, context) -> { - L locCtx = context.localContext(); - IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> ks = keysSupplier.apply(input, locCtx); - - Ignite ignite = context.ignite(); - UUID trainingUUID = context.localContext().trainingUUID(); - String cacheName = context.cache().getName(); - ClusterGroup grp = ignite.cluster().forDataNodes(cacheName); - - // Apply first argument locally because it is common for all nodes. - IgniteSupplier<G> extractor = Functions.curry(workerCtxExtractor).apply(input).apply(locCtx); - - return ignite.compute(grp).execute(new GroupTrainerKeysProcessorTask<>(trainingUUID, extractor, worker, ks, reducer, cacheName, ignite), null); - }; - return then(nextStep); - } - - /** - * Add a distributed step specified by {@link DistributedEntryProcessingStep}. - * - * @param step Distributed step. - * @param <O1> Type of output of distributed step. - * @param <G> Type of context of distributed step. - * @return Combination of this chain and distributed step specified by input. - */ - default <O1 extends Serializable, G> ComputationsChain<L, K, V, I, O1> thenDistributedForEntries( - DistributedEntryProcessingStep<L, K, V, G, O, O1> step) { - return thenDistributedForEntries(step::remoteContextSupplier, step.worker(), step::keys, step.reducer()); - } - - /** - * Add a distributed step specified by {@link DistributedKeyProcessingStep}. - * - * @param step Distributed step. - * @param <O1> Type of output of distributed step. - * @param <G> Type of context of distributed step. - * @return Combination of this chain and distributed step specified by input. - */ - default <O1 extends Serializable, G> ComputationsChain<L, K, V, I, O1> thenDistributedForKeys( - DistributedKeyProcessingStep<L, K, G, O, O1> step) { - return thenDistributedForKeys(step::remoteContextSupplier, step.worker(), step::keys, step.reducer()); - } - - /** - * Version of 'thenDistributedForKeys' where worker does not depend on context. - * - * @param worker Worker. - * @param kf Function providing supplier - * @param reducer Function from chain input and local context to supplier of keys for worker. - * @param <O1> Type of worker output. - * @return Combination of this chain and distributed step specified by given parameters. - */ - default <O1 extends Serializable> ComputationsChain<L, K, V, I, O1> thenDistributedForKeys( - IgniteFunction<GroupTrainerCacheKey<K>, ResultAndUpdates<O1>> worker, - IgniteBiFunction<O, L, IgniteSupplier<Stream<GroupTrainerCacheKey<K>>>> kf, - IgniteFunction<List<O1>, O1> reducer) { - - return thenDistributedForKeys((o, lc) -> () -> o, (context) -> worker.apply(context.key()), kf, reducer); - } - - /** - * Combine this computation chain with other computation chain in the following way: - * 1. perform this calculations chain and get result r. - * 2. while 'cond(r)' is true, r = otherChain(r, context) - * 3. return r. - * - * @param cond Condition checking if 'while' loop should continue. - * @param otherChain Chain to be combined with this chain. - * @return Combination of this chain and otherChain. - */ - default ComputationsChain<L, K, V, I, O> thenWhile(IgniteBiPredicate<O, L> cond, - ComputationsChain<L, K, V, O, O> otherChain) { - ComputationsChain<L, K, V, I, O> me = this; - return (input, context) -> { - O res = me.process(input, context); - - while (cond.apply(res, context.localContext())) - res = otherChain.process(res, context); - - return res; - }; - } - - /** - * Combine this chain with other: feed this chain as input to other, pass same context as second argument to both chains - * process method. - * - * @param next Next chain. - * @param <O1> Type of next chain output. - * @return Combined chain. - */ - default <O1> ComputationsChain<L, K, V, I, O1> then(ComputationsChain<L, K, V, O, O1> next) { - ComputationsChain<L, K, V, I, O> me = this; - return (input, context) -> { - O myRes = me.process(input, context); - return next.process(myRes, context); - }; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedEntryProcessingStep.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedEntryProcessingStep.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedEntryProcessingStep.java deleted file mode 100644 index 8fd1264..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedEntryProcessingStep.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group.chain; - -import java.io.Serializable; - -/** - * {@link DistributedStep} specialized to {@link EntryAndContext}. - * - * @param <L> Local context. - * @param <K> Type of keys of cache used for group training. - * @param <V> Type of values of cache used for group training. - * @param <C> Context used by worker. - * @param <I> Type of input to this step. - * @param <O> Type of output of this step. - */ -public interface DistributedEntryProcessingStep<L, K, V, C, I, O extends Serializable> extends - DistributedStep<EntryAndContext<K, V, C>, L, K, C, I, O> { -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedKeyProcessingStep.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedKeyProcessingStep.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedKeyProcessingStep.java deleted file mode 100644 index fb8d867..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedKeyProcessingStep.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group.chain; - -import java.io.Serializable; - -/** - * {@link DistributedStep} specialized to {@link KeyAndContext}. - * - * @param <L> Local context. - * @param <K> Type of keys of cache used for group training. - * @param <C> Context used by worker. - * @param <I> Type of input to this step. - * @param <O> Type of output of this step. - */ -public interface DistributedKeyProcessingStep<L, K, C, I, O extends Serializable> extends - DistributedStep<KeyAndContext<K, C>, L, K, C, I, O> { -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedStep.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedStep.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedStep.java deleted file mode 100644 index 804a886..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedStep.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group.chain; - -import java.io.Serializable; -import java.util.List; -import java.util.stream.Stream; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.functions.IgniteSupplier; -import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey; -import org.apache.ignite.ml.trainers.group.ResultAndUpdates; - -/** - * Class encapsulating logic of distributed step in {@link ComputationsChain}. - * - * @param <T> Type of elements to be processed by worker. - * @param <L> Local context. - * @param <K> Type of keys of cache used for group training. - * @param <C> Context used by worker. - * @param <I> Type of input to this step. - * @param <O> Type of output of this step. - */ -public interface DistributedStep<T, L, K, C, I, O extends Serializable> { - /** - * Create supplier of context used by worker. - * - * @param input Input. - * @param locCtx Local context. - * @return Context used by worker. - */ - IgniteSupplier<C> remoteContextSupplier(I input, L locCtx); - - /** - * Get function applied to each cache element specified by keys. - * - * @return Function applied to each cache entry specified by keys.. - */ - IgniteFunction<T, ResultAndUpdates<O>> worker(); - - /** - * Get supplier of keys for worker. - * - * @param input Input to this step. - * @param locCtx Local context. - * @return Keys for worker. - */ - IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keys(I input, L locCtx); - - /** - * Get function used to reduce results returned by worker. - * - * @return Function used to reduce results returned by worker.. - */ - IgniteFunction<List<O>, O> reducer(); -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/EntryAndContext.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/EntryAndContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/EntryAndContext.java deleted file mode 100644 index 59c3b34..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/EntryAndContext.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group.chain; - -import java.util.Map; -import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey; - -/** - * Entry of cache used for group training and context. - * This class is used as input for workers of distributed steps of {@link ComputationsChain}. - * - * @param <K> Type of cache keys used for training. - * @param <V> Type of cache values used for training. - * @param <C> Type of context. - */ -public class EntryAndContext<K, V, C> { - /** - * Entry of cache used for training. - */ - private Map.Entry<GroupTrainerCacheKey<K>, V> entry; - - /** - * Context. - */ - private C ctx; - - /** - * Construct instance of this class. - * - * @param entry Entry of cache used for training. - * @param ctx Context. - */ - public EntryAndContext(Map.Entry<GroupTrainerCacheKey<K>, V> entry, C ctx) { - this.entry = entry; - this.ctx = ctx; - } - - /** - * Get entry of cache used for training. - * - * @return Entry of cache used for training. - */ - public Map.Entry<GroupTrainerCacheKey<K>, V> entry() { - return entry; - } - - /** - * Get context. - * - * @return Context. - */ - public C context() { - return ctx; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/HasTrainingUUID.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/HasTrainingUUID.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/HasTrainingUUID.java deleted file mode 100644 index d855adf..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/HasTrainingUUID.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group.chain; - -import java.util.UUID; - -/** - * Interface for classes which contain UUID of training. - */ -public interface HasTrainingUUID { - /** - * Get training UUID. - * - * @return Training UUID. - */ - UUID trainingUUID(); -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/KeyAndContext.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/KeyAndContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/KeyAndContext.java deleted file mode 100644 index ba36e0e..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/KeyAndContext.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group.chain; - -import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey; - -/** - * Class containing key and remote context (see explanation of remote context in {@link ComputationsChain}). - * - * @param <K> Cache key type. - * @param <C> Remote context. - */ -public class KeyAndContext<K, C> { - /** - * Key of group trainer. - */ - private GroupTrainerCacheKey<K> key; - - /** - * Remote context. - */ - private C ctx; - - /** - * Construct instance of this class. - * - * @param key Cache key. - * @param ctx Remote context. - */ - public KeyAndContext(GroupTrainerCacheKey<K> key, C ctx) { - this.key = key; - this.ctx = ctx; - } - - /** - * Get group trainer cache key. - * - * @return Group trainer cache key. - */ - public GroupTrainerCacheKey<K> key() { - return key; - } - - /** - * Get remote context. - * - * @return Remote context. - */ - public C context() { - return ctx; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/package-info.java deleted file mode 100644 index 46dcc6b..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * Contains classes related to computations chain. - */ -package org.apache.ignite.ml.trainers.group.chain; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/package-info.java deleted file mode 100644 index 9b7f7cd..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * Contains group trainers. - */ -package org.apache.ignite.ml.trainers.group; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java index 9900f85..0c3408e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java @@ -23,11 +23,9 @@ import org.apache.ignite.ml.genetic.GAGridTestSuite; import org.apache.ignite.ml.knn.KNNTestSuite; import org.apache.ignite.ml.math.MathImplMainTestSuite; import org.apache.ignite.ml.nn.MLPTestSuite; -import org.apache.ignite.ml.optimization.OptimizationTestSuite; import org.apache.ignite.ml.preprocessing.PreprocessingTestSuite; import org.apache.ignite.ml.regressions.RegressionsTestSuite; import org.apache.ignite.ml.svm.SVMTestSuite; -import org.apache.ignite.ml.trainers.group.TrainersGroupTestSuite; import org.apache.ignite.ml.tree.DecisionTreeTestSuite; import org.junit.runner.RunWith; import org.junit.runners.Suite; @@ -45,8 +43,6 @@ import org.junit.runners.Suite; KNNTestSuite.class, LocalModelsTest.class, MLPTestSuite.class, - TrainersGroupTestSuite.class, - OptimizationTestSuite.class, DatasetTestSuite.class, PreprocessingTestSuite.class, GAGridTestSuite.class http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java index ec9fdaa..bdd1eea 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java @@ -22,7 +22,7 @@ import java.util.HashMap; import java.util.Map; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; -import org.apache.ignite.ml.math.isolve.LinSysPartitionDataBuilderOnHeap; +import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -64,9 +64,9 @@ public class LSQROnHeapTest { LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>( datasetBuilder, - new LinSysPartitionDataBuilderOnHeap<>( + new SimpleLabeledDatasetDataBuilder<>( (k, v) -> Arrays.copyOf(v, v.length - 1), - (k, v) -> v[3] + (k, v) -> new double[]{v[3]} ) ); @@ -87,9 +87,9 @@ public class LSQROnHeapTest { LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>( datasetBuilder, - new LinSysPartitionDataBuilderOnHeap<>( + new SimpleLabeledDatasetDataBuilder<>( (k, v) -> Arrays.copyOf(v, v.length - 1), - (k, v) -> v[3] + (k, v) -> new double[]{v[3]} ) ); @@ -118,9 +118,9 @@ public class LSQROnHeapTest { try (LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>( datasetBuilder, - new LinSysPartitionDataBuilderOnHeap<>( + new SimpleLabeledDatasetDataBuilder<>( (k, v) -> Arrays.copyOf(v, v.length - 1), - (k, v) -> v[4] + (k, v) -> new double[]{v[4]} ) )) { LSQRResult res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null); http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java index 038b880..654ebe0 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java @@ -31,7 +31,6 @@ import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; import org.apache.ignite.ml.optimization.updatecalculators.*; -import org.apache.ignite.ml.trainers.group.UpdatesStrategy; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; import java.io.Serializable; http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java index c53f6f1..db14881 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java @@ -24,7 +24,6 @@ import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; import org.apache.ignite.ml.optimization.updatecalculators.*; -import org.apache.ignite.ml.trainers.group.UpdatesStrategy; import org.junit.Before; import org.junit.Test; import org.junit.experimental.runners.Enclosed; http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java index a64af9b..3b65a28 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java @@ -32,7 +32,7 @@ import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; -import org.apache.ignite.ml.trainers.group.UpdatesStrategy; +import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.util.MnistUtils; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java index d966484..4063312 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java @@ -27,7 +27,7 @@ import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; -import org.apache.ignite.ml.trainers.group.UpdatesStrategy; +import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.util.MnistUtils; import org.junit.Test; http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/optimization/GradientDescentTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/optimization/GradientDescentTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/optimization/GradientDescentTest.java deleted file mode 100644 index f6f4775..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/optimization/GradientDescentTest.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.optimization; - -import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.junit.Test; - -/** - * Tests for {@link GradientDescent}. - */ -public class GradientDescentTest { - /** */ - private static final double PRECISION = 1e-6; - - /** - * Test gradient descent optimization on function y = x^2 with gradient function 2 * x. - */ - @Test - public void testOptimize() { - GradientDescent gradientDescent = new GradientDescent( - (inputs, groundTruth, point) -> point.times(2), - new SimpleUpdater(0.01) - ); - - Vector res = gradientDescent.optimize(new DenseLocalOnHeapMatrix(new double[1][1]), - new DenseLocalOnHeapVector(new double[]{ 2.0 })); - - TestUtils.assertEquals(0, res.get(0), PRECISION); - } - - /** - * Test gradient descent optimization on function y = (x - 2)^2 with gradient function 2 * (x - 2). - */ - @Test - public void testOptimizeWithOffset() { - GradientDescent gradientDescent = new GradientDescent( - (inputs, groundTruth, point) -> point.minus(new DenseLocalOnHeapVector(new double[]{ 2.0 })).times(2.0), - new SimpleUpdater(0.01) - ); - - Vector res = gradientDescent.optimize(new DenseLocalOnHeapMatrix(new double[1][1]), - new DenseLocalOnHeapVector(new double[]{ 2.0 })); - - TestUtils.assertEquals(2, res.get(0), PRECISION); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/optimization/OptimizationTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/optimization/OptimizationTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/optimization/OptimizationTestSuite.java deleted file mode 100644 index 0ae6e4c..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/optimization/OptimizationTestSuite.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.optimization; - -import org.apache.ignite.ml.optimization.util.SparseDistributedMatrixMapReducerTest; -import org.junit.runner.RunWith; -import org.junit.runners.Suite; - -/** - * Test suite for group trainer tests. - */ -@RunWith(Suite.class) -@Suite.SuiteClasses({ - GradientDescentTest.class, - SparseDistributedMatrixMapReducerTest.class -}) -public class OptimizationTestSuite { -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducerTest.java deleted file mode 100644 index 9017c43..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducerTest.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.optimization.util; - -import org.apache.ignite.Ignite; -import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; -import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; - -/** - * Tests for {@link SparseDistributedMatrixMapReducer}. - */ -public class SparseDistributedMatrixMapReducerTest extends GridCommonAbstractTest { - /** Number of nodes in grid */ - private static final int NODE_COUNT = 2; - - /** */ - private Ignite ignite; - - /** {@inheritDoc} */ - @Override protected void beforeTestsStarted() throws Exception { - for (int i = 1; i <= NODE_COUNT; i++) - startGrid(i); - } - - /** {@inheritDoc} */ - @Override protected void afterTestsStopped() { - stopAllGrids(); - } - - /** - * {@inheritDoc} - */ - @Override protected void beforeTest() throws Exception { - /* Grid instance. */ - ignite = grid(NODE_COUNT); - ignite.configuration().setPeerClassLoadingEnabled(true); - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - } - - /** - * Tests that matrix 100x100 filled by "1.0" and distributed across nodes successfully processed (calculate sum of - * all elements) via {@link SparseDistributedMatrixMapReducer}. - */ - public void testMapReduce() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(100, 100); - for (int i = 0; i < 100; i++) - for (int j = 0; j < 100; j++) - distributedMatrix.set(i, j, 1); - SparseDistributedMatrixMapReducer mapReducer = new SparseDistributedMatrixMapReducer(distributedMatrix); - double total = mapReducer.mapReduce( - (matrix, args) -> { - double partialSum = 0.0; - for (int i = 0; i < matrix.rowSize(); i++) - for (int j = 0; j < matrix.columnSize(); j++) - partialSum += matrix.get(i, j); - return partialSum; - }, - sums -> { - double totalSum = 0; - for (Double partialSum : sums) - if (partialSum != null) - totalSum += partialSum; - return totalSum; - }, 0.0); - assertEquals(100.0 * 100.0, total, 1e-18); - } - - /** - * Tests that matrix 100x100 filled by "1.0" and distributed across nodes successfully processed via - * {@link SparseDistributedMatrixMapReducer} even when mapping function returns {@code null}. - */ - public void testMapReduceWithNullValues() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(100, 100); - for (int i = 0; i < 100; i++) - for (int j = 0; j < 100; j++) - distributedMatrix.set(i, j, 1); - SparseDistributedMatrixMapReducer mapReducer = new SparseDistributedMatrixMapReducer(distributedMatrix); - double total = mapReducer.mapReduce( - (matrix, args) -> null, - sums -> { - double totalSum = 0; - for (Double partialSum : sums) - if (partialSum != null) - totalSum += partialSum; - return totalSum; - }, 0.0); - assertEquals(0, total, 1e-18); - } - - /** - * Tests that matrix 1x100 filled by "1.0" and distributed across nodes successfully processed (calculate sum of - * all elements) via {@link SparseDistributedMatrixMapReducer} even when not all nodes contains data. - */ - public void testMapReduceWithOneEmptyNode() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(1, 100); - for (int j = 0; j < 100; j++) - distributedMatrix.set(0, j, 1); - SparseDistributedMatrixMapReducer mapReducer = new SparseDistributedMatrixMapReducer(distributedMatrix); - double total = mapReducer.mapReduce( - (matrix, args) -> { - double partialSum = 0.0; - for (int i = 0; i < matrix.rowSize(); i++) - for (int j = 0; j < matrix.columnSize(); j++) - partialSum += matrix.get(i, j); - return partialSum; - }, - sums -> { - double totalSum = 0; - for (Double partialSum : sums) - if (partialSum != null) - totalSum += partialSum; - return totalSum; - }, 0.0); - assertEquals(100.0, total, 1e-18); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java index b3c9368..5005ef2 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java @@ -27,9 +27,6 @@ import org.junit.runners.Suite; @RunWith(Suite.class) @Suite.SuiteClasses({ LinearRegressionModelTest.class, - LocalLinearRegressionQRTrainerTest.class, - DistributedLinearRegressionQRTrainerTest.class, - BlockDistributedLinearRegressionQRTrainerTest.class, LinearRegressionLSQRTrainerTest.class, LinearRegressionSGDTrainerTest.class })