http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java index 0736906..bde4bb6 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java @@ -28,8 +28,10 @@ import org.apache.ignite.lang.IgniteBiPredicate; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; -import org.apache.ignite.ml.dataset.UpstreamTransformerChain; +import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder; import org.apache.ignite.ml.dataset.impl.cache.util.ComputeUtils; +import org.apache.ignite.ml.environment.LearningEnvironment; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.functions.IgniteBinaryOperator; import org.apache.ignite.ml.math.functions.IgniteFunction; @@ -61,8 +63,8 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose /** Filter for {@code upstream} data. */ private final IgniteBiPredicate<K, V> filter; - /** Chain of transformers applied to upstream. */ - private final UpstreamTransformerChain<K, V> upstreamTransformers; + /** Builder of transformation applied to upstream. */ + private final UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder; /** Ignite Cache with partition {@code context}. */ private final IgniteCache<Integer, C> datasetCache; @@ -73,6 +75,9 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose /** Dataset ID that is used to identify dataset in local storage on the node where computation is performed. */ private final UUID datasetId; + /** Learning environment builder. */ + private final LearningEnvironmentBuilder envBuilder; + /** * Constructs a new instance of dataset based on Ignite Cache, which is used as {@code upstream} and as reliable storage for * partition {@code context} as well. @@ -80,7 +85,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose * @param ignite Ignite instance. * @param upstreamCache Ignite Cache with {@code upstream} data. * @param filter Filter for {@code upstream} data. - * @param upstreamTransformers Transformers of upstream data (see description in {@link DatasetBuilder}). + * @param upstreamTransformerBuilder Transformer of upstream data (see description in {@link DatasetBuilder}). * @param datasetCache Ignite Cache with partition {@code context}. * @param partDataBuilder Partition {@code data} builder. * @param datasetId Dataset ID. @@ -89,39 +94,45 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter, - UpstreamTransformerChain<K, V> upstreamTransformers, - IgniteCache<Integer, C> datasetCache, PartitionDataBuilder<K, V, C, D> partDataBuilder, + UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder, + IgniteCache<Integer, C> datasetCache, + LearningEnvironmentBuilder envBuilder, + PartitionDataBuilder<K, V, C, D> partDataBuilder, UUID datasetId) { this.ignite = ignite; this.upstreamCache = upstreamCache; this.filter = filter; - this.upstreamTransformers = upstreamTransformers; + this.upstreamTransformerBuilder = upstreamTransformerBuilder; this.datasetCache = datasetCache; this.partDataBuilder = partDataBuilder; + this.envBuilder = envBuilder; this.datasetId = datasetId; } /** {@inheritDoc} */ - @Override public <R> R computeWithCtx(IgniteTriFunction<C, D, Integer, R> map, IgniteBinaryOperator<R> reduce, R identity) { + @Override public <R> R computeWithCtx(IgniteTriFunction<C, D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity) { String upstreamCacheName = upstreamCache.getName(); String datasetCacheName = datasetCache.getName(); return computeForAllPartitions(part -> { + LearningEnvironment env = ComputeUtils.getLearningEnvironment(ignite, datasetId, part, envBuilder); + C ctx = ComputeUtils.getContext(Ignition.localIgnite(), datasetCacheName, part); D data = ComputeUtils.getData( Ignition.localIgnite(), upstreamCacheName, filter, - upstreamTransformers, + upstreamTransformerBuilder, datasetCacheName, datasetId, - part, - partDataBuilder + partDataBuilder, + env ); + if (data != null) { - R res = map.apply(ctx, data, part); + R res = map.apply(ctx, data, env); // Saves partition context after update. ComputeUtils.saveContext(Ignition.localIgnite(), datasetCacheName, part, ctx); @@ -134,23 +145,24 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose } /** {@inheritDoc} */ - @Override public <R> R compute(IgniteBiFunction<D, Integer, R> map, IgniteBinaryOperator<R> reduce, R identity) { + @Override public <R> R compute(IgniteBiFunction<D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity) { String upstreamCacheName = upstreamCache.getName(); String datasetCacheName = datasetCache.getName(); return computeForAllPartitions(part -> { + LearningEnvironment env = ComputeUtils.getLearningEnvironment(Ignition.localIgnite(), datasetId, part, envBuilder); + D data = ComputeUtils.getData( Ignition.localIgnite(), upstreamCacheName, filter, - upstreamTransformers, + upstreamTransformerBuilder, datasetCacheName, datasetId, - part, - partDataBuilder + partDataBuilder, + env ); - - return data != null ? map.apply(data, part) : null; + return data != null ? map.apply(data, env) : null; }, reduce, identity); } @@ -158,6 +170,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose @Override public void close() { datasetCache.destroy(); ComputeUtils.removeData(ignite, datasetId); + ComputeUtils.removeLearningEnv(ignite, datasetId); } /**
http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java index 1d00875..be40158 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java @@ -27,9 +27,10 @@ import org.apache.ignite.lang.IgniteBiPredicate; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.PartitionContextBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; -import org.apache.ignite.ml.dataset.UpstreamTransformerChain; +import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder; import org.apache.ignite.ml.dataset.impl.cache.util.ComputeUtils; import org.apache.ignite.ml.dataset.impl.cache.util.DatasetAffinityFunctionWrapper; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; /** * A dataset builder that makes {@link CacheBasedDataset}. Encapsulate logic of building cache based dataset such as @@ -57,8 +58,8 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> { /** Filter for {@code upstream} data. */ private final IgniteBiPredicate<K, V> filter; - /** Chain of upstream transformers. */ - private final UpstreamTransformerChain<K, V> transformersChain; + /** Upstream transformer builder. */ + private final UpstreamTransformerBuilder<K, V> transformerBuilder; /** * Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset} with default @@ -79,16 +80,32 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> { * @param filter Filter for {@code upstream} data. */ public CacheBasedDatasetBuilder(Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter) { + this(ignite, upstreamCache, filter, UpstreamTransformerBuilder.identity()); + } + + /** + * Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset}. + * + * @param ignite Ignite instance. + * @param upstreamCache Ignite Cache with {@code upstream} data. + * @param filter Filter for {@code upstream} data. + */ + public CacheBasedDatasetBuilder(Ignite ignite, + IgniteCache<K, V> upstreamCache, + IgniteBiPredicate<K, V> filter, + UpstreamTransformerBuilder<K, V> transformerBuilder) { this.ignite = ignite; this.upstreamCache = upstreamCache; this.filter = filter; - transformersChain = UpstreamTransformerChain.empty(); + this.transformerBuilder = transformerBuilder; } /** {@inheritDoc} */ @SuppressWarnings("unchecked") @Override public <C extends Serializable, D extends AutoCloseable> CacheBasedDataset<K, V, C, D> build( - PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder) { + LearningEnvironmentBuilder envBuilder, + PartitionContextBuilder<K, V, C> partCtxBuilder, + PartitionDataBuilder<K, V, C, D> partDataBuilder) { UUID datasetId = UUID.randomUUID(); // Retrieves affinity function of the upstream Ignite Cache. @@ -106,25 +123,24 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> { ComputeUtils.initContext( ignite, upstreamCache.getName(), + transformerBuilder, filter, - transformersChain, datasetCache.getName(), partCtxBuilder, + envBuilder, RETRIES, RETRY_INTERVAL ); - return new CacheBasedDataset<>(ignite, upstreamCache, filter, transformersChain, datasetCache, partDataBuilder, datasetId); + return new CacheBasedDataset<>(ignite, upstreamCache, filter, transformerBuilder, datasetCache, envBuilder, partDataBuilder, datasetId); } /** {@inheritDoc} */ - @Override public UpstreamTransformerChain<K, V> upstreamTransformersChain() { - return transformersChain; + @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder<K, V> builder) { + return new CacheBasedDatasetBuilder<>(ignite, upstreamCache, filter, transformerBuilder.andThen(builder)); } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public DatasetBuilder<K, V> withFilter(IgniteBiPredicate<K, V> filterToAdd) { return new CacheBasedDatasetBuilder<>(ignite, upstreamCache, (e1, e2) -> filter.apply(e1, e2) && filterToAdd.apply(e1, e2)); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java index 4f18a18..1dc5591 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java @@ -26,6 +26,8 @@ import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.locks.LockSupport; import java.util.stream.Stream; import org.apache.ignite.Ignite; @@ -41,7 +43,10 @@ import org.apache.ignite.lang.IgniteFuture; import org.apache.ignite.ml.dataset.PartitionContextBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; -import org.apache.ignite.ml.dataset.UpstreamTransformerChain; +import org.apache.ignite.ml.dataset.UpstreamTransformer; +import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder; +import org.apache.ignite.ml.environment.LearningEnvironment; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.util.Utils; @@ -49,11 +54,12 @@ import org.apache.ignite.ml.util.Utils; * Util class that provides common methods to perform computations on top of the Ignite Compute Grid. */ public class ComputeUtils { - /** - * Template of the key used to store partition {@code data} in local storage. - */ + /** Template of the key used to store partition {@code data} in local storage. */ private static final String DATA_STORAGE_KEY_TEMPLATE = "part_data_storage_%s"; + /** Template of the key used to store partition {@link LearningEnvironment} in local storage. */ + private static final String ENVIRONMENT_STORAGE_KEY_TEMPLATE = "part_environment_storage_%s"; + /** * Calls the specified {@code fun} function on all partitions so that is't guaranteed that partitions with the same * index of all specified caches will be placed on the same node and will not be moved before computation is @@ -134,6 +140,30 @@ public class ComputeUtils { } /** + * Gets learning environment for given partition. If learning environment is not found in local node map, + * it will be created with specified {@link LearningEnvironmentBuilder}. + * + * @param ignite Ignite instance. + * @param datasetId Dataset id. + * @param part Partition index. + * @param envBuilder {@link LearningEnvironmentBuilder}. + * @return Learning environment for given partition. + */ + public static LearningEnvironment getLearningEnvironment(Ignite ignite, + UUID datasetId, + int part, + LearningEnvironmentBuilder envBuilder) { + + @SuppressWarnings("unchecked") + ConcurrentMap<Integer, LearningEnvironment> envStorage = (ConcurrentMap<Integer, LearningEnvironment>)ignite + .cluster() + .nodeLocalMap() + .computeIfAbsent(String.format(ENVIRONMENT_STORAGE_KEY_TEMPLATE, datasetId), key -> new ConcurrentHashMap<>()); + + return envStorage.computeIfAbsent(part, envBuilder::buildForWorker); + } + + /** * Extracts partition {@code data} from the local storage, if it's not found in local storage recovers this {@code * data} from a partition {@code upstream} and {@code context}. Be aware that this method should be called from * the node where partition is placed. @@ -141,11 +171,11 @@ public class ComputeUtils { * @param ignite Ignite instance. * @param upstreamCacheName Name of an {@code upstream} cache. * @param filter Filter for {@code upstream} data. - * @param transformersChain Upstream transformers. + * @param transformerBuilder Builder of upstream transformers. * @param datasetCacheName Name of a partition {@code context} cache. * @param datasetId Dataset ID. - * @param part Partition index. * @param partDataBuilder Partition data builder. + * @param env Learning environment. * @param <K> Type of a key in {@code upstream} data. * @param <V> Type of a value in {@code upstream} data. * @param <C> Type of a partition {@code context}. @@ -155,17 +185,18 @@ public class ComputeUtils { public static <K, V, C extends Serializable, D extends AutoCloseable> D getData( Ignite ignite, String upstreamCacheName, IgniteBiPredicate<K, V> filter, - UpstreamTransformerChain<K, V> transformersChain, - String datasetCacheName, - UUID datasetId, - int part, - PartitionDataBuilder<K, V, C, D> partDataBuilder) { + UpstreamTransformerBuilder<K, V> transformerBuilder, + String datasetCacheName, UUID datasetId, + PartitionDataBuilder<K, V, C, D> partDataBuilder, + LearningEnvironment env) { PartitionDataStorage dataStorage = (PartitionDataStorage)ignite .cluster() .nodeLocalMap() .computeIfAbsent(String.format(DATA_STORAGE_KEY_TEMPLATE, datasetId), key -> new PartitionDataStorage()); + final int part = env.partition(); + return dataStorage.computeDataIfAbsent(part, () -> { IgniteCache<Integer, C> learningCtxCache = ignite.cache(datasetCacheName); C ctx = learningCtxCache.get(part); @@ -177,25 +208,24 @@ public class ComputeUtils { qry.setPartition(part); qry.setFilter(filter); - UpstreamTransformerChain<K, V> chainCopy = Utils.copy(transformersChain); - chainCopy.modifySeed(s -> s + part); + UpstreamTransformer<K, V> transformer = transformerBuilder.build(env); + UpstreamTransformer<K, V> transformerCp = Utils.copy(transformer); - long cnt = computeCount(upstreamCache, qry, chainCopy); + long cnt = computeCount(upstreamCache, qry, transformer); if (cnt > 0) { try (QueryCursor<UpstreamEntry<K, V>> cursor = upstreamCache.query(qry, e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) { Iterator<UpstreamEntry<K, V>> it = cursor.iterator(); - if (!chainCopy.isEmpty()) { - Stream<UpstreamEntry<K, V>> transformedStream = chainCopy.transform(Utils.asStream(it, cnt)); - it = transformedStream.iterator(); - } + Stream<UpstreamEntry<K, V>> transformedStream = transformerCp.transform(Utils.asStream(it, cnt)); + it = transformedStream.iterator(); + Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(it, cnt, "Cache expected to be not modified during dataset data building [partition=" + part + ']'); - return partDataBuilder.build(iter, cnt, ctx); + return partDataBuilder.build(env, iter, cnt, ctx); } } @@ -214,27 +244,40 @@ public class ComputeUtils { } /** + * Remove learning environment from local cache by Dataset ID. + * + * @param ignite Ingnite instance. + * @param datasetId Dataset ID. + */ + public static void removeLearningEnv(Ignite ignite, UUID datasetId) { + ignite.cluster().nodeLocalMap().remove(String.format(ENVIRONMENT_STORAGE_KEY_TEMPLATE, datasetId)); + } + + /** * Initializes partition {@code context} by loading it from a partition {@code upstream}. - * @param <K> Type of a key in {@code upstream} data. - * @param <V> Type of a value in {@code upstream} data. - * @param <C> Type of a partition {@code context}. * @param ignite Ignite instance. * @param upstreamCacheName Name of an {@code upstream} cache. * @param filter Filter for {@code upstream} data. - * @param transformersChain Upstream data {@link Stream} transformers chain. + * @param transformerBuilder Upstream transformer builder. * @param ctxBuilder Partition {@code context} builder. + * @param envBuilder Environment builder. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @param <C> Type of a partition {@code context}. */ public static <K, V, C extends Serializable> void initContext( Ignite ignite, String upstreamCacheName, + UpstreamTransformerBuilder<K, V> transformerBuilder, IgniteBiPredicate<K, V> filter, - UpstreamTransformerChain<K, V> transformersChain, String datasetCacheName, PartitionContextBuilder<K, V, C> ctxBuilder, + LearningEnvironmentBuilder envBuilder, int retries, int interval) { affinityCallWithRetries(ignite, Arrays.asList(datasetCacheName, upstreamCacheName), part -> { Ignite locIgnite = Ignition.localIgnite(); + LearningEnvironment env = envBuilder.buildForWorker(part); IgniteCache<K, V> locUpstreamCache = locIgnite.cache(upstreamCacheName); @@ -244,25 +287,24 @@ public class ComputeUtils { qry.setFilter(filter); C ctx; - UpstreamTransformerChain<K, V> chainCopy = Utils.copy(transformersChain); - chainCopy.modifySeed(s -> s + part); + UpstreamTransformer<K, V> transformer = transformerBuilder.build(env); + UpstreamTransformer<K, V> transformerCp = Utils.copy(transformer); - long cnt = computeCount(locUpstreamCache, qry, transformersChain); + long cnt = computeCount(locUpstreamCache, qry, transformer); try (QueryCursor<UpstreamEntry<K, V>> cursor = locUpstreamCache.query(qry, e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) { Iterator<UpstreamEntry<K, V>> it = cursor.iterator(); - if (!chainCopy.isEmpty()) { - Stream<UpstreamEntry<K, V>> transformedStream = chainCopy.transform(Utils.asStream(it, cnt)); - it = transformedStream.iterator(); - } + Stream<UpstreamEntry<K, V>> transformedStream = transformerCp.transform(Utils.asStream(it, cnt)); + it = transformedStream.iterator(); + Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>( it, cnt, "Cache expected to be not modified during dataset data building [partition=" + part + ']'); - ctx = ctxBuilder.build(iter, cnt); + ctx = ctxBuilder.build(env, iter, cnt); } IgniteCache<Integer, C> datasetCache = locIgnite.cache(datasetCacheName); @@ -279,9 +321,10 @@ public class ComputeUtils { * @param ignite Ignite instance. * @param upstreamCacheName Name of an {@code upstream} cache. * @param filter Filter for {@code upstream} data. - * @param transformersChain Transformers of upstream data. + * @param transformerBuilder Builder of transformer of upstream data. * @param datasetCacheName Name of a partition {@code context} cache. * @param ctxBuilder Partition {@code context} builder. + * @param envBuilder Environment builder. * @param retries Number of retries for the case when one of partitions not found on the node. * @param <K> Type of a key in {@code upstream} data. * @param <V> Type of a value in {@code upstream} data. @@ -291,11 +334,12 @@ public class ComputeUtils { Ignite ignite, String upstreamCacheName, IgniteBiPredicate<K, V> filter, - UpstreamTransformerChain<K, V> transformersChain, + UpstreamTransformerBuilder<K, V> transformerBuilder, String datasetCacheName, PartitionContextBuilder<K, V, C> ctxBuilder, + LearningEnvironmentBuilder envBuilder, int retries) { - initContext(ignite, upstreamCacheName, filter, transformersChain, datasetCacheName, ctxBuilder, retries, 0); + initContext(ignite, upstreamCacheName, transformerBuilder, filter, datasetCacheName, ctxBuilder, envBuilder, retries, 0); } /** @@ -328,25 +372,21 @@ public class ComputeUtils { /** * Computes number of entries selected from the cache by the query. * - * @param <K> Type of a key in {@code upstream} data. - * @param <V> Type of a value in {@code upstream} data. * @param cache Ignite cache with upstream data. * @param qry Cache query. - * @param transformersChain Transformers of stream of upstream data. + * @param transformer Upstream transformer. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. * @return Number of entries supplied by the iterator. */ private static <K, V> long computeCount( IgniteCache<K, V> cache, ScanQuery<K, V> qry, - UpstreamTransformerChain<K, V> transformersChain) { + UpstreamTransformer<K, V> transformer) { try (QueryCursor<UpstreamEntry<K, V>> cursor = cache.query(qry, e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) { - // 'If' statement below is just for optimization, to avoid unnecessary iterator -> stream -> iterator - // operations. - return transformersChain.isEmpty() ? - computeCount(cursor.iterator()) : - computeCount(transformersChain.transform(Utils.asStream(cursor.iterator())).iterator()); + return computeCount(transformer.transform(Utils.asStream(cursor.iterator())).iterator()); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java index 975beda..8c67c02 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java @@ -20,6 +20,7 @@ package org.apache.ignite.ml.dataset.impl.local; import java.io.Serializable; import java.util.List; import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.environment.LearningEnvironment; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.functions.IgniteBinaryOperator; import org.apache.ignite.ml.math.functions.IgniteTriFunction; @@ -32,6 +33,9 @@ import org.apache.ignite.ml.math.functions.IgniteTriFunction; * @param <D> Type of a partition {@code data}. */ public class LocalDataset<C extends Serializable, D extends AutoCloseable> implements Dataset<C, D> { + /** Partition {@code data} storage. */ + private final List<LearningEnvironment> envs; + /** Partition {@code context} storage. */ private final List<C> ctx; @@ -42,38 +46,42 @@ public class LocalDataset<C extends Serializable, D extends AutoCloseable> imple * Constructs a new instance of dataset based on local data structures such as {@code Map} and {@code List} and * doesn't requires Ignite environment. * + * @param envs List of {@link LearningEnvironment}. * @param ctx Partition {@code context} storage. * @param data Partition {@code data} storage. */ - LocalDataset(List<C> ctx, List<D> data) { + LocalDataset(List<LearningEnvironment> envs, List<C> ctx, List<D> data) { + this.envs = envs; this.ctx = ctx; this.data = data; } /** {@inheritDoc} */ - @Override public <R> R computeWithCtx(IgniteTriFunction<C, D, Integer, R> map, IgniteBinaryOperator<R> reduce, + @Override public <R> R computeWithCtx(IgniteTriFunction<C, D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity) { R res = identity; for (int part = 0; part < ctx.size(); part++) { D partData = data.get(part); + LearningEnvironment env = envs.get(part); if (partData != null) - res = reduce.apply(res, map.apply(ctx.get(part), partData, part)); + res = reduce.apply(res, map.apply(ctx.get(part), partData, env)); } return res; } /** {@inheritDoc} */ - @Override public <R> R compute(IgniteBiFunction<D, Integer, R> map, IgniteBinaryOperator<R> reduce, R identity) { + @Override public <R> R compute(IgniteBiFunction<D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity) { R res = identity; for (int part = 0; part < data.size(); part++) { D partData = data.get(part); + LearningEnvironment env = envs.get(part); if (partData != null) - res = reduce.apply(res, map.apply(partData, part)); + res = reduce.apply(res, map.apply(partData, env)); } return res; http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java index 2514f3e..b8cd8dc 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java @@ -22,12 +22,17 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.apache.ignite.lang.IgniteBiPredicate; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.PartitionContextBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; -import org.apache.ignite.ml.dataset.UpstreamTransformerChain; +import org.apache.ignite.ml.dataset.UpstreamTransformer; +import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder; +import org.apache.ignite.ml.environment.LearningEnvironment; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.util.Utils; @@ -49,7 +54,7 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> { private final IgniteBiPredicate<K, V> filter; /** Upstream transformers. */ - private final UpstreamTransformerChain<K, V> upstreamTransformers; + private final UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder; /** * Constructs a new instance of local dataset builder that makes {@link LocalDataset} with default predicate that @@ -68,16 +73,34 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> { * @param upstreamMap {@code Map} with upstream data. * @param filter Filter for {@code upstream} data. * @param partitions Number of partitions. + * @param upstreamTransformerBuilder Builder of upstream transformer. */ - public LocalDatasetBuilder(Map<K, V> upstreamMap, IgniteBiPredicate<K, V> filter, int partitions) { + public LocalDatasetBuilder(Map<K, V> upstreamMap, + IgniteBiPredicate<K, V> filter, + int partitions, + UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder) { this.upstreamMap = upstreamMap; this.filter = filter; this.partitions = partitions; - this.upstreamTransformers = UpstreamTransformerChain.empty(); + this.upstreamTransformerBuilder = upstreamTransformerBuilder; + } + + /** + * Constructs a new instance of local dataset builder that makes {@link LocalDataset}. + * + * @param upstreamMap {@code Map} with upstream data. + * @param filter Filter for {@code upstream} data. + * @param partitions Number of partitions. + */ + public LocalDatasetBuilder(Map<K, V> upstreamMap, + IgniteBiPredicate<K, V> filter, + int partitions) { + this(upstreamMap, filter, partitions, UpstreamTransformerBuilder.identity()); } /** {@inheritDoc} */ @Override public <C extends Serializable, D extends AutoCloseable> LocalDataset<C, D> build( + LearningEnvironmentBuilder envBuilder, PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder) { List<C> ctxList = new ArrayList<>(); List<D> dataList = new ArrayList<>(); @@ -99,36 +122,29 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> { int ptr = 0; + List<LearningEnvironment> envs = IntStream.range(0, partitions).boxed().map(envBuilder::buildForWorker) + .collect(Collectors.toList()); + for (int part = 0; part < partitions; part++) { - int cnt = part == partitions - 1 ? entriesList.size() - ptr : Math.min(partSize, entriesList.size() - ptr); - - int p = part; - upstreamTransformers.modifySeed(s -> s + p); - - if (!upstreamTransformers.isEmpty()) { - cnt = (int)upstreamTransformers.transform( - Utils.asStream(new IteratorWindow<>(thirdKeysIter, k -> k, cnt))).count(); - } - - Iterator<UpstreamEntry<K, V>> iter; - if (upstreamTransformers.isEmpty()) - iter = new IteratorWindow<>(firstKeysIter, k -> k, cnt); - - else { - iter = upstreamTransformers.transform( - Utils.asStream(new IteratorWindow<>(firstKeysIter, k -> k, cnt))).iterator(); - } - C ctx = cnt > 0 ? partCtxBuilder.build(iter, cnt) : null; - - Iterator<UpstreamEntry<K, V>> iter1; - if (upstreamTransformers.isEmpty()) { - iter1 = upstreamTransformers.transform( - Utils.asStream(new IteratorWindow<>(secondKeysIter, k -> k, cnt))).iterator(); - } - else - iter1 = new IteratorWindow<>(secondKeysIter, k -> k, cnt); - - D data = cnt > 0 ? partDataBuilder.build( + int cntBeforeTransform = + part == partitions - 1 ? entriesList.size() - ptr : Math.min(partSize, entriesList.size() - ptr); + LearningEnvironment env = envs.get(part); + UpstreamTransformer<K, V> transformer1 = upstreamTransformerBuilder.build(env); + UpstreamTransformer<K, V> transformer2 = Utils.copy(transformer1); + UpstreamTransformer<K, V> transformer3 = Utils.copy(transformer1); + + int cnt = (int)transformer1.transform(Utils.asStream(new IteratorWindow<>(thirdKeysIter, k -> k, cntBeforeTransform))).count(); + + Iterator<UpstreamEntry<K, V>> iter = + transformer2.transform(Utils.asStream(new IteratorWindow<>(firstKeysIter, k -> k, cntBeforeTransform))).iterator(); + + C ctx = cntBeforeTransform > 0 ? partCtxBuilder.build(env, iter, cnt) : null; + + Iterator<UpstreamEntry<K, V>> iter1 = transformer3.transform( + Utils.asStream(new IteratorWindow<>(secondKeysIter, k -> k, cntBeforeTransform))).iterator(); + + D data = cntBeforeTransform > 0 ? partDataBuilder.build( + env, iter1, cnt, ctx @@ -137,20 +153,18 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> { ctxList.add(ctx); dataList.add(data); - ptr += cnt; + ptr += cntBeforeTransform; } - return new LocalDataset<>(ctxList, dataList); + return new LocalDataset<>(envs, ctxList, dataList); } /** {@inheritDoc} */ - @Override public UpstreamTransformerChain<K, V> upstreamTransformersChain() { - return upstreamTransformers; + @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder<K, V> builder) { + return new LocalDatasetBuilder<>(upstreamMap, filter, partitions, upstreamTransformerBuilder.andThen(builder)); } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public DatasetBuilder<K, V> withFilter(IgniteBiPredicate<K, V> filterToAdd) { return new LocalDatasetBuilder<>(upstreamMap, (e1, e2) -> filter.apply(e1, e2) && filterToAdd.apply(e1, e2), partitions); @@ -164,24 +178,16 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> { * @param <T> Target type of entries. */ private static class IteratorWindow<K, T> implements Iterator<T> { - /** - * Delegate iterator. - */ + /** Delegate iterator. */ private final Iterator<K> delegate; - /** - * Transformer that transforms entries from one type to another. - */ + /** Transformer that transforms entries from one type to another. */ private final IgniteFunction<K, T> map; - /** - * Count of entries to produce. - */ + /** Count of entries to produce. */ private final int cnt; - /** - * Number of already produced entries. - */ + /** Number of already produced entries. */ private int ptr; /** @@ -197,16 +203,12 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> { this.cnt = cnt; } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public boolean hasNext() { return delegate.hasNext() && ptr < cnt; } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public T next() { ++ptr; http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/DatasetWrapper.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/DatasetWrapper.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/DatasetWrapper.java index 578a149..270c7eb 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/DatasetWrapper.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/DatasetWrapper.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.dataset.primitive; import java.io.Serializable; import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.environment.LearningEnvironment; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.functions.IgniteBinaryOperator; import org.apache.ignite.ml.math.functions.IgniteTriFunction; @@ -46,13 +47,13 @@ public class DatasetWrapper<C extends Serializable, D extends AutoCloseable> imp } /** {@inheritDoc} */ - @Override public <R> R computeWithCtx(IgniteTriFunction<C, D, Integer, R> map, IgniteBinaryOperator<R> reduce, + @Override public <R> R computeWithCtx(IgniteTriFunction<C, D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity) { return delegate.computeWithCtx(map, reduce, identity); } /** {@inheritDoc} */ - @Override public <R> R compute(IgniteBiFunction<D, Integer, R> map, IgniteBinaryOperator<R> reduce, R identity) { + @Override public <R> R compute(IgniteBiFunction<D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity) { return delegate.compute(map, reduce, identity); } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java index be1724c..5273fa6 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.Iterator; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.ml.environment.LearningEnvironment; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.tree.data.DecisionTreeData; @@ -56,7 +57,11 @@ public class FeatureMatrixWithLabelsOnHeapDataBuilder<K, V, C extends Serializab } /** {@inheritDoc} */ - @Override public FeatureMatrixWithLabelsOnHeapData build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) { + @Override public FeatureMatrixWithLabelsOnHeapData build( + LearningEnvironment env, + Iterator<UpstreamEntry<K, V>> upstreamData, + long upstreamDataSize, + C ctx) { double[][] features = new double[Math.toIntExact(upstreamDataSize)][]; double[] labels = new double[Math.toIntExact(upstreamDataSize)]; http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/context/EmptyContextBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/context/EmptyContextBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/context/EmptyContextBuilder.java index 03b69b5..9fd77b5 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/context/EmptyContextBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/context/EmptyContextBuilder.java @@ -21,6 +21,7 @@ import java.util.Iterator; import org.apache.ignite.ml.dataset.PartitionContextBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironment; /** * A partition {@code context} builder that makes {@link EmptyContext}. @@ -33,7 +34,7 @@ public class EmptyContextBuilder<K, V> implements PartitionContextBuilder<K, V, private static final long serialVersionUID = 6620781747993467186L; /** {@inheritDoc} */ - @Override public EmptyContext build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize) { + @Override public EmptyContext build(LearningEnvironment env, Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize) { return new EmptyContext(); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleDatasetDataBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleDatasetDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleDatasetDataBuilder.java index cf5bc7a..b14d8a2 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleDatasetDataBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleDatasetDataBuilder.java @@ -22,6 +22,7 @@ import java.util.Iterator; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; import org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData; +import org.apache.ignite.ml.environment.LearningEnvironment; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; @@ -50,7 +51,9 @@ public class SimpleDatasetDataBuilder<K, V, C extends Serializable> } /** {@inheritDoc} */ - @Override public SimpleDatasetData build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) { + @Override public SimpleDatasetData build( + LearningEnvironment env, + Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) { // Prepares the matrix of features in flat column-major format. int cols = -1; double[] features = null; http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleLabeledDatasetDataBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleLabeledDatasetDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleLabeledDatasetDataBuilder.java index 6286255..48166ee 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleLabeledDatasetDataBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleLabeledDatasetDataBuilder.java @@ -22,6 +22,7 @@ import java.util.Iterator; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData; +import org.apache.ignite.ml.environment.LearningEnvironment; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; @@ -56,7 +57,9 @@ public class SimpleLabeledDatasetDataBuilder<K, V, C extends Serializable> } /** {@inheritDoc} */ - @Override public SimpleLabeledDatasetData build(Iterator<UpstreamEntry<K, V>> upstreamData, + @Override public SimpleLabeledDatasetData build( + LearningEnvironment env, + Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) { // Prepares the matrix of features in flat column-major format. int featureCols = -1; http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/environment/DefaultLearningEnvironmentBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/DefaultLearningEnvironmentBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/DefaultLearningEnvironmentBuilder.java new file mode 100644 index 0000000..4aef8f2 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/environment/DefaultLearningEnvironmentBuilder.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.environment; + +import java.util.Random; +import org.apache.ignite.ml.environment.logging.MLLogger; +import org.apache.ignite.ml.environment.logging.NoOpLogger; +import org.apache.ignite.ml.environment.parallelism.DefaultParallelismStrategy; +import org.apache.ignite.ml.environment.parallelism.NoParallelismStrategy; +import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy; +import org.apache.ignite.ml.math.functions.IgniteFunction; + +import static org.apache.ignite.ml.math.functions.IgniteFunction.constant; + +/** + * Builder for {@link LearningEnvironment}. + */ +public class DefaultLearningEnvironmentBuilder implements LearningEnvironmentBuilder { + /** Serial version id. */ + private static final long serialVersionUID = 8502532880517447662L; + + /** Dependency (partition -> Parallelism strategy). */ + private IgniteFunction<Integer, ParallelismStrategy> parallelismStgy; + + /** Dependency (partition -> Logging factory). */ + private IgniteFunction<Integer, MLLogger.Factory> loggingFactory; + + /** Dependency (partition -> Random number generator seed). */ + private IgniteFunction<Integer, Long> seed; + + /** Dependency (partition -> Random numbers generator supplier). */ + private IgniteFunction<Integer, Random> rngSupplier; + + /** + * Creates an instance of DefaultLearningEnvironmentBuilder. + */ + DefaultLearningEnvironmentBuilder() { + parallelismStgy = constant(NoParallelismStrategy.INSTANCE); + loggingFactory = constant(NoOpLogger.factory()); + seed = constant(new Random().nextLong()); + rngSupplier = constant(new Random()); + } + + /** {@inheritDoc} */ + @Override public LearningEnvironmentBuilder withRNGSeedDependency(IgniteFunction<Integer, Long> seed) { + this.seed = seed; + + return this; + } + + /** {@inheritDoc} */ + @Override public LearningEnvironmentBuilder withRandomDependency(IgniteFunction<Integer, Random> rngSupplier) { + this.rngSupplier = rngSupplier; + + return this; + } + + /** {@inheritDoc} */ + @Override public DefaultLearningEnvironmentBuilder withParallelismStrategyDependency( + IgniteFunction<Integer, ParallelismStrategy> stgy) { + this.parallelismStgy = stgy; + + return this; + } + + /** {@inheritDoc} */ + @Override public DefaultLearningEnvironmentBuilder withParallelismStrategyTypeDependency( + IgniteFunction<Integer, ParallelismStrategy.Type> stgyType) { + this.parallelismStgy = part -> strategyByType(stgyType.apply(part)); + + return this; + } + + /** + * Get parallelism strategy by {@link ParallelismStrategy.Type}. + * + * @param stgyType Strategy type. + * @return {@link ParallelismStrategy}. + */ + private static ParallelismStrategy strategyByType(ParallelismStrategy.Type stgyType) { + switch (stgyType) { + case NO_PARALLELISM: + return NoParallelismStrategy.INSTANCE; + case ON_DEFAULT_POOL: + return new DefaultParallelismStrategy(); + } + throw new IllegalStateException("Wrong type"); + } + + + /** {@inheritDoc} */ + @Override public DefaultLearningEnvironmentBuilder withLoggingFactoryDependency( + IgniteFunction<Integer, MLLogger.Factory> loggingFactory) { + this.loggingFactory = loggingFactory; + return this; + } + + /** {@inheritDoc} */ + @Override public LearningEnvironment buildForWorker(int part) { + Random random = rngSupplier.apply(part); + random.setSeed(seed.apply(part)); + return new LearningEnvironmentImpl(part, random, parallelismStgy.apply(part), loggingFactory.apply(part)); + } + + /** Default LearningEnvironment implementation. */ + private class LearningEnvironmentImpl implements LearningEnvironment { + /** Parallelism strategy. */ + private final ParallelismStrategy parallelismStgy; + + /** Logging factory. */ + private final MLLogger.Factory loggingFactory; + + /** Partition. */ + private final int part; + + /** Random numbers generator. */ + private final Random randomNumGen; + + /** + * Creates an instance of LearningEnvironmentImpl. + * + * @param part Partition. + * @param rng Random numbers generator. + * @param parallelismStgy Parallelism strategy. + * @param loggingFactory Logging factory. + */ + private LearningEnvironmentImpl( + int part, + Random rng, + ParallelismStrategy parallelismStgy, + MLLogger.Factory loggingFactory) { + this.part = part; + this.parallelismStgy = parallelismStgy; + this.loggingFactory = loggingFactory; + randomNumGen = rng; + } + + /** {@inheritDoc} */ + @Override public ParallelismStrategy parallelismStrategy() { + return parallelismStgy; + } + + /** {@inheritDoc} */ + @Override public MLLogger logger() { + return loggingFactory.create(getClass()); + } + + /** {@inheritDoc} */ + @Override public Random randomNumbersGenerator() { + return randomNumGen; + } + + /** {@inheritDoc} */ + @Override public <T> MLLogger logger(Class<T> clazz) { + return loggingFactory.create(clazz); + } + + /** {@inheritDoc} */ + @Override public int partition() { + return part; + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironment.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironment.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironment.java index f5fb693..f1e4f32 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironment.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironment.java @@ -17,6 +17,8 @@ package org.apache.ignite.ml.environment; +import java.util.Random; +import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.environment.logging.MLLogger; import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy; @@ -26,7 +28,7 @@ import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy; */ public interface LearningEnvironment { /** Default environment */ - public static final LearningEnvironment DEFAULT = builder().build(); + public static final LearningEnvironment DEFAULT_TRAINER_ENV = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer(); /** * Returns Parallelism Strategy instance. @@ -39,6 +41,13 @@ public interface LearningEnvironment { public MLLogger logger(); /** + * Random numbers generator. + * + * @return Random numbers generator. + */ + public Random randomNumbersGenerator(); + + /** * Returns an instance of logger for specific class. * * @param forCls Logging class context. @@ -46,9 +55,9 @@ public interface LearningEnvironment { public <T> MLLogger logger(Class<T> forCls); /** - * Creates an instance of LearningEnvironmentBuilder. + * Gets current partition. If this is called not in one of compute tasks of {@link Dataset}, will return -1. + * + * @return Partition. */ - public static LearningEnvironmentBuilder builder() { - return new LearningEnvironmentBuilder(); - } + public int partition(); } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilder.java index 98f584f..8fcc6b2 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilder.java @@ -17,109 +17,134 @@ package org.apache.ignite.ml.environment; +import java.io.Serializable; +import java.util.Random; import org.apache.ignite.ml.environment.logging.MLLogger; -import org.apache.ignite.ml.environment.logging.NoOpLogger; -import org.apache.ignite.ml.environment.parallelism.DefaultParallelismStrategy; -import org.apache.ignite.ml.environment.parallelism.NoParallelismStrategy; import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy; +import org.apache.ignite.ml.math.functions.IgniteFunction; + +import static org.apache.ignite.ml.math.functions.IgniteFunction.constant; /** - * Builder for LearningEnvironment. + * Builder of learning environment. */ -public class LearningEnvironmentBuilder { - /** Parallelism strategy. */ - private ParallelismStrategy parallelismStgy; - /** Logging factory. */ - private MLLogger.Factory loggingFactory; +public interface LearningEnvironmentBuilder extends Serializable { + /** + * Builds {@link LearningEnvironment} for worker on given partition. + * + * @param part Partition. + * @return {@link LearningEnvironment} for worker on given partition. + */ + public LearningEnvironment buildForWorker(int part); /** - * Creates an instance of LearningEnvironmentBuilder. + * Builds learning environment for trainer. + * + * @return Learning environment for trainer. */ - public LearningEnvironmentBuilder() { - parallelismStgy = NoParallelismStrategy.INSTANCE; - loggingFactory = NoOpLogger.factory(); + public default LearningEnvironment buildForTrainer() { + return buildForWorker(-1); } /** - * Specifies Parallelism Strategy for LearningEnvironment. + * Specifies dependency (partition -> Parallelism Strategy Type for LearningEnvironment). * - * @param stgy Parallelism Strategy. + * @param stgyType Function describing dependency (partition -> Parallelism Strategy Type). + * @return This object. */ - public <T> LearningEnvironmentBuilder withParallelismStrategy(ParallelismStrategy stgy) { - this.parallelismStgy = stgy; + public LearningEnvironmentBuilder withParallelismStrategyTypeDependency( + IgniteFunction<Integer, ParallelismStrategy.Type> stgyType); - return this; + /** + * Specifies Parallelism Strategy Type for LearningEnvironment. Same strategy type will be used for all partitions. + * + * @param stgyType Parallelism Strategy Type. + * @return This object. + */ + public default LearningEnvironmentBuilder withParallelismStrategyType(ParallelismStrategy.Type stgyType) { + return withParallelismStrategyTypeDependency(constant(stgyType)); } /** - * Specifies Parallelism Strategy for LearningEnvironment. + * Specifies dependency (partition -> Parallelism Strategy for LearningEnvironment). * - * @param stgyType Parallelism Strategy Type. + * @param stgy Function describing dependency (partition -> Parallelism Strategy). + * @return This object. */ - public LearningEnvironmentBuilder withParallelismStrategy(ParallelismStrategy.Type stgyType) { - switch (stgyType) { - case NO_PARALLELISM: - this.parallelismStgy = NoParallelismStrategy.INSTANCE; - break; - case ON_DEFAULT_POOL: - this.parallelismStgy = new DefaultParallelismStrategy(); - break; - } - return this; + public LearningEnvironmentBuilder withParallelismStrategyDependency(IgniteFunction<Integer, ParallelismStrategy> stgy); + + /** + * Specifies Parallelism Strategy for LearningEnvironment. Same strategy type will be used for all partitions. + * + * @param stgy Parallelism Strategy. + * @param <T> Parallelism strategy type. + * @return This object. + */ + public default <T extends ParallelismStrategy & Serializable> LearningEnvironmentBuilder withParallelismStrategy(T stgy) { + return withParallelismStrategyDependency(constant(stgy)); } /** - * Specifies Logging factory for LearningEnvironment. + * Specify dependency (partition -> logging factory). * - * @param loggingFactory Logging Factory. + * @param loggingFactory Function describing (partition -> logging factory). + * @return This object. */ - public LearningEnvironmentBuilder withLoggingFactory(MLLogger.Factory loggingFactory) { - this.loggingFactory = loggingFactory; - return this; + public LearningEnvironmentBuilder withLoggingFactoryDependency(IgniteFunction<Integer, MLLogger.Factory> loggingFactory); + + /** + * Specify logging factory. + * + * @param loggingFactory Logging factory. + * @return This object. + */ + public default <T extends MLLogger.Factory & Serializable> LearningEnvironmentBuilder withLoggingFactory(T loggingFactory) { + return withLoggingFactoryDependency(constant(loggingFactory)); } /** - * Create an instance of LearningEnvironment. + * Specify dependency (partition -> seed for random number generator). Same seed will be used for all partitions. + * + * @param seed Function describing dependency (partition -> seed for random number generator). + * @return This object. */ - public LearningEnvironment build() { - return new LearningEnvironmentImpl(parallelismStgy, loggingFactory); + public LearningEnvironmentBuilder withRNGSeedDependency(IgniteFunction<Integer, Long> seed); + + /** + * Specify seed for random number generator. + * + * @param seed Seed for random number generator. + * @return This object. + */ + public default LearningEnvironmentBuilder withRNGSeed(long seed) { + return withRNGSeedDependency(constant(seed)); } /** - * Default LearningEnvironment implementation. + * Specify dependency (partition -> random numbers generator). + * + * @param rngSupplier Function describing dependency (partition -> random numbers generator). + * @return This object. + */ + public LearningEnvironmentBuilder withRandomDependency(IgniteFunction<Integer, Random> rngSupplier); + + /** + * Specify random numbers generator for learning environment. Same random will be used for all partitions. + * + * @param random Rrandom numbers generator for learning environment. + * @return This object. + */ + public default LearningEnvironmentBuilder withRandom(Random random) { + return withRandomDependency(constant(random)); + } + + /** + * Get default {@link LearningEnvironmentBuilder}. + * + * @return Default {@link LearningEnvironmentBuilder}. */ - private class LearningEnvironmentImpl implements LearningEnvironment { - /** Parallelism strategy. */ - private final ParallelismStrategy parallelismStgy; - /** Logging factory. */ - private final MLLogger.Factory loggingFactory; - - /** - * Creates an instance of LearningEnvironmentImpl. - * - * @param parallelismStgy Parallelism strategy. - * @param loggingFactory Logging factory. - */ - private LearningEnvironmentImpl(ParallelismStrategy parallelismStgy, - MLLogger.Factory loggingFactory) { - this.parallelismStgy = parallelismStgy; - this.loggingFactory = loggingFactory; - } - - /** {@inheritDoc} */ - @Override public ParallelismStrategy parallelismStrategy() { - return parallelismStgy; - } - - /** {@inheritDoc} */ - @Override public MLLogger logger() { - return loggingFactory.create(getClass()); - } - - /** {@inheritDoc} */ - @Override public <T> MLLogger logger(Class<T> clazz) { - return loggingFactory.create(clazz); - } + public static LearningEnvironmentBuilder defaultBuilder() { + return new DefaultLearningEnvironmentBuilder(); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java index e064fc3..c124e06 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java @@ -82,6 +82,9 @@ public class ConsoleLogger implements MLLogger { * ConsoleLogger factory. */ private static class Factory implements MLLogger.Factory { + /** Serial version uuid. */ + private static final long serialVersionUID = 5864605548782107893L; + /** Max Verbose level. */ private final VerboseLevel maxVerboseLevel; http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/ParallelismStrategy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/ParallelismStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/ParallelismStrategy.java index e7228f8..329ce89 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/ParallelismStrategy.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/ParallelismStrategy.java @@ -26,7 +26,6 @@ import org.apache.ignite.ml.math.functions.IgniteSupplier; * bagging, learning submodels for One-vs-All model, Cross-Validation etc. */ public interface ParallelismStrategy { - /** * The type of parallelism. */ http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java index d7bccd8..8239ebd 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java @@ -21,6 +21,7 @@ import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.structures.LabeledVector; @@ -35,12 +36,15 @@ public class KNNUtils { /** * Builds dataset. * + * @param envBuilder Learning environment builder. * @param datasetBuilder Dataset builder. * @param featureExtractor Feature extractor. * @param lbExtractor Label extractor. * @return Dataset. */ - @Nullable public static <K, V> Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> buildDataset(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + @Nullable public static <K, V> Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> buildDataset( + LearningEnvironmentBuilder envBuilder, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>( featureExtractor, @@ -51,7 +55,8 @@ public class KNNUtils { if (datasetBuilder != null) { dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new EmptyContext(), + envBuilder, + (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder ); } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java index e56a10a..c32ca56 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java @@ -31,6 +31,7 @@ import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.distances.DistanceMeasure; import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.functions.IgniteBiFunction; @@ -105,6 +106,12 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass return mdl.getDistanceMeasure().equals(distance) && mdl.getCandidates().rowSize() == k; } + /** {@inheritDoc} */ + @Override public ANNClassificationTrainer withEnvironmentBuilder( + LearningEnvironmentBuilder envBuilder) { + return (ANNClassificationTrainer)super.withEnvironmentBuilder(envBuilder); + } + /** */ @NotNull private LabeledVectorSet<ProbableLabel, LabeledVector> buildLabelsForCandidates(List<Vector> centers, CentroidStat centroidStat) { @@ -180,7 +187,8 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass ); try (Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new EmptyContext(), + envBuilder, + (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder )) { return dataset.compute(data -> { http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java index 1a3ff73..ed55318 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java @@ -18,6 +18,7 @@ package org.apache.ignite.ml.knn.classification; import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.knn.KNNUtils; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; @@ -46,7 +47,7 @@ public class KNNClassificationTrainer extends SingleLabelDatasetTrainer<KNNClass DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - KNNClassificationModel res = new KNNClassificationModel(KNNUtils.buildDataset(datasetBuilder, + KNNClassificationModel res = new KNNClassificationModel(KNNUtils.buildDataset(envBuilder, datasetBuilder, featureExtractor, lbExtractor)); if (mdl != null) res.copyStateFrom(mdl); @@ -54,6 +55,11 @@ public class KNNClassificationTrainer extends SingleLabelDatasetTrainer<KNNClass } /** {@inheritDoc} */ + @Override public KNNClassificationTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (KNNClassificationTrainer)super.withEnvironmentBuilder(envBuilder); + } + + /** {@inheritDoc} */ @Override protected boolean checkState(KNNClassificationModel mdl) { return true; } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java index 7a42dc8..9b348f3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java @@ -42,10 +42,13 @@ public class KNNRegressionTrainer extends SingleLabelDatasetTrainer<KNNRegressio } /** {@inheritDoc} */ - @Override public <K, V> KNNRegressionModel updateModel(KNNRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + @Override public <K, V> KNNRegressionModel updateModel( + KNNRegressionModel mdl, + DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { - KNNRegressionModel res = new KNNRegressionModel(KNNUtils.buildDataset(datasetBuilder, + KNNRegressionModel res = new KNNRegressionModel(KNNUtils.buildDataset(envBuilder, datasetBuilder, featureExtractor, lbExtractor)); if (mdl != null) res.copyStateFrom(mdl); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteFunction.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteFunction.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteFunction.java index 9d19592..2673b90 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteFunction.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteFunction.java @@ -26,5 +26,15 @@ import java.util.function.Function; * @see java.util.function.Function */ public interface IgniteFunction<T, R> extends Function<T, R>, Serializable { - + /** + * {@link IgniteFunction} returning specified constant. + * + * @param r Constant to return. + * @param <T> Type of input. + * @param <R> Type of output. + * @return {@link IgniteFunction} returning specified constant. + */ + static <T, R> IgniteFunction<T, R> constant(R r) { + return t -> r; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java index 14356e1..e0376b8 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java @@ -23,6 +23,7 @@ import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; /** * Distributed implementation of LSQR algorithm based on {@link AbstractLSQR} and {@link Dataset}. @@ -35,12 +36,15 @@ public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable { * Constructs a new instance of OnHeap LSQR algorithm implementation. * * @param datasetBuilder Dataset builder. + * @param envBuilder Learning environment builder. * @param partDataBuilder Partition data builder. */ public LSQROnHeap(DatasetBuilder<K, V> datasetBuilder, + LearningEnvironmentBuilder envBuilder, PartitionDataBuilder<K, V, LSQRPartitionContext, SimpleLabeledDatasetData> partDataBuilder) { this.dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new LSQRPartitionContext(), + envBuilder, + (env, upstream, upstreamSize) -> new LSQRPartitionContext(), partDataBuilder ); } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java index 5e1341b..3c580c3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java @@ -40,6 +40,17 @@ public class VectorUtils { } /** + * Create new vector of specified size n with specified value. + * + * @param val Value. + * @param n Size; + * @return New vector of specified size n with specified value. + */ + public static DenseVector fill(double val, int n) { + return (DenseVector)new DenseVector(n).assign(val); + } + + /** * Turn number into a local Vector of given size with one-hot encoding. * * @param num Number to turn into vector. http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java index 7426506..f265318 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java @@ -115,7 +115,8 @@ public class OneVsRestTrainer<M extends Model<Vector, Double>> List<Double> res = new ArrayList<>(); try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new EmptyContext(), + envBuilder, + (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder )) { final Set<Double> clsLabels = dataset.compute(data -> { http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java index 7ee423d..cdaac5a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java @@ -24,6 +24,7 @@ import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; @@ -59,14 +60,20 @@ public class GaussianNaiveBayesTrainer extends SingleLabelDatasetTrainer<Gaussia } /** {@inheritDoc} */ + @Override public GaussianNaiveBayesTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (GaussianNaiveBayesTrainer)super.withEnvironmentBuilder(envBuilder); + } + + /** {@inheritDoc} */ @Override protected <K, V> GaussianNaiveBayesModel updateModel(GaussianNaiveBayesModel mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { assert datasetBuilder != null; try (Dataset<EmptyContext, GaussianNaiveBayesSumsHolder> dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new EmptyContext(), - (upstream, upstreamSize, ctx) -> { + envBuilder, + (env, upstream, upstreamSize) -> new EmptyContext(), + (env, upstream, upstreamSize, ctx) -> { GaussianNaiveBayesSumsHolder res = new GaussianNaiveBayesSumsHolder(); while (upstream.hasNext()) { http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java index c75c5bb..ea0bb6c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java @@ -124,6 +124,7 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer assert updatesStgy!= null; try (Dataset<EmptyContext, SimpleLabeledDatasetData> dataset = datasetBuilder.build( + envBuilder, new EmptyContextBuilder<>(), new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor) )) { http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/pipeline/Pipeline.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/pipeline/Pipeline.java b/modules/ml/src/main/java/org/apache/ignite/ml/pipeline/Pipeline.java index 8bfcb34..1aeac6b 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/pipeline/Pipeline.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/pipeline/Pipeline.java @@ -26,6 +26,7 @@ import org.apache.ignite.ml.Model; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; @@ -53,6 +54,9 @@ public class Pipeline<K, V, R> { /** Final trainer stage. */ private DatasetTrainer finalStage; + /** Learning environment builder. */ + private LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder(); + /** * Adds feature extractor as a zero stage. * @@ -110,6 +114,15 @@ public class Pipeline<K, V, R> { } /** + * Set learning environment builder. + * + * @param envBuilder Learning environment builder. + */ + public void setEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + this.envBuilder = envBuilder; + } + + /** * Fits the pipeline to the input mock data. * * @param data Data. @@ -132,6 +145,7 @@ public class Pipeline<K, V, R> { preprocessors.forEach(e -> { finalFeatureExtractor = e.fit( + envBuilder, datasetBuilder, finalFeatureExtractor );
