http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationPreprocessor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationPreprocessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationPreprocessor.java new file mode 100644 index 0000000..7c94b8f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationPreprocessor.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.normalization; + +import org.apache.ignite.ml.math.functions.IgniteBiFunction; + +/** + * Preprocessing function that makes normalization. From mathematical point of view it's the following function which + * is applied to every element in dataset: + * + * {@code a_i = (a_i - min_i) / (max_i - min_i) for all i}, + * + * where {@code i} is a number of column, {@code max_i} is the value of the maximum element in this columns, + * {@code min_i} is the value of the minimal element in this column. + * + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + */ +public class NormalizationPreprocessor<K, V> implements IgniteBiFunction<K, V, double[]> { + /** */ + private static final long serialVersionUID = 6997800576392623469L; + + /** Minimal values. */ + private final double[] min; + + /** Maximum values. */ + private final double[] max; + + /** Base preprocessor. */ + private final IgniteBiFunction<K, V, double[]> basePreprocessor; + + /** + * Constructs a new instance of normalization preprocessor. + * + * @param min Minimal values. + * @param max Maximum values. + * @param basePreprocessor Base preprocessor. + */ + public NormalizationPreprocessor(double[] min, double[] max, IgniteBiFunction<K, V, double[]> basePreprocessor) { + this.min = min; + this.max = max; + this.basePreprocessor = basePreprocessor; + } + + /** + * Applies this preprocessor. + * + * @param k Key. + * @param v Value. + * @return Preprocessed row. + */ + @Override public double[] apply(K k, V v) { + double[] res = basePreprocessor.apply(k, v); + + assert res.length == min.length; + assert res.length == max.length; + + for (int i = 0; i < res.length; i++) + res[i] = (res[i] - min[i]) / (max[i] - min[i]); + + return res; + } + + /** */ + public double[] getMin() { + return min; + } + + /** */ + public double[] getMax() { + return max; + } +}
http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java new file mode 100644 index 0000000..16623ba --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.normalization; + +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; + +/** + * Trainer of the normalization preprocessor. + * + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + */ +public class NormalizationTrainer<K, V> implements PreprocessingTrainer<K, V, double[], double[]> { + /** {@inheritDoc} */ + @Override public NormalizationPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, double[]> basePreprocessor, int cols) { + try (Dataset<EmptyContext, NormalizationPartitionData> dataset = datasetBuilder.build( + (upstream, upstreamSize) -> new EmptyContext(), + (upstream, upstreamSize, ctx) -> { + double[] min = new double[cols]; + double[] max = new double[cols]; + + for (int i = 0; i < cols; i++) { + min[i] = Double.MAX_VALUE; + max[i] = -Double.MAX_VALUE; + } + + while (upstream.hasNext()) { + UpstreamEntry<K, V> entity = upstream.next(); + double[] row = basePreprocessor.apply(entity.getKey(), entity.getValue()); + for (int i = 0; i < cols; i++) { + if (row[i] < min[i]) + min[i] = row[i]; + if (row[i] > max[i]) + max[i] = row[i]; + } + } + return new NormalizationPartitionData(min, max); + } + )) { + double[][] minMax = dataset.compute( + data -> new double[][]{ data.getMin(), data.getMax() }, + (a, b) -> { + if (a == null) + return b; + + if (b == null) + return a; + + double[][] res = new double[2][]; + + res[0] = new double[a[0].length]; + for (int i = 0; i < res[0].length; i++) + res[0][i] = Math.min(a[0][i], b[0][i]); + + res[1] = new double[a[1].length]; + for (int i = 0; i < res[1].length; i++) + res[1][i] = Math.max(a[1][i], b[1][i]); + + return res; + } + ); + + return new NormalizationPreprocessor<>(minMax[0], minMax[1], basePreprocessor); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/package-info.java new file mode 100644 index 0000000..5c3146f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains standardization preprocessor. + */ +package org.apache.ignite.ml.preprocessing.normalization; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/package-info.java new file mode 100644 index 0000000..ca04410 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Base package for machine learning preprocessing classes. + */ +package org.apache.ignite.ml.preprocessing; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java index c42efc5..7102d6a 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java @@ -18,10 +18,12 @@ package org.apache.ignite.ml; import org.apache.ignite.ml.clustering.ClusteringTestSuite; +import org.apache.ignite.ml.dataset.DatasetTestSuite; import org.apache.ignite.ml.knn.KNNTestSuite; import org.apache.ignite.ml.math.MathImplMainTestSuite; import org.apache.ignite.ml.nn.MLPTestSuite; import org.apache.ignite.ml.optimization.OptimizationTestSuite; +import org.apache.ignite.ml.preprocessing.PreprocessingTestSuite; import org.apache.ignite.ml.regressions.RegressionsTestSuite; import org.apache.ignite.ml.svm.SVMTestSuite; import org.apache.ignite.ml.trainers.group.TrainersGroupTestSuite; @@ -43,7 +45,9 @@ import org.junit.runners.Suite; LocalModelsTest.class, MLPTestSuite.class, TrainersGroupTestSuite.class, - OptimizationTestSuite.class + OptimizationTestSuite.class, + DatasetTestSuite.class, + PreprocessingTestSuite.class }) public class IgniteMLTestSuite { // No-op. http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/test/java/org/apache/ignite/ml/dataset/DatasetTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/DatasetTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/DatasetTestSuite.java new file mode 100644 index 0000000..3be79a4 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/DatasetTestSuite.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.dataset; + +import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilderTest; +import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetTest; +import org.apache.ignite.ml.dataset.impl.cache.util.ComputeUtilsTest; +import org.apache.ignite.ml.dataset.impl.cache.util.DatasetAffinityFunctionWrapperTest; +import org.apache.ignite.ml.dataset.impl.cache.util.PartitionDataStorageTest; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilderTest; +import org.apache.ignite.ml.dataset.primitive.DatasetWrapperTest; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * Test suite for all tests located in org.apache.ignite.ml.dataset.* package. + */ +@RunWith(Suite.class) +@Suite.SuiteClasses({ + DatasetWrapperTest.class, + ComputeUtilsTest.class, + DatasetAffinityFunctionWrapperTest.class, + PartitionDataStorageTest.class, + CacheBasedDatasetBuilderTest.class, + CacheBasedDatasetTest.class, + LocalDatasetBuilderTest.class +}) +public class DatasetTestSuite { + // No-op. +} http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilderTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilderTest.java new file mode 100644 index 0000000..c35cdc3 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilderTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.dataset.impl.cache; + +import java.util.Collection; +import java.util.UUID; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.cache.affinity.Affinity; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.cluster.ClusterNode; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; + +/** + * Tests for {@link CacheBasedDatasetBuilder}. + */ +public class CacheBasedDatasetBuilderTest extends GridCommonAbstractTest { + /** Number of nodes in grid. */ + private static final int NODE_COUNT = 10; + + /** Ignite instance. */ + private Ignite ignite; + + /** {@inheritDoc} */ + @Override protected void beforeTestsStarted() throws Exception { + for (int i = 1; i <= NODE_COUNT; i++) + startGrid(i); + } + + /** {@inheritDoc} */ + @Override protected void afterTestsStopped() { + stopAllGrids(); + } + + /** {@inheritDoc} */ + @Override protected void beforeTest() throws Exception { + /* Grid instance. */ + ignite = grid(NODE_COUNT); + ignite.configuration().setPeerClassLoadingEnabled(true); + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + } + + /** + * Tests that partitions of the dataset cache are placed on the same nodes as upstream cache. + */ + public void testBuild() { + IgniteCache<Integer, String> upstreamCache = createTestCache(100, 10); + CacheBasedDatasetBuilder<Integer, String> builder = new CacheBasedDatasetBuilder<>(ignite, upstreamCache); + + CacheBasedDataset<Integer, String, Long, AutoCloseable> dataset = builder.build( + (upstream, upstreamSize) -> upstreamSize, + (upstream, upstreamSize, ctx) -> null + ); + + Affinity<Integer> upstreamAffinity = ignite.affinity(upstreamCache.getName()); + Affinity<Integer> datasetAffinity = ignite.affinity(dataset.getDatasetCache().getName()); + + int upstreamPartitions = upstreamAffinity.partitions(); + int datasetPartitions = datasetAffinity.partitions(); + + assertEquals(upstreamPartitions, datasetPartitions); + + for (int part = 0; part < upstreamPartitions; part++) { + Collection<ClusterNode> upstreamPartNodes = upstreamAffinity.mapPartitionToPrimaryAndBackups(part); + Collection<ClusterNode> datasetPartNodes = datasetAffinity.mapPartitionToPrimaryAndBackups(part); + + assertEqualsCollections(upstreamPartNodes, datasetPartNodes); + } + } + + /** + * Generate an Ignite Cache with the specified size and number of partitions for testing purposes. + * + * @param size Size of an Ignite Cache. + * @param parts Number of partitions. + * @return Ignite Cache instance. + */ + private IgniteCache<Integer, String> createTestCache(int size, int parts) { + CacheConfiguration<Integer, String> cacheConfiguration = new CacheConfiguration<>(); + cacheConfiguration.setName(UUID.randomUUID().toString()); + cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, parts)); + + IgniteCache<Integer, String> cache = ignite.createCache(cacheConfiguration); + + for (int i = 0; i < size; i++) + cache.put(i, "DATA_" + i); + + return cache; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java new file mode 100644 index 0000000..f9ecb0b --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.dataset.impl.cache; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.locks.LockSupport; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteAtomicLong; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.IgniteLock; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.internal.IgniteKernal; +import org.apache.ignite.internal.processors.affinity.AffinityTopologyVersion; +import org.apache.ignite.internal.processors.cache.IgniteCacheProxy; +import org.apache.ignite.internal.processors.cache.distributed.dht.GridDhtCacheAdapter; +import org.apache.ignite.internal.processors.cache.distributed.dht.GridDhtLocalPartition; +import org.apache.ignite.internal.processors.cache.distributed.dht.GridDhtPartitionTopology; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.internal.util.typedef.G; +import org.apache.ignite.lang.IgnitePredicate; +import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; + +/** + * Tests for {@link CacheBasedDataset}. + */ +public class CacheBasedDatasetTest extends GridCommonAbstractTest { + /** Number of nodes in grid. */ + private static final int NODE_COUNT = 4; + + /** Ignite instance. */ + private Ignite ignite; + + /** {@inheritDoc} */ + @Override protected void beforeTestsStarted() throws Exception { + for (int i = 1; i <= NODE_COUNT; i++) + startGrid(i); + } + + /** {@inheritDoc} */ + @Override protected void afterTestsStopped() { + stopAllGrids(); + } + + /** {@inheritDoc} */ + @Override protected void beforeTest() throws Exception { + /* Grid instance. */ + ignite = grid(NODE_COUNT); + ignite.configuration().setPeerClassLoadingEnabled(true); + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + } + + /** + * Tests that partitions of the upstream cache and the partition {@code context} cache are reserved during + * computations on dataset. Reservation means that partitions won't be unloaded from the node before computation is + * completed. + */ + public void testPartitionExchangeDuringComputeCall() { + int partitions = 4; + + IgniteCache<Integer, String> upstreamCache = generateTestData(4, 0); + + CacheBasedDatasetBuilder<Integer, String> builder = new CacheBasedDatasetBuilder<>(ignite, upstreamCache); + + CacheBasedDataset<Integer, String, Long, AutoCloseable> dataset = builder.build( + (upstream, upstreamSize) -> upstreamSize, + (upstream, upstreamSize, ctx) -> null + ); + + assertTrue("Before computation all partitions should not be reserved", + areAllPartitionsNotReserved(upstreamCache.getName(), dataset.getDatasetCache().getName())); + + UUID numOfStartedComputationsId = UUID.randomUUID(); + IgniteAtomicLong numOfStartedComputations = ignite.atomicLong(numOfStartedComputationsId.toString(), 0, true); + + UUID computationsLockId = UUID.randomUUID(); + IgniteLock computationsLock = ignite.reentrantLock(computationsLockId.toString(), false, true, true); + + // lock computations lock to stop computations in the middle + computationsLock.lock(); + + try { + new Thread(() -> dataset.compute((data, partIndex) -> { + // track number of started computations + ignite.atomicLong(numOfStartedComputationsId.toString(), 0, false).incrementAndGet(); + ignite.reentrantLock(computationsLockId.toString(), false, true, false).lock(); + ignite.reentrantLock(computationsLockId.toString(), false, true, false).unlock(); + })).start(); + // wait all computations to start + + while (numOfStartedComputations.get() < partitions) { + } + + assertTrue("During computation all partitions should be reserved", + areAllPartitionsReserved(upstreamCache.getName(), dataset.getDatasetCache().getName())); + } + finally { + computationsLock.unlock(); + } + + assertTrue("All partitions should be released", + areAllPartitionsNotReserved(upstreamCache.getName(), dataset.getDatasetCache().getName())); + } + + /** + * Tests that partitions of the upstream cache and the partition {@code context} cache are reserved during + * computations on dataset. Reservation means that partitions won't be unloaded from the node before computation is + * completed. + */ + public void testPartitionExchangeDuringComputeWithCtxCall() { + int partitions = 4; + + IgniteCache<Integer, String> upstreamCache = generateTestData(4, 0); + + CacheBasedDatasetBuilder<Integer, String> builder = new CacheBasedDatasetBuilder<>(ignite, upstreamCache); + + CacheBasedDataset<Integer, String, Long, AutoCloseable> dataset = builder.build( + (upstream, upstreamSize) -> upstreamSize, + (upstream, upstreamSize, ctx) -> null + ); + + assertTrue("Before computation all partitions should not be reserved", + areAllPartitionsNotReserved(upstreamCache.getName(), dataset.getDatasetCache().getName())); + + UUID numOfStartedComputationsId = UUID.randomUUID(); + IgniteAtomicLong numOfStartedComputations = ignite.atomicLong(numOfStartedComputationsId.toString(), 0, true); + + UUID computationsLockId = UUID.randomUUID(); + IgniteLock computationsLock = ignite.reentrantLock(computationsLockId.toString(), false, true, true); + + // lock computations lock to stop computations in the middle + computationsLock.lock(); + + try { + new Thread(() -> dataset.computeWithCtx((ctx, data, partIndex) -> { + // track number of started computations + ignite.atomicLong(numOfStartedComputationsId.toString(), 0, false).incrementAndGet(); + ignite.reentrantLock(computationsLockId.toString(), false, true, false).lock(); + ignite.reentrantLock(computationsLockId.toString(), false, true, false).unlock(); + })).start(); + // wait all computations to start + + while (numOfStartedComputations.get() < partitions) { + } + + assertTrue("During computation all partitions should be reserved", + areAllPartitionsReserved(upstreamCache.getName(), dataset.getDatasetCache().getName())); + } + finally { + computationsLock.unlock(); + } + + assertTrue("All partitions should be released", + areAllPartitionsNotReserved(upstreamCache.getName(), dataset.getDatasetCache().getName())); + } + + /** + * Checks that all partitions of all specified caches are not reserved. + * + * @param cacheNames Cache names to be checked. + * @return {@code true} if all partitions are not reserved, otherwise {@code false}. + */ + private boolean areAllPartitionsNotReserved(String... cacheNames) { + return checkAllPartitions(partition -> partition.reservations() == 0, cacheNames); + } + + /** + * Checks that all partitions of all specified caches not reserved. + * + * @param cacheNames Cache names to be checked. + * @return {@code true} if all partitions are reserved, otherwise {@code false}. + */ + private boolean areAllPartitionsReserved(String... cacheNames) { + return checkAllPartitions(partition -> partition.reservations() != 0, cacheNames); + } + + /** + * Checks that all partitions of all specified caches satisfies the given predicate. + * + * @param pred Predicate. + * @param cacheNames Cache names. + * @return {@code true} if all partitions satisfies the given predicate. + */ + private boolean checkAllPartitions(IgnitePredicate<GridDhtLocalPartition> pred, String... cacheNames) { + boolean flag = false; + long checkingStartTs = System.currentTimeMillis(); + + while (!flag && (System.currentTimeMillis() - checkingStartTs) < 30_000) { + LockSupport.parkNanos(200 * 1000 * 1000); + flag = true; + + for (String cacheName : cacheNames) { + IgniteClusterPartitionsState state = IgniteClusterPartitionsState.getCurrentState(cacheName); + + for (IgniteInstancePartitionsState instanceState : state.instances.values()) + for (GridDhtLocalPartition partition : instanceState.parts) + if (partition != null) + flag &= pred.apply(partition); + } + } + + return flag; + } + + /** + * Aggregated data about cache partitions in Ignite cluster. + */ + private static class IgniteClusterPartitionsState { + /** */ + private final String cacheName; + + /** */ + private final Map<UUID, IgniteInstancePartitionsState> instances; + + /** */ + static IgniteClusterPartitionsState getCurrentState(String cacheName) { + Map<UUID, IgniteInstancePartitionsState> instances = new HashMap<>(); + + for (Ignite ignite : G.allGrids()) { + IgniteKernal igniteKernal = (IgniteKernal)ignite; + IgniteCacheProxy<?, ?> cache = igniteKernal.context().cache().jcache(cacheName); + + GridDhtCacheAdapter<?, ?> dht = dht(cache); + + GridDhtPartitionTopology top = dht.topology(); + + AffinityTopologyVersion topVer = dht.context().shared().exchange().readyAffinityVersion(); + List<GridDhtLocalPartition> parts = new ArrayList<>(); + for (int p = 0; p < cache.context().config().getAffinity().partitions(); p++) { + GridDhtLocalPartition part = top.localPartition(p, AffinityTopologyVersion.NONE, false); + parts.add(part); + } + instances.put(ignite.cluster().localNode().id(), new IgniteInstancePartitionsState(topVer, parts)); + } + + return new IgniteClusterPartitionsState(cacheName, instances); + } + + /** */ + IgniteClusterPartitionsState(String cacheName, + Map<UUID, IgniteInstancePartitionsState> instances) { + this.cacheName = cacheName; + this.instances = instances; + } + + /** */ + @Override public String toString() { + StringBuilder builder = new StringBuilder(); + builder.append("Cache ").append(cacheName).append(" is in following state:").append("\n"); + for (Map.Entry<UUID, IgniteInstancePartitionsState> e : instances.entrySet()) { + UUID instanceId = e.getKey(); + IgniteInstancePartitionsState instanceState = e.getValue(); + builder.append("\n\t") + .append("Node ") + .append(instanceId) + .append(" with topology version [") + .append(instanceState.topVer.topologyVersion()) + .append(", ") + .append(instanceState.topVer.minorTopologyVersion()) + .append("] contains following partitions:") + .append("\n\n"); + builder.append("\t\t---------------------------------------------------------------------------------"); + builder.append("--------------------\n"); + builder.append("\t\t| ID | STATE | RELOAD | RESERVATIONS | SHOULD BE RENTING | PRIMARY |"); + builder.append(" DATA STORE SIZE |\n"); + builder.append("\t\t---------------------------------------------------------------------------------"); + builder.append("--------------------\n"); + for (GridDhtLocalPartition partition : instanceState.parts) + if (partition != null) { + builder.append("\t\t") + .append(String.format("| %3d |", partition.id())) + .append(String.format(" %7s |", partition.state())) + .append(String.format(" %7s |", partition.reload())) + .append(String.format(" %13s |", partition.reservations())) + .append(String.format(" %18s |", partition.shouldBeRenting())) + .append(String.format(" %8s |", partition.primary(instanceState.topVer))) + .append(String.format(" %16d |", partition.dataStore().fullSize())) + .append("\n"); + builder.append("\t\t-------------------------------------------------------------------------"); + builder.append("----------------------------\n"); + } + } + return builder.toString(); + } + } + + /** + * Aggregated data about cache partitions in Ignite instance. + */ + private static class IgniteInstancePartitionsState { + /** */ + private final AffinityTopologyVersion topVer; + + /** */ + private final List<GridDhtLocalPartition> parts; + + /** */ + IgniteInstancePartitionsState(AffinityTopologyVersion topVer, + List<GridDhtLocalPartition> parts) { + this.topVer = topVer; + this.parts = parts; + } + + /** */ + public AffinityTopologyVersion getTopVer() { + return topVer; + } + + /** */ + public List<GridDhtLocalPartition> getParts() { + return parts; + } + } + + /** + * Generates Ignite Cache with data for tests. + * + * @return Ignite Cache with data for tests. + */ + private IgniteCache<Integer, String> generateTestData(int partitions, int backups) { + CacheConfiguration<Integer, String> cacheConfiguration = new CacheConfiguration<>(); + + cacheConfiguration.setName(UUID.randomUUID().toString()); + cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, partitions)); + cacheConfiguration.setBackups(backups); + + IgniteCache<Integer, String> cache = ignite.createCache(cacheConfiguration); + + for (int i = 0; i < 1000; i++) + cache.put(i, "TEST" + i); + + return cache; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java new file mode 100644 index 0000000..4926a90 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.dataset.impl.cache.util; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.UUID; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteAtomicLong; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.affinity.AffinityFunction; +import org.apache.ignite.cache.affinity.AffinityFunctionContext; +import org.apache.ignite.cluster.ClusterNode; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; + +/** + * Tests for {@link ComputeUtils}. + */ +public class ComputeUtilsTest extends GridCommonAbstractTest { + /** Number of nodes in grid. */ + private static final int NODE_COUNT = 10; + + /** Ignite instance. */ + private Ignite ignite; + + /** {@inheritDoc} */ + @Override protected void beforeTestsStarted() throws Exception { + for (int i = 1; i <= NODE_COUNT; i++) + startGrid(i); + } + + /** {@inheritDoc} */ + @Override protected void afterTestsStopped() { + stopAllGrids(); + } + + /** {@inheritDoc} */ + @Override protected void beforeTest() throws Exception { + /* Grid instance. */ + ignite = grid(NODE_COUNT); + ignite.configuration().setPeerClassLoadingEnabled(true); + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + } + + /** + * Tests that in case two caches maintain their partitions on different nodes, affinity call won't be completed. + */ + public void testAffinityCallWithRetriesNegative() { + ClusterNode node1 = grid(1).cluster().localNode(); + ClusterNode node2 = grid(2).cluster().localNode(); + + String firstCacheName = "CACHE_1_" + UUID.randomUUID(); + String secondCacheName = "CACHE_2_" + UUID.randomUUID(); + + CacheConfiguration<Integer, Integer> cacheConfiguration1 = new CacheConfiguration<>(); + cacheConfiguration1.setName(firstCacheName); + cacheConfiguration1.setAffinity(new TestAffinityFunction(node1)); + IgniteCache<Integer, Integer> cache1 = ignite.createCache(cacheConfiguration1); + + CacheConfiguration<Integer, Integer> cacheConfiguration2 = new CacheConfiguration<>(); + cacheConfiguration2.setName(secondCacheName); + cacheConfiguration2.setAffinity(new TestAffinityFunction(node2)); + IgniteCache<Integer, Integer> cache2 = ignite.createCache(cacheConfiguration2); + + try { + try { + ComputeUtils.affinityCallWithRetries( + ignite, + Arrays.asList(firstCacheName, secondCacheName), + part -> part, + 0 + ); + } + catch (IllegalStateException expectedException) { + return; + } + + fail("Missing IllegalStateException"); + } + finally { + cache1.destroy(); + cache2.destroy(); + } + } + + /** + * Test that in case two caches maintain their partitions on the same node, affinity call will be completed. + */ + public void testAffinityCallWithRetriesPositive() { + ClusterNode node = grid(1).cluster().localNode(); + + String firstCacheName = "CACHE_1_" + UUID.randomUUID(); + String secondCacheName = "CACHE_2_" + UUID.randomUUID(); + + CacheConfiguration<Integer, Integer> cacheConfiguration1 = new CacheConfiguration<>(); + cacheConfiguration1.setName(firstCacheName); + cacheConfiguration1.setAffinity(new TestAffinityFunction(node)); + IgniteCache<Integer, Integer> cache1 = ignite.createCache(cacheConfiguration1); + + CacheConfiguration<Integer, Integer> cacheConfiguration2 = new CacheConfiguration<>(); + cacheConfiguration2.setName(secondCacheName); + cacheConfiguration2.setAffinity(new TestAffinityFunction(node)); + IgniteCache<Integer, Integer> cache2 = ignite.createCache(cacheConfiguration2); + + try (IgniteAtomicLong cnt = ignite.atomicLong("COUNTER_" + UUID.randomUUID(), 0, true)) { + + ComputeUtils.affinityCallWithRetries(ignite, Arrays.asList(firstCacheName, secondCacheName), part -> { + Ignite locIgnite = Ignition.localIgnite(); + + assertEquals(node, locIgnite.cluster().localNode()); + + cnt.incrementAndGet(); + + return part; + }, 0); + + assertEquals(1, cnt.get()); + } + finally { + cache1.destroy(); + cache2.destroy(); + } + } + + /** + * Tests {@code getData()} method. + */ + public void testGetData() { + ClusterNode node = grid(1).cluster().localNode(); + + String upstreamCacheName = "CACHE_1_" + UUID.randomUUID(); + String datasetCacheName = "CACHE_2_" + UUID.randomUUID(); + + CacheConfiguration<Integer, Integer> upstreamCacheConfiguration = new CacheConfiguration<>(); + upstreamCacheConfiguration.setName(upstreamCacheName); + upstreamCacheConfiguration.setAffinity(new TestAffinityFunction(node)); + IgniteCache<Integer, Integer> upstreamCache = ignite.createCache(upstreamCacheConfiguration); + + CacheConfiguration<Integer, Integer> datasetCacheConfiguration = new CacheConfiguration<>(); + datasetCacheConfiguration.setName(datasetCacheName); + datasetCacheConfiguration.setAffinity(new TestAffinityFunction(node)); + IgniteCache<Integer, Integer> datasetCache = ignite.createCache(datasetCacheConfiguration); + + upstreamCache.put(42, 42); + datasetCache.put(0, 0); + + UUID datasetId = UUID.randomUUID(); + + IgniteAtomicLong cnt = ignite.atomicLong("CNT_" + datasetId, 0, true); + + for (int i = 0; i < 10; i++) { + Collection<TestPartitionData> data = ComputeUtils.affinityCallWithRetries( + ignite, + Arrays.asList(datasetCacheName, upstreamCacheName), + part -> ComputeUtils.<Integer, Integer, Serializable, TestPartitionData>getData( + ignite, + upstreamCacheName, + datasetCacheName, + datasetId, + 0, + (upstream, upstreamSize, ctx) -> { + cnt.incrementAndGet(); + + assertEquals(1, upstreamSize); + + UpstreamEntry<Integer, Integer> e = upstream.next(); + return new TestPartitionData(e.getKey() + e.getValue()); + } + ), + 0 + ); + + assertEquals(1, data.size()); + + TestPartitionData dataElement = data.iterator().next(); + assertEquals(84, dataElement.val.intValue()); + } + + assertEquals(1, cnt.get()); + } + + /** + * Tests {@code initContext()} method. + */ + public void testInitContext() { + ClusterNode node = grid(1).cluster().localNode(); + + String upstreamCacheName = "CACHE_1_" + UUID.randomUUID(); + String datasetCacheName = "CACHE_2_" + UUID.randomUUID(); + + CacheConfiguration<Integer, Integer> upstreamCacheConfiguration = new CacheConfiguration<>(); + upstreamCacheConfiguration.setName(upstreamCacheName); + upstreamCacheConfiguration.setAffinity(new TestAffinityFunction(node)); + IgniteCache<Integer, Integer> upstreamCache = ignite.createCache(upstreamCacheConfiguration); + + CacheConfiguration<Integer, Integer> datasetCacheConfiguration = new CacheConfiguration<>(); + datasetCacheConfiguration.setName(datasetCacheName); + datasetCacheConfiguration.setAffinity(new TestAffinityFunction(node)); + IgniteCache<Integer, Integer> datasetCache = ignite.createCache(datasetCacheConfiguration); + + upstreamCache.put(42, 42); + + ComputeUtils.<Integer, Integer, Integer>initContext( + ignite, + upstreamCacheName, + datasetCacheName, + (upstream, upstreamSize) -> { + + assertEquals(1, upstreamSize); + + UpstreamEntry<Integer, Integer> e = upstream.next(); + return e.getKey() + e.getValue(); + }, + 0 + ); + + assertEquals(1, datasetCache.size()); + assertEquals(84, datasetCache.get(0).intValue()); + } + + /** + * Test partition data. + */ + private static class TestPartitionData implements AutoCloseable { + /** Value. */ + private final Integer val; + + /** + * Constructs a new instance of test partition data. + * + * @param val Value. + */ + TestPartitionData(Integer val) { + this.val = val; + } + + /** {@inheritDoc} */ + @Override public void close() throws Exception { + // Do nothing, GC will clean up. + } + } + + /** + * Affinity function used in tests in this class. Defines one partition and assign it on the specified cluster node. + */ + private static class TestAffinityFunction implements AffinityFunction { + /** */ + private static final long serialVersionUID = -1353725303983563094L; + + /** Cluster node partition will be assigned on. */ + private final ClusterNode node; + + /** + * Constructs a new instance of test affinity function. + * + * @param node Cluster node partition will be assigned on. + */ + TestAffinityFunction(ClusterNode node) { + this.node = node; + } + + /** {@inheritDoc} */ + @Override public void reset() { + // Do nothing. + } + + /** {@inheritDoc} */ + @Override public int partitions() { + return 1; + } + + /** {@inheritDoc} */ + @Override public int partition(Object key) { + return 0; + } + + /** {@inheritDoc} */ + @Override public List<List<ClusterNode>> assignPartitions(AffinityFunctionContext affCtx) { + return Collections.singletonList(Collections.singletonList(node)); + } + + /** {@inheritDoc} */ + @Override public void removeNode(UUID nodeId) { + // Do nothing. + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/DatasetAffinityFunctionWrapperTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/DatasetAffinityFunctionWrapperTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/DatasetAffinityFunctionWrapperTest.java new file mode 100644 index 0000000..2628aa6 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/DatasetAffinityFunctionWrapperTest.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.dataset.impl.cache.util; + +import java.util.Collections; +import java.util.List; +import java.util.UUID; +import org.apache.ignite.cache.affinity.AffinityFunction; +import org.apache.ignite.cache.affinity.AffinityFunctionContext; +import org.apache.ignite.cluster.ClusterNode; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link DatasetAffinityFunctionWrapper}. + */ +@RunWith(MockitoJUnitRunner.class) +public class DatasetAffinityFunctionWrapperTest { + /** Mocked affinity function. */ + @Mock + private AffinityFunction affinityFunction; + + /** Wrapper. */ + private DatasetAffinityFunctionWrapper wrapper; + + /** Initialization. */ + @Before + public void beforeTest() { + wrapper = new DatasetAffinityFunctionWrapper(affinityFunction); + } + + /** Tests {@code reset()} method. */ + @Test + public void testReset() { + wrapper.reset(); + + verify(affinityFunction, times(1)).reset(); + } + + /** Tests {@code partitions()} method. */ + @Test + public void testPartitions() { + doReturn(42).when(affinityFunction).partitions(); + + int partitions = wrapper.partitions(); + + assertEquals(42, partitions); + verify(affinityFunction, times(1)).partitions(); + } + + /** Tests {@code partition} method. */ + @Test + public void testPartition() { + doReturn(0).when(affinityFunction).partition(eq(42)); + + int part = wrapper.partition(42); + + assertEquals(42, part); + verify(affinityFunction, times(0)).partition(any()); + } + + /** Tests {@code assignPartitions()} method. */ + @Test + public void testAssignPartitions() { + List<List<ClusterNode>> nodes = Collections.singletonList(Collections.singletonList(mock(ClusterNode.class))); + + doReturn(nodes).when(affinityFunction).assignPartitions(any()); + + List<List<ClusterNode>> resNodes = wrapper.assignPartitions(mock(AffinityFunctionContext.class)); + + assertEquals(nodes, resNodes); + verify(affinityFunction, times(1)).assignPartitions(any()); + } + + /** Tests {@code removeNode()} method. */ + @Test + public void testRemoveNode() { + UUID nodeId = UUID.randomUUID(); + + wrapper.removeNode(nodeId); + + verify(affinityFunction, times(1)).removeNode(eq(nodeId)); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/PartitionDataStorageTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/PartitionDataStorageTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/PartitionDataStorageTest.java new file mode 100644 index 0000000..eab2be1 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/PartitionDataStorageTest.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.dataset.impl.cache.util; + +import java.util.concurrent.atomic.AtomicLong; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link PartitionDataStorage}. + */ +public class PartitionDataStorageTest { + /** Data storage. */ + private PartitionDataStorage dataStorage = new PartitionDataStorage(); + + /** Tests {@code computeDataIfAbsent()} method. */ + @Test + public void testComputeDataIfAbsent() { + AtomicLong cnt = new AtomicLong(); + + for (int i = 0; i < 10; i++) { + Integer res = (Integer) dataStorage.computeDataIfAbsent(0, () -> { + cnt.incrementAndGet(); + + return 42; + }); + + assertEquals(42, res.intValue()); + } + + assertEquals(1, cnt.intValue()); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilderTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilderTest.java new file mode 100644 index 0000000..0628580 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilderTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.dataset.impl.local; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link LocalDatasetBuilder}. + */ +public class LocalDatasetBuilderTest { + /** Tests {@code build()} method. */ + @Test + public void testBuild() { + Map<Integer, Integer> data = new HashMap<>(); + for (int i = 0; i < 100; i++) + data.put(i, i); + + LocalDatasetBuilder<Integer, Integer> builder = new LocalDatasetBuilder<>(data, 10); + + LocalDataset<Serializable, TestPartitionData> dataset = builder.build( + (upstream, upstreamSize) -> null, + (upstream, upstreamSize, ctx) -> { + int[] arr = new int[Math.toIntExact(upstreamSize)]; + + int ptr = 0; + while (upstream.hasNext()) + arr[ptr++] = upstream.next().getValue(); + + return new TestPartitionData(arr); + } + ); + + AtomicLong cnt = new AtomicLong(); + + dataset.compute((partData, partIdx) -> { + cnt.incrementAndGet(); + + int[] arr = partData.data; + + assertEquals(10, arr.length); + + for (int i = 0; i < 10; i++) + assertEquals(partIdx * 10 + i, arr[i]); + }); + + assertEquals(10, cnt.intValue()); + } + + /** + * Test partition {@code data}. + */ + private static class TestPartitionData implements AutoCloseable { + /** Data. */ + private int[] data; + + /** + * Constructs a new test partition data instance. + * + * @param data Data. + */ + TestPartitionData(int[] data) { + this.data = data; + } + + /** {@inheritDoc} */ + @Override public void close() throws Exception { + // Do nothing, GC will clean up. + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/DatasetWrapperTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/DatasetWrapperTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/DatasetWrapperTest.java new file mode 100644 index 0000000..b42b604 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/DatasetWrapperTest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.dataset.primitive; + +import java.io.Serializable; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteBinaryOperator; +import org.apache.ignite.ml.math.functions.IgniteTriFunction; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link DatasetWrapper}. + */ +@RunWith(MockitoJUnitRunner.class) +public class DatasetWrapperTest { + /** Mocked dataset. */ + @Mock + private Dataset<Serializable, AutoCloseable> dataset; + + /** Dataset wrapper. */ + private DatasetWrapper<Serializable, AutoCloseable> wrapper; + + /** Initialization. */ + @Before + public void beforeTest() { + wrapper = new DatasetWrapper<>(dataset); + } + + /** Tests {@code computeWithCtx()} method. */ + @Test + @SuppressWarnings("unchecked") + public void testComputeWithCtx() { + doReturn(42).when(dataset).computeWithCtx(any(IgniteTriFunction.class), any(), any()); + + Integer res = wrapper.computeWithCtx(mock(IgniteTriFunction.class), mock(IgniteBinaryOperator.class), null); + + assertEquals(42, res.intValue()); + verify(dataset, times(1)).computeWithCtx(any(IgniteTriFunction.class), any(), any()); + } + + /** Tests {@code compute()} method. */ + @Test + @SuppressWarnings("unchecked") + public void testCompute() { + doReturn(42).when(dataset).compute(any(IgniteBiFunction.class), any(), any()); + + Integer res = wrapper.compute(mock(IgniteBiFunction.class), mock(IgniteBinaryOperator.class), null); + + assertEquals(42, res.intValue()); + verify(dataset, times(1)).compute(any(IgniteBiFunction.class), any(), any()); + } + + /** Tests {@code close()} method. */ + @Test + public void testClose() throws Exception { + wrapper.close(); + + verify(dataset, times(1)).close(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java new file mode 100644 index 0000000..1b25908 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing; + +import org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessorTest; +import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainerTest; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * Test suite for all tests located in org.apache.ignite.ml.preprocessing.* package. + */ +@RunWith(Suite.class) +@Suite.SuiteClasses({ + NormalizationPreprocessorTest.class, + NormalizationTrainerTest.class +}) +public class PreprocessingTestSuite { + // No-op. +} http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationPreprocessorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationPreprocessorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationPreprocessorTest.java new file mode 100644 index 0000000..c9eb765 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationPreprocessorTest.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.normalization; + +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link NormalizationPreprocessor}. + */ +public class NormalizationPreprocessorTest { + /** Tests {@code apply()} method. */ + @Test + public void testApply() { + double[][] data = new double[][]{ + {2., 4., 1.}, + {1., 8., 22.}, + {4., 10., 100.}, + {0., 22., 300.} + }; + + NormalizationPreprocessor<Integer, double[]> preprocessor = new NormalizationPreprocessor<>( + new double[] {0, 4, 1}, + new double[] {4, 22, 300}, + (k, v) -> v + ); + + double[][] standardData = new double[][]{ + {2. / 4, (4. - 4.) / 18., 0.}, + {1. / 4, (8. - 4.) / 18., (22. - 1.) / 299.}, + {1., (10. - 4.) / 18., (100. - 1.) / 299.}, + {0., (22. - 4.) / 18., (300. - 1.) / 299.} + }; + + for (int i = 0; i < data.length; i++) + assertArrayEquals(standardData[i], preprocessor.apply(i, data[i]), 1e-8); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/54bac750/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java new file mode 100644 index 0000000..1548253 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.normalization; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link NormalizationTrainer}. + */ +@RunWith(Parameterized.class) +public class NormalizationTrainerTest { + /** Parameters. */ + @Parameterized.Parameters(name = "Data divided on {0} partitions") + public static Iterable<Integer[]> data() { + return Arrays.asList( + new Integer[] {1}, + new Integer[] {2}, + new Integer[] {3}, + new Integer[] {5}, + new Integer[] {7}, + new Integer[] {100}, + new Integer[] {1000} + ); + } + + /** Number of partitions. */ + @Parameterized.Parameter + public int parts; + + /** Tests {@code fit()} method. */ + @Test + public void testFit() { + Map<Integer, double[]> data = new HashMap<>(); + data.put(1, new double[] {2, 4, 1}); + data.put(2, new double[] {1, 8, 22}); + data.put(3, new double[] {4, 10, 100}); + data.put(4, new double[] {0, 22, 300}); + + DatasetBuilder<Integer, double[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts); + + NormalizationTrainer<Integer, double[]> standardizationTrainer = new NormalizationTrainer<>(); + + NormalizationPreprocessor<Integer, double[]> preprocessor = standardizationTrainer.fit( + datasetBuilder, + (k, v) -> v, + 3 + ); + + assertArrayEquals(new double[] {0, 4, 1}, preprocessor.getMin(), 1e-8); + assertArrayEquals(new double[] {4, 22, 300}, preprocessor.getMax(), 1e-8); + } +}