Repository: ignite Updated Branches: refs/heads/master 9249efda5 -> 7a5aa7c6b
IGNITE-8664: Encoding categorical features with One-of-K Encoder this closes #4106 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/7a5aa7c6 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/7a5aa7c6 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/7a5aa7c6 Branch: refs/heads/master Commit: 7a5aa7c6b91bbf8c6bbdabc232f3d09ac1b015f9 Parents: 9249efd Author: zaleslaw <[email protected]> Authored: Fri Jun 1 17:20:40 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Fri Jun 1 17:20:40 2018 +0300 ---------------------------------------------------------------------- .../preprocessing/UnknownStringValue.java | 35 +++++ .../ml/preprocessing/encoding/package-info.java | 22 +++ .../StringEncoderPartitionData.java | 62 ++++++++ .../StringEncoderPreprocessor.java | 70 +++++++++ .../stringencoder/StringEncoderTrainer.java | 152 +++++++++++++++++++ .../encoding/stringencoder/package-info.java | 22 +++ .../preprocessing/PreprocessingTestSuite.java | 6 +- .../encoding/StringEncoderPreprocessorTest.java | 67 ++++++++ .../encoding/StringEncoderTrainerTest.java | 78 ++++++++++ 9 files changed, 513 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownStringValue.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownStringValue.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownStringValue.java new file mode 100644 index 0000000..f2312a1 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownStringValue.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.exceptions.preprocessing; + +import org.apache.ignite.IgniteException; + +/** + * Indicates an unknown String value for StringEncoder. + */ +public class UnknownStringValue extends IgniteException { + /** */ + private static final long serialVersionUID = 0L; + + /** + * @param unknownString String value that caused this exception. + */ + public UnknownStringValue(String unknownString) { + super("This String value is unknown for StringEncoder: " + unknownString); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/package-info.java new file mode 100644 index 0000000..436ad8f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains encoding preprocessors. + */ +package org.apache.ignite.ml.preprocessing.encoding; http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPartitionData.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPartitionData.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPartitionData.java new file mode 100644 index 0000000..acd2aa2 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPartitionData.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.encoding.stringencoder; + +import java.util.Map; + +/** + * Partition data used in String Encoder preprocessor. + * + * @see StringEncoderTrainer + * @see StringEncoderPreprocessor + */ +public class StringEncoderPartitionData implements AutoCloseable { + /** Frequencies of categories for each categorial feature presented as strings. */ + private Map<String, Integer>[] categoryFrequencies; + + /** + * Constructs a new instance of String Encoder partition data. + */ + public StringEncoderPartitionData() { + } + + /** + * Gets the array of maps of frequencies by value in partition for each feature in the dataset. + * + * @return The frequencies. + */ + public Map<String, Integer>[] categoryFrequencies() { + return categoryFrequencies; + } + + /** + * Sets the array of maps of frequencies by value in partition for each feature in the dataset. + * + * @param categoryFrequencies The given value. + * @return The partition data. + */ + public StringEncoderPartitionData withCategoryFrequencies(Map<String, Integer>[] categoryFrequencies) { + this.categoryFrequencies = categoryFrequencies; + return this; + } + + /** */ + @Override public void close() { + // Do nothing, GC will clean up. + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java new file mode 100644 index 0000000..4b21e67 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.encoding.stringencoder; + +import java.util.Map; +import org.apache.ignite.ml.math.exceptions.preprocessing.UnknownStringValue; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; + +/** + * Preprocessing function that makes String encoding. + * + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + */ +public class StringEncoderPreprocessor<K, V> implements IgniteBiFunction<K, V, double[]> { + /** */ + private static final long serialVersionUID = 6237812226382623469L; + + /** Filling values. */ + private final Map<String, Integer>[] encodingValues; + + /** Base preprocessor. */ + private final IgniteBiFunction<K, V, String[]> basePreprocessor; + + /** + * Constructs a new instance of String Encoder preprocessor. + * + * @param basePreprocessor Base preprocessor. + */ + public StringEncoderPreprocessor(Map<String, Integer>[] encodingValues, + IgniteBiFunction<K, V, String[]> basePreprocessor) { + this.encodingValues = encodingValues; + this.basePreprocessor = basePreprocessor; + } + + /** + * Applies this preprocessor. + * + * @param k Key. + * @param v Value. + * @return Preprocessed row. + */ + @Override public double[] apply(K k, V v) { + String[] tmp = basePreprocessor.apply(k, v); + double[] res = new double[tmp.length]; + + for (int i = 0; i < res.length; i++) { + if (encodingValues[i].containsKey(tmp[i])) + res[i] = encodingValues[i].get(tmp[i]); + else + throw new UnknownStringValue(tmp[i]); + } + return res; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java new file mode 100644 index 0000000..5a4d090 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.encoding.stringencoder; + +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; + +/** + * Trainer of the String Encoder preprocessor. + * The String Encoder encodes string values (categories) to double values in range [0.0, amountOfCategories) + * where the most popular value will be presented as 0.0 and the least popular value presented with amountOfCategories-1 value. + * + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + */ +public class StringEncoderTrainer<K, V> implements PreprocessingTrainer<K, V, String[], double[]> { + /** {@inheritDoc} */ + @Override public StringEncoderPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, String[]> basePreprocessor) { + try (Dataset<EmptyContext, StringEncoderPartitionData> dataset = datasetBuilder.build( + (upstream, upstreamSize) -> new EmptyContext(), + (upstream, upstreamSize, ctx) -> { + Map<String, Integer>[] categoryFrequencies = null; + + while (upstream.hasNext()) { + UpstreamEntry<K, V> entity = upstream.next(); + String[] row = basePreprocessor.apply(entity.getKey(), entity.getValue()); + categoryFrequencies = calculateFrequencies(row, categoryFrequencies); + + } + return new StringEncoderPartitionData() + .withCategoryFrequencies(categoryFrequencies); + } + )) { + Map<String, Integer>[] encodingValues = calculateEncodingValuesByFrequencies(dataset); + + return new StringEncoderPreprocessor<>(encodingValues, basePreprocessor); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Calculates the encoding values values by frequencies keeping in the given dataset. + * + * @param dataset The dataset of frequencies for each feature aggregated in each partition. + * @return Encoding values for each feature. + */ + private Map<String, Integer>[] calculateEncodingValuesByFrequencies( + Dataset<EmptyContext, StringEncoderPartitionData> dataset) { + Map<String, Integer>[] frequencies = dataset.compute( + StringEncoderPartitionData::categoryFrequencies, + (a, b) -> { + if (a == null) + return b; + + if (b == null) + return a; + + assert a.length == b.length; + + for (int i = 0; i < a.length; i++) { + int finalI = i; + a[i].forEach((k, v) -> b[finalI].merge(k, v, (f1, f2) -> f1 + f2)); + } + return b; + } + ); + + Map<String, Integer>[] res = new HashMap[frequencies.length]; + + for (int i = 0; i < frequencies.length; i++) + res[i] = transformFrequenciesToEncodingValues(frequencies[i]); + + return res; + } + + /** + * Transforms frequencies to the encoding values. + * + * @param frequencies Frequencies of categories for the specific feature. + * @return Encoding values. + */ + private Map<String, Integer> transformFrequenciesToEncodingValues(Map<String, Integer> frequencies) { + final HashMap<String, Integer> resMap = frequencies.entrySet() + .stream() + .sorted(Map.Entry.comparingByValue()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, + (oldValue, newValue) -> oldValue, LinkedHashMap::new)); + + int amountOfLabels = frequencies.size(); + + for (Map.Entry<String, Integer> m : resMap.entrySet()) + m.setValue(--amountOfLabels); + + return resMap; + } + + /** + * Updates frequencies by values and features. + * + * @param row Feature vector. + * @param categoryFrequencies Holds the frequencies of categories by values and features. + * @return Updated frequencies by values and features. + */ + private Map<String, Integer>[] calculateFrequencies(String[] row, Map<String, Integer>[] categoryFrequencies) { + if (categoryFrequencies == null) { + categoryFrequencies = new HashMap[row.length]; + for (int i = 0; i < categoryFrequencies.length; i++) + categoryFrequencies[i] = new HashMap<>(); + } + else + assert categoryFrequencies.length == row.length : "Base preprocessor must return exactly " + categoryFrequencies.length + + " features"; + + for (int i = 0; i < categoryFrequencies.length; i++) { + String s = row[i]; + Map<String, Integer> map = categoryFrequencies[i]; + + if (map.containsKey(s)) + map.put(s, (map.get(s)) + 1); + else + map.put(s, 1); + } + return categoryFrequencies; + } + +} http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/package-info.java new file mode 100644 index 0000000..7cdb40f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains string encoding preprocessor. + */ +package org.apache.ignite.ml.preprocessing.encoding.stringencoder; http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java index f0c566c..cb29ecb 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java @@ -19,6 +19,8 @@ package org.apache.ignite.ml.preprocessing; import org.apache.ignite.ml.preprocessing.binarization.BinarizationPreprocessorTest; import org.apache.ignite.ml.preprocessing.binarization.BinarizationTrainerTest; +import org.apache.ignite.ml.preprocessing.encoding.StringEncoderPreprocessorTest; +import org.apache.ignite.ml.preprocessing.encoding.StringEncoderTrainerTest; import org.apache.ignite.ml.preprocessing.imputing.ImputerPreprocessorTest; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainerTest; import org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessorTest; @@ -36,7 +38,9 @@ import org.junit.runners.Suite; BinarizationPreprocessorTest.class, BinarizationTrainerTest.class, ImputerPreprocessorTest.class, - ImputerTrainerTest.class + ImputerTrainerTest.class, + StringEncoderTrainerTest.class, + StringEncoderPreprocessorTest.class }) public class PreprocessingTestSuite { // No-op. http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java new file mode 100644 index 0000000..d74b923 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.encoding; + +import java.util.HashMap; +import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderPreprocessor; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link StringEncoderPreprocessor}. + */ +public class StringEncoderPreprocessorTest { + /** Tests {@code apply()} method. */ + @Test + public void testApply() { + String[][] data = new String[][]{ + {"1", "Moscow", "A"}, + {"2", "Moscow", "B"}, + {"2", "Moscow", "B"}, + }; + + StringEncoderPreprocessor<Integer, String[]> preprocessor = new StringEncoderPreprocessor<Integer, String[]>( + new HashMap[]{new HashMap() { + { + put("1", 1); + put("2", 0); + } + }, new HashMap() { + { + put("Moscow", 0); + } + }, new HashMap() { + { + put("A", 1); + put("B", 0); + } + }}, + (k, v) -> v + ); + + double[][] postProcessedData = new double[][]{ + {1.0, 0.0, 1.0}, + {0.0, 0.0, 0.0}, + {0.0, 0.0, 0.0}, + }; + + for (int i = 0; i < data.length; i++) + assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]), 1e-8); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java new file mode 100644 index 0000000..aa17beb --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.encoding; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderPreprocessor; +import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link StringEncoderTrainer}. + */ +@RunWith(Parameterized.class) +public class StringEncoderTrainerTest { + /** Parameters. */ + @Parameterized.Parameters(name = "Data divided on {0} partitions") + public static Iterable<Integer[]> data() { + return Arrays.asList( + new Integer[] {1}, + new Integer[] {2}, + new Integer[] {3}, + new Integer[] {5}, + new Integer[] {7}, + new Integer[] {100}, + new Integer[] {1000} + ); + } + + /** Number of partitions. */ + @Parameterized.Parameter + public int parts; + + /** Tests {@code fit()} method. */ + @Test + public void testFit() { + Map<Integer, String[]> data = new HashMap<>(); + data.put(1, new String[] {"Monday", "September"}); + data.put(2, new String[] {"Monday", "August"}); + data.put(3, new String[] {"Monday", "August"}); + data.put(4, new String[] {"Friday", "June"}); + data.put(5, new String[] {"Friday", "June"}); + data.put(6, new String[] {"Sunday", "August"}); + + DatasetBuilder<Integer, String[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts); + + StringEncoderTrainer<Integer, String[]> strEncoderTrainer = new StringEncoderTrainer<>(); + + StringEncoderPreprocessor<Integer, String[]> preprocessor = strEncoderTrainer.fit( + datasetBuilder, + (k, v) -> v + ); + + assertArrayEquals(new double[] {0.0, 2.0}, preprocessor.apply(7, new String[] {"Monday", "September"}), 1e-8); + } +}
