IGNITE-9284: [ML] Add a Standard Scaler this closes #4964
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/41f4225c Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/41f4225c Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/41f4225c Branch: refs/heads/ignite-9720 Commit: 41f4225c4b2f2735bce4ce861b9a51afc80d5815 Parents: 46a84fd Author: Ravil Galeyev <[email protected]> Authored: Tue Nov 27 14:05:17 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Tue Nov 27 14:05:17 2018 +0300 ---------------------------------------------------------------------- .../ml/preprocessing/StandardScalerExample.java | 84 +++++++++++++++ .../standardscaling/StandardScalerData.java | 56 ++++++++++ .../StandardScalerPreprocessor.java | 91 +++++++++++++++++ .../standardscaling/StandardScalerTrainer.java | 101 +++++++++++++++++++ .../standardscaling/package-info.java | 22 ++++ .../StandardScalerPreprocessorTest.java | 59 +++++++++++ .../StandardScalerTrainerTest.java | 85 ++++++++++++++++ 7 files changed, 498 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/41f4225c/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/StandardScalerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/StandardScalerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/StandardScalerExample.java new file mode 100644 index 0000000..13d8635 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/StandardScalerExample.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.preprocessing; + +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.examples.ml.dataset.model.Person; +import org.apache.ignite.examples.ml.util.DatasetHelper; +import org.apache.ignite.ml.dataset.DatasetFactory; +import org.apache.ignite.ml.dataset.primitive.SimpleDataset; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.preprocessing.standardscaling.StandardScalerTrainer; + +/** + * Example that shows how to use StandardScaler preprocessor to scale the given data. + * + * Machine learning preprocessors are built as a chain. Most often the first preprocessor is a feature extractor as + * shown in this example. The second preprocessor here is a {@code StandardScaler} preprocessor which is built on top of + * the feature extractor and represents a chain of itself and the underlying feature extractor. + */ +public class StandardScalerExample { + /** Run example. */ + public static void main(String[] args) throws Exception { + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Standard scaler example started."); + + IgniteCache<Integer, Person> persons = createCache(ignite); + + // Defines first preprocessor that extracts features from an upstream data. + IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of( + v.getAge(), + v.getSalary() + ); + + // Defines second preprocessor that processes features. + IgniteBiFunction<Integer, Person, Vector> preprocessor = new StandardScalerTrainer<Integer, Person>() + .fit(ignite, persons, featureExtractor); + + // Creates a cache based simple dataset containing features and providing standard dataset API. + try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, preprocessor)) { + new DatasetHelper(dataset).describe(); + } + + System.out.println(">>> Standard scaler example completed."); + } + } + + /** */ + private static IgniteCache<Integer, Person> createCache(Ignite ignite) { + CacheConfiguration<Integer, Person> cacheConfiguration = new CacheConfiguration<>(); + + cacheConfiguration.setName("PERSONS"); + cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 2)); + + IgniteCache<Integer, Person> persons = ignite.createCache(cacheConfiguration); + + persons.put(1, new Person("Mike", 42, 10000)); + persons.put(2, new Person("John", 32, 64000)); + persons.put(3, new Person("George", 53, 120000)); + persons.put(4, new Person("Karl", 24, 70000)); + + return persons; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/41f4225c/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerData.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerData.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerData.java new file mode 100644 index 0000000..f96dcc5 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerData.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.preprocessing.standardscaling; + +/** A Service class for {@link StandardScalerTrainer} which used for sums holing. */ +public class StandardScalerData implements AutoCloseable { + /** Sum values of every feature. */ + double[] sum; + /** Sum of squared values of every feature. */ + double[] squaredSum; + /** Rows count */ + long cnt; + + /** + * Creates {@code StandardScalerData}. + * + * @param sum Sum values of every feature. + * @param squaredSum Sum of squared values of every feature. + * @param cnt Rows count. + */ + public StandardScalerData(double[] sum, double[] squaredSum, long cnt) { + this.sum = sum; + this.squaredSum = squaredSum; + this.cnt = cnt; + } + + /** Merges to current. */ + StandardScalerData merge(StandardScalerData that) { + for (int i = 0; i < sum.length; i++) { + sum[i] += that.sum[i]; + squaredSum[i] += that.squaredSum[i]; + } + + cnt += that.cnt; + return this; + } + + /** */ + @Override public void close() { + // Do nothing, GC will clean up. + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/41f4225c/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessor.java new file mode 100644 index 0000000..293e86a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessor.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.standardscaling; + +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; + +/** + * The preprocessing function that makes standard scaling, transforms features to make {@code mean} equal to {@code 0} + * and {@code variance} equal to {@code 1}. From mathematical point of view it's the following function which is applied + * to every element in a dataset: + * + * {@code a_i = (a_i - mean_i) / sigma_i for all i}, + * + * where {@code i} is a number of column, {@code mean_i} is the mean value this column and {@code sigma_i} is the + * standard deviation in this column. + * + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + */ +public class StandardScalerPreprocessor<K, V> implements IgniteBiFunction<K, V, Vector> { + /** */ + private static final long serialVersionUID = -5977957318991608203L; + + /** Means for each column. */ + private final double[] means; + /** Standard deviation for each column. */ + private final double[] sigmas; + + /** Base preprocessor. */ + private final IgniteBiFunction<K, V, Vector> basePreprocessor; + + /** + * Constructs a new instance of standardscaling preprocessor. + * + * @param means Means of each column. + * @param sigmas Standard deviations in each column. + * @param basePreprocessor Base preprocessor. + */ + public StandardScalerPreprocessor(double[] means, double[] sigmas, + IgniteBiFunction<K, V, Vector> basePreprocessor) { + assert means.length == sigmas.length; + + this.means = means; + this.sigmas = sigmas; + this.basePreprocessor = basePreprocessor; + } + + /** + * Applies this preprocessor. + * + * @param k Key. + * @param v Value. + * @return Preprocessed row. + */ + @Override public Vector apply(K k, V v) { + Vector res = basePreprocessor.apply(k, v); + + assert res.size() == means.length; + + for (int i = 0; i < res.size(); i++) + res.set(i, (res.get(i) - means[i]) / sigmas[i]); + + return res; + } + + /** */ + public double[] getMeans() { + return means; + } + + /** */ + public double[] getSigmas() { + return sigmas; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/41f4225c/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainer.java new file mode 100644 index 0000000..3661772 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainer.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.standardscaling; + +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; + +/** + * Trainer of the standard scaler preprocessor. + * + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + */ +public class StandardScalerTrainer<K, V> implements PreprocessingTrainer<K, V, Vector, Vector> { + /** {@inheritDoc} */ + @Override public StandardScalerPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> basePreprocessor) { + StandardScalerData standardScalerData = computeSum(datasetBuilder, basePreprocessor); + + int n = standardScalerData.sum.length; + long cnt = standardScalerData.cnt; + double[] mean = new double[n]; + double[] sigma = new double[n]; + + for (int i = 0; i < n; i++) { + mean[i] = standardScalerData.sum[i] / cnt; + double variace = (standardScalerData.squaredSum[i] - Math.pow(standardScalerData.sum[i], 2) / cnt) / cnt; + sigma[i] = Math.sqrt(variace); + } + return new StandardScalerPreprocessor<>(mean, sigma, basePreprocessor); + } + + /** Computes sum, squared sum and row count. */ + private StandardScalerData computeSum(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> basePreprocessor) { + try (Dataset<EmptyContext, StandardScalerData> dataset = datasetBuilder.build( + (upstream, upstreamSize) -> new EmptyContext(), + (upstream, upstreamSize, ctx) -> { + double[] sum = null; + double[] squaredSum = null; + long cnt = 0; + + while (upstream.hasNext()) { + UpstreamEntry<K, V> entity = upstream.next(); + Vector row = basePreprocessor.apply(entity.getKey(), entity.getValue()); + + if (sum == null) { + sum = new double[row.size()]; + squaredSum = new double[row.size()]; + } + else { + assert sum.length == row.size() : "Base preprocessor must return exactly " + sum.length + + " features"; + } + + ++cnt; + for (int i = 0; i < row.size(); i++) { + double x = row.get(i); + sum[i] += x; + squaredSum[i] += x * x; + } + } + return new StandardScalerData(sum, squaredSum, cnt); + } + )) { + + return dataset.compute(data -> data, + (a, b) -> { + if (a == null) + return b; + if (b == null) + return a; + + return a.merge(b); + }); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/41f4225c/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/package-info.java new file mode 100644 index 0000000..5f5de3b --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains Standard scaler preprocessor. + */ +package org.apache.ignite.ml.preprocessing.standardscaling; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/41f4225c/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessorTest.java new file mode 100644 index 0000000..3c325b3 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessorTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.standardscaling; + +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link StandardScalerPreprocessor}. + */ +public class StandardScalerPreprocessorTest { + + /** Test {@code apply()} method. */ + @Test + public void testApply() { + double[][] inputData = new double[][] { + {0, 2., 4., .1}, + {0, 1., -18., 2.2}, + {1, 4., 10., -.1}, + {1, 0., 22., 1.3} + }; + double[] means = new double[] {0.5, 1.75, 4.5, 0.875}; + double[] sigmas = new double[] {0.5, 1.47901995, 14.51723114, 0.93374247}; + + StandardScalerPreprocessor<Integer, Vector> preprocessor = new StandardScalerPreprocessor<>( + means, + sigmas, + (k, v) -> v + ); + + double[][] expectedData = new double[][] { + {-1., 0.16903085, -0.03444183, -0.82999331}, + {-1., -0.50709255, -1.54988233, 1.41902081}, + {1., 1.52127766, 0.37886012, -1.04418513}, + {1., -1.18321596, 1.20546403, 0.45515762} + }; + + for (int i = 0; i < inputData.length; i++) + assertArrayEquals(expectedData[i], preprocessor.apply(i, VectorUtils.of(inputData[i])).asArray(), 1e-8); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/41f4225c/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainerTest.java new file mode 100644 index 0000000..679cc48 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainerTest.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.standardscaling; + +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.common.TrainerTest; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link StandardScalerTrainer}. + */ +public class StandardScalerTrainerTest extends TrainerTest { + + /** Data. */ + private DatasetBuilder<Integer, Vector> datasetBuilder; + + /** Trainer to be tested. */ + private StandardScalerTrainer<Integer, Vector> standardizationTrainer; + + /** */ + @Before + public void prepareDataset() { + Map<Integer, Vector> data = new HashMap<>(); + data.put(1, VectorUtils.of(0, 2., 4., .1)); + data.put(2, VectorUtils.of(0, 1., -18., 2.2)); + data.put(3, VectorUtils.of(1, 4., 10., -.1)); + data.put(4, VectorUtils.of(1, 0., 22., 1.3)); + datasetBuilder = new LocalDatasetBuilder<>(data, parts); + } + + /** */ + @Before + public void createTrainer() { + standardizationTrainer = new StandardScalerTrainer<>(); + } + + /** Test {@code fit()} method. */ + @Test + public void testCalculatesCorrectMeans() { + double[] expectedMeans = new double[] {0.5, 1.75, 4.5, 0.875}; + + StandardScalerPreprocessor<Integer, Vector> preprocessor = standardizationTrainer.fit( + datasetBuilder, + (k, v) -> v + ); + + assertArrayEquals(expectedMeans, preprocessor.getMeans(), 1e-8); + } + + /** Test {@code fit()} method. */ + @Test + public void testCalculatesCorrectStandardDeviations() { + double[] expectedSigmas = new double[] {0.5, 1.47901995, 14.51723114, 0.93374247}; + + StandardScalerPreprocessor<Integer, Vector> preprocessor = standardizationTrainer.fit( + datasetBuilder, + (k, v) -> v + ); + + assertArrayEquals(expectedSigmas, preprocessor.getSigmas(), 1e-8); + } +}
