IGNITE-9145:[ML] Add different strategies to index labels in StringEncoderTrainer
this closes #5481 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/9137af73 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/9137af73 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/9137af73 Branch: refs/heads/ignite-9720 Commit: 9137af73ef20228ee98e4bc95a8ccb15dadd0010 Parents: cdaeda1 Author: zaleslaw <[email protected]> Authored: Mon Nov 26 14:10:51 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Mon Nov 26 14:10:51 2018 +0300 ---------------------------------------------------------------------- .../encoding/EncoderSortingStrategy.java | 31 ++++++++++++++++++++ .../preprocessing/encoding/EncoderTrainer.java | 25 +++++++++++++++- .../encoding/EncoderTrainerTest.java | 27 +++++++++++++++++ 3 files changed, 82 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/9137af73/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderSortingStrategy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderSortingStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderSortingStrategy.java new file mode 100644 index 0000000..22cca53 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderSortingStrategy.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.encoding; + +/** + * Describes Encoder sorting strategy to define mapping of integer values to values of categorical feature . + * + * @see EncoderTrainer + */ +public enum EncoderSortingStrategy { + /** Descending order by label frequency (most frequent label assigned 0). */ + FREQUENCY_DESC, + + /** Ascending order by label frequency (least frequent label assigned 0). */ + FREQUENCY_ASC +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9137af73/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java index 9a97a6d..d5668e4 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java @@ -17,6 +17,8 @@ package org.apache.ignite.ml.preprocessing.encoding; +import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; @@ -47,6 +49,9 @@ public class EncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[] /** Encoder preprocessor type. */ private EncoderType encoderType = EncoderType.ONE_HOT_ENCODER; + /** Encoder sorting strategy. */ + private EncoderSortingStrategy encoderSortingStgy = EncoderSortingStrategy.FREQUENCY_DESC; + /** {@inheritDoc} */ @Override public EncoderPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Object[]> basePreprocessor) { @@ -129,9 +134,16 @@ public class EncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[] * @return Encoding values. */ private Map<String, Integer> transformFrequenciesToEncodingValues(Map<String, Integer> frequencies) { + Comparator<Map.Entry<String, Integer>> comp; + + if (encoderSortingStgy.equals(EncoderSortingStrategy.FREQUENCY_DESC)) + comp = Map.Entry.comparingByValue(); + else + comp = Collections.reverseOrder(Map.Entry.comparingByValue()); + final HashMap<String, Integer> resMap = frequencies.entrySet() .stream() - .sorted(Map.Entry.comparingByValue()) + .sorted(comp) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (oldValue, newValue) -> oldValue, LinkedHashMap::new)); @@ -211,6 +223,17 @@ public class EncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[] } /** + * Sets the encoder indexing strategy. + * + * @param encoderSortingStgy The encoder indexing strategy. + * @return The changed trainer. + */ + public EncoderTrainer<K, V> withEncoderIndexingStrategy(EncoderSortingStrategy encoderSortingStgy) { + this.encoderSortingStgy = encoderSortingStgy; + return this; + } + + /** * Sets the encoder preprocessor type. * * @param type The encoder preprocessor type. http://git-wip-us.apache.org/repos/asf/ignite/blob/9137af73/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java index 23afd30..7c7eabe 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java @@ -115,4 +115,31 @@ public class EncoderTrainerTest extends TrainerTest { } fail("UnknownCategorialFeatureValue"); } + + /** Tests {@code fit()} method. */ + @Test + public void testFitOnStringCategorialFeaturesWithReversedOrder() { + Map<Integer, String[]> data = new HashMap<>(); + data.put(1, new String[] {"Monday", "September"}); + data.put(2, new String[] {"Monday", "August"}); + data.put(3, new String[] {"Monday", "August"}); + data.put(4, new String[] {"Friday", "June"}); + data.put(5, new String[] {"Friday", "June"}); + data.put(6, new String[] {"Sunday", "August"}); + + DatasetBuilder<Integer, String[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts); + + EncoderTrainer<Integer, String[]> strEncoderTrainer = new EncoderTrainer<Integer, String[]>() + .withEncoderType(EncoderType.STRING_ENCODER) + .withEncoderIndexingStrategy(EncoderSortingStrategy.FREQUENCY_ASC) + .withEncodedFeature(0) + .withEncodedFeature(1); + + EncoderPreprocessor<Integer, String[]> preprocessor = strEncoderTrainer.fit( + datasetBuilder, + (k, v) -> v + ); + + assertArrayEquals(new double[] {2.0, 0.0}, preprocessor.apply(7, new String[] {"Monday", "September"}).asArray(), 1e-8); + } }
