Repository: ignite Updated Branches: refs/heads/master 44098bc6e -> 8cf9aa273
IGNITE-8680: Added One-Hot Encoder this closes #4469 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/8cf9aa27 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/8cf9aa27 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/8cf9aa27 Branch: refs/heads/master Commit: 8cf9aa273ec6e8c61a653fcd1d2b935ff8e430f2 Parents: 44098bc Author: Zinoviev Alexey <[email protected]> Authored: Fri Aug 3 14:25:17 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Fri Aug 3 14:25:17 2018 +0300 ---------------------------------------------------------------------- .../examples/ml/tutorial/Step_3_Categorial.java | 6 +- .../Step_3_Categorial_with_One_Hot_Encoder.java | 103 +++++++++ .../ml/tutorial/Step_4_Add_age_fare.java | 6 +- .../examples/ml/tutorial/Step_5_Scaling.java | 6 +- .../ignite/examples/ml/tutorial/Step_6_KNN.java | 6 +- .../ml/tutorial/Step_7_Split_train_test.java | 6 +- .../ignite/examples/ml/tutorial/Step_8_CV.java | 6 +- .../ml/tutorial/Step_8_CV_with_Param_Grid.java | 6 +- .../ml/tutorial/Step_9_Go_to_LogReg.java | 6 +- .../src/main/resources/datasets/titanic.csv | 2 +- .../UnknownCategorialFeatureValue.java | 35 +++ .../preprocessing/UnknownStringValue.java | 35 --- .../encoding/EncoderPartitionData.java | 59 +++++ .../encoding/EncoderPreprocessor.java | 56 +++++ .../preprocessing/encoding/EncoderTrainer.java | 225 +++++++++++++++++++ .../ml/preprocessing/encoding/EncoderType.java | 31 +++ .../OneHotEncoderPreprocessor.java | 149 ++++++++++++ .../encoding/onehotencoder/package-info.java | 22 ++ .../StringEncoderPartitionData.java | 62 ----- .../StringEncoderPreprocessor.java | 34 +-- .../stringencoder/StringEncoderTrainer.java | 196 ---------------- .../preprocessing/PreprocessingTestSuite.java | 6 +- .../encoding/EncoderTrainerTest.java | 139 ++++++++++++ .../encoding/OneHotEncoderPreprocessorTest.java | 134 +++++++++++ .../encoding/StringEncoderPreprocessorTest.java | 4 +- .../encoding/StringEncoderTrainerTest.java | 80 ------- 26 files changed, 1003 insertions(+), 417 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java index ee2ef8b..e623083 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java @@ -23,7 +23,8 @@ import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderType; +import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; @@ -51,7 +52,8 @@ public class Step_3_Categorial { IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double) v[1]; - IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>() + .withEncoderType(EncoderType.STRING_ENCODER) .encodeFeature(1) .encodeFeature(4) .fit(ignite, http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java new file mode 100644 index 0000000..d80f647 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tutorial; + +import java.io.FileNotFoundException; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.preprocessing.encoding.EncoderType; +import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; +import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; +import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; +import org.apache.ignite.ml.selection.scoring.metric.Accuracy; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.thread.IgniteThread; + +/** + * Let's add two categorial features "sex", "embarked" to predict more precisely. + * + * To encode categorial features the StringEncoderTrainer will be used. + */ +public class Step_3_Categorial_with_One_Hot_Encoder { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + Step_3_Categorial_with_One_Hot_Encoder.class.getSimpleName(), () -> { + try { + IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite); + + // Defines first preprocessor that extracts features from an upstream data. + IgniteBiFunction<Integer, Object[], Object[]> featureExtractor + = (k, v) -> new Object[]{v[0], v[3], v[5], v[6], v[10] + }; // "pclass", "sibsp", "parch", "sex", "embarked" + + IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double) v[1]; + + IgniteBiFunction<Integer, Object[], Vector> oneHotEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>() + .withEncoderType(EncoderType.ONE_HOT_ENCODER) + .encodeFeature(0) + .encodeFeature(1) + .encodeFeature(4) + .fit(ignite, + dataCache, + featureExtractor + ); + + IgniteBiFunction<Integer, Object[], Vector> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>() + .fit(ignite, + dataCache, + oneHotEncoderPreprocessor + ); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0); + + // Train decision tree model. + DecisionTreeNode mdl = trainer.fit( + ignite, + dataCache, + imputingPreprocessor, + lbExtractor + ); + + double accuracy = Evaluator.evaluate( + dataCache, + mdl, + imputingPreprocessor, + lbExtractor, + new Accuracy<>() + ); + + System.out.println("\n>>> Accuracy " + accuracy); + System.out.println("\n>>> Test Error " + (1 - accuracy)); + } + catch (FileNotFoundException e) { + e.printStackTrace(); + } + }); + + igniteThread.start(); + + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java index c3bf389..2ea9860 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java @@ -23,7 +23,8 @@ import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderType; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; @@ -50,7 +51,8 @@ public class Step_4_Add_age_fare { IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double) v[1]; - IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>() + .withEncoderType(EncoderType.STRING_ENCODER) .encodeFeature(1) .encodeFeature(6) // <--- Changed index here .fit(ignite, http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java index 549ab77..01a4c3f 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java @@ -23,7 +23,8 @@ import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderType; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; @@ -52,7 +53,8 @@ public class Step_5_Scaling { IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double) v[1]; - IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>() + .withEncoderType(EncoderType.STRING_ENCODER) .encodeFeature(1) .encodeFeature(6) // <--- Changed index here .fit(ignite, http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java index 5308287..e07e9f8 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java @@ -26,7 +26,8 @@ import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer; import org.apache.ignite.ml.knn.classification.KNNStrategy; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderType; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; @@ -53,7 +54,8 @@ public class Step_6_KNN { IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double) v[1]; - IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>() + .withEncoderType(EncoderType.STRING_ENCODER) .encodeFeature(1) .encodeFeature(6) // <--- Changed index here .fit(ignite, http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java index c8fad61..f62054e 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java @@ -23,7 +23,8 @@ import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderType; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; @@ -59,7 +60,8 @@ public class Step_7_Split_train_test { TrainTestSplit<Integer, Object[]> split = new TrainTestDatasetSplitter<Integer, Object[]>() .split(0.75); - IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>() + .withEncoderType(EncoderType.STRING_ENCODER) .encodeFeature(1) .encodeFeature(6) // <--- Changed index here .fit(ignite, http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java index 981e119..d7e6e27 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java @@ -24,7 +24,8 @@ import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderType; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; @@ -70,7 +71,8 @@ public class Step_8_CV { TrainTestSplit<Integer, Object[]> split = new TrainTestDatasetSplitter<Integer, Object[]>() .split(0.75); - IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>() + .withEncoderType(EncoderType.STRING_ENCODER) .encodeFeature(1) .encodeFeature(6) // <--- Changed index here .fit(ignite, http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java index 6104299..9311cfb 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java @@ -24,7 +24,8 @@ import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderType; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; @@ -72,7 +73,8 @@ public class Step_8_CV_with_Param_Grid { TrainTestSplit<Integer, Object[]> split = new TrainTestDatasetSplitter<Integer, Object[]>() .split(0.75); - IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>() + .withEncoderType(EncoderType.STRING_ENCODER) .encodeFeature(1) .encodeFeature(6) // <--- Changed index here .fit(ignite, http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java index 88b642b..9fcc9ba 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java @@ -27,7 +27,8 @@ import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderType; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; @@ -64,7 +65,8 @@ public class Step_9_Go_to_LogReg { TrainTestSplit<Integer, Object[]> split = new TrainTestDatasetSplitter<Integer, Object[]>() .split(0.75); - IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>() + .withEncoderType(EncoderType.STRING_ENCODER) .encodeFeature(1) .encodeFeature(6) // <--- Changed index here .fit(ignite, http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/examples/src/main/resources/datasets/titanic.csv ---------------------------------------------------------------------- diff --git a/examples/src/main/resources/datasets/titanic.csv b/examples/src/main/resources/datasets/titanic.csv index eff6ba6..6994016 100644 --- a/examples/src/main/resources/datasets/titanic.csv +++ b/examples/src/main/resources/datasets/titanic.csv @@ -1,6 +1,6 @@ pclass;survived;name;sex;age;sibsp;parch;ticket;fare;cabin;embarked;boat;body;homedest 1;1;Allen, Miss. Elisabeth Walton;;29;;;24160;211,3375;B5;;2;;St Louis, MO -1;1;Allison, Master. Hudson Trevor;male;0,9167;1;2;113781;151,55;C22 C26;S;11;;Montreal, PQ / Chesterville, ON +1;1;Allison, Master. Hudson Trevor;male;0,9167;1;2;113781;151,55;C22 C26;AA;11;;Montreal, PQ / Chesterville, ON 1;0;Allison, Miss. Helen Loraine;female;2;1;2;113781;151,55;C22 C26;S;;;Montreal, PQ / Chesterville, ON 1;0;Allison, Mr. Hudson Joshua Creighton;male;30;1;2;113781;151,55;C22 C26;S;;135;Montreal, PQ / Chesterville, ON 1;0;Allison, Mrs. Hudson J C (Bessie Waldo Daniels);female;25;1;2;113781;151,55;C22 C26;S;;;Montreal, PQ / Chesterville, ON http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownCategorialFeatureValue.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownCategorialFeatureValue.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownCategorialFeatureValue.java new file mode 100644 index 0000000..6e97d62 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownCategorialFeatureValue.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.exceptions.preprocessing; + +import org.apache.ignite.IgniteException; + +/** + * Indicates an unknown categorial feature value for Encoder. + */ +public class UnknownCategorialFeatureValue extends IgniteException { + /** */ + private static final long serialVersionUID = 0L; + + /** + * @param unknownStr Categorial value that caused this exception. + */ + public UnknownCategorialFeatureValue(String unknownStr) { + super("This categorial value is unknown for Encoder: " + unknownStr); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownStringValue.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownStringValue.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownStringValue.java deleted file mode 100644 index 2fc6cee..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownStringValue.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.math.exceptions.preprocessing; - -import org.apache.ignite.IgniteException; - -/** - * Indicates an unknown String value for StringEncoder. - */ -public class UnknownStringValue extends IgniteException { - /** */ - private static final long serialVersionUID = 0L; - - /** - * @param unknownStr String value that caused this exception. - */ - public UnknownStringValue(String unknownStr) { - super("This String value is unknown for StringEncoder: " + unknownStr); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderPartitionData.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderPartitionData.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderPartitionData.java new file mode 100644 index 0000000..ae53716 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderPartitionData.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.encoding; + +import java.util.Map; + +/** + * Partition data used in Encoder preprocessor. + */ +public class EncoderPartitionData implements AutoCloseable { + /** Frequencies of categories for each categorial feature presented as strings. */ + private Map<String, Integer>[] categoryFrequencies; + + /** + * Constructs a new instance of String Encoder partition data. + */ + public EncoderPartitionData() { + } + + /** + * Gets the array of maps of frequencies by value in partition for each feature in the dataset. + * + * @return The frequencies. + */ + public Map<String, Integer>[] categoryFrequencies() { + return categoryFrequencies; + } + + /** + * Sets the array of maps of frequencies by value in partition for each feature in the dataset. + * + * @param categoryFrequencies The given value. + * @return The partition data. + */ + public EncoderPartitionData withCategoryFrequencies(Map<String, Integer>[] categoryFrequencies) { + this.categoryFrequencies = categoryFrequencies; + return this; + } + + /** */ + @Override public void close() { + // Do nothing, GC will clean up. + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderPreprocessor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderPreprocessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderPreprocessor.java new file mode 100644 index 0000000..7df44f3 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderPreprocessor.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.encoding; + +import java.util.Map; +import java.util.Set; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; + +/** + * Preprocessing function that makes encoding. + * + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + */ +public abstract class EncoderPreprocessor<K, V> implements IgniteBiFunction<K, V, Vector> { + /** */ + protected static final String KEY_FOR_NULL_VALUES = ""; + + /** Filling values. */ + protected final Map<String, Integer>[] encodingValues; + + /** Base preprocessor. */ + protected final IgniteBiFunction<K, V, Object[]> basePreprocessor; + + /** Feature indices to apply encoder. */ + protected final Set<Integer> handledIndices; + + /** + * Constructs a new instance of String Encoder preprocessor. + * + * @param basePreprocessor Base preprocessor. + * @param handledIndices Handled indices. + */ + public EncoderPreprocessor(Map<String, Integer>[] encodingValues, + IgniteBiFunction<K, V, Object[]> basePreprocessor, Set<Integer> handledIndices) { + this.handledIndices = handledIndices; + this.encodingValues = encodingValues; + this.basePreprocessor = basePreprocessor; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java new file mode 100644 index 0000000..f716d96 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.encoding; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; +import org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor; +import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderPreprocessor; +import org.jetbrains.annotations.NotNull; + +/** + * Trainer of the String Encoder preprocessor. + * The String Encoder encodes string values (categories) to double values in range [0.0, amountOfCategories) + * where the most popular value will be presented as 0.0 and the least popular value presented with amountOfCategories-1 value. + * + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + */ +public class EncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[], Vector> { + /** Indices of features which should be encoded. */ + private Set<Integer> handledIndices = new HashSet<>(); + + /** Encoder preprocessor type. */ + private EncoderType encoderType = EncoderType.ONE_HOT_ENCODER; + + /** {@inheritDoc} */ + @Override public EncoderPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Object[]> basePreprocessor) { + if (handledIndices.isEmpty()) + throw new RuntimeException("Add indices of handled features"); + + try (Dataset<EmptyContext, EncoderPartitionData> dataset = datasetBuilder.build( + (upstream, upstreamSize) -> new EmptyContext(), + (upstream, upstreamSize, ctx) -> { + // This array will contain not null values for handled indices + Map<String, Integer>[] categoryFrequencies = null; + + while (upstream.hasNext()) { + UpstreamEntry<K, V> entity = upstream.next(); + Object[] row = basePreprocessor.apply(entity.getKey(), entity.getValue()); + categoryFrequencies = calculateFrequencies(row, categoryFrequencies); + } + return new EncoderPartitionData() + .withCategoryFrequencies(categoryFrequencies); + } + )) { + Map<String, Integer>[] encodingValues = calculateEncodingValuesByFrequencies(dataset); + + switch (encoderType) { + case ONE_HOT_ENCODER: + return new OneHotEncoderPreprocessor<>(encodingValues, basePreprocessor, handledIndices); + case STRING_ENCODER: + return new StringEncoderPreprocessor<>(encodingValues, basePreprocessor, handledIndices); + default: + throw new IllegalStateException("Define the type of the resulting prerocessor."); + } + + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Calculates the encoding values values by frequencies keeping in the given dataset. + * + * @param dataset The dataset of frequencies for each feature aggregated in each partition. + * @return Encoding values for each feature. + */ + private Map<String, Integer>[] calculateEncodingValuesByFrequencies( + Dataset<EmptyContext, EncoderPartitionData> dataset) { + Map<String, Integer>[] frequencies = dataset.compute( + EncoderPartitionData::categoryFrequencies, + (a, b) -> { + if (a == null) + return b; + + if (b == null) + return a; + + assert a.length == b.length; + + for (int i = 0; i < a.length; i++) { + if (handledIndices.contains(i)) { + int finalI = i; + a[i].forEach((k, v) -> b[finalI].merge(k, v, (f1, f2) -> f1 + f2)); + } + } + return b; + } + ); + + Map<String, Integer>[] res = new HashMap[frequencies.length]; + + for (int i = 0; i < frequencies.length; i++) + if (handledIndices.contains(i)) + res[i] = transformFrequenciesToEncodingValues(frequencies[i]); + + return res; + } + + /** + * Transforms frequencies to the encoding values. + * + * @param frequencies Frequencies of categories for the specific feature. + * @return Encoding values. + */ + private Map<String, Integer> transformFrequenciesToEncodingValues(Map<String, Integer> frequencies) { + final HashMap<String, Integer> resMap = frequencies.entrySet() + .stream() + .sorted(Map.Entry.comparingByValue()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, + (oldValue, newValue) -> oldValue, LinkedHashMap::new)); + + int amountOfLabels = frequencies.size(); + + for (Map.Entry<String, Integer> m : resMap.entrySet()) + m.setValue(--amountOfLabels); + + return resMap; + } + + /** + * Updates frequencies by values and features. + * + * @param row Feature vector. + * @param categoryFrequencies Holds the frequencies of categories by values and features. + * @return Updated frequencies by values and features. + */ + private Map<String, Integer>[] calculateFrequencies(Object[] row, Map<String, Integer>[] categoryFrequencies) { + if (categoryFrequencies == null) + categoryFrequencies = initializeCategoryFrequencies(row); + else + assert categoryFrequencies.length == row.length : "Base preprocessor must return exactly " + + categoryFrequencies.length + " features"; + + for (int i = 0; i < categoryFrequencies.length; i++) { + if (handledIndices.contains(i)) { + String strVal; + Object featureVal = row[i]; + + if (featureVal.equals(Double.NaN)) { + strVal = ""; + row[i] = strVal; + } else if (featureVal instanceof String) + strVal = (String) featureVal; + else if (featureVal instanceof Double) + strVal = String.valueOf(featureVal); + else + throw new RuntimeException("The type " + featureVal.getClass() + " is not supported for the feature values."); + + Map<String, Integer> map = categoryFrequencies[i]; + + if (map.containsKey(strVal)) + map.put(strVal, (map.get(strVal)) + 1); + else + map.put(strVal, 1); + } + } + return categoryFrequencies; + } + + /** + * Initialize frequencies for handled indices only. + * + * @param row Feature vector. + * @return The array contains not null values for handled indices. + */ + @NotNull private Map<String, Integer>[] initializeCategoryFrequencies(Object[] row) { + Map<String, Integer>[] categoryFrequencies = new HashMap[row.length]; + + for (int i = 0; i < categoryFrequencies.length; i++) + if (handledIndices.contains(i)) + categoryFrequencies[i] = new HashMap<>(); + + return categoryFrequencies; + } + + /** + * Add the index of encoded feature. + * + * @param idx The index of encoded feature. + * @return The changed trainer. + */ + public EncoderTrainer<K, V> encodeFeature(int idx) { + handledIndices.add(idx); + return this; + } + + /** + * Sets the encoder preprocessor type. + * + * @param type The encoder preprocessor type. + * @return The changed trainer. + */ + public EncoderTrainer<K, V> withEncoderType(EncoderType type) { + this.encoderType = type; + return this; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderType.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderType.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderType.java new file mode 100644 index 0000000..79e216c --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderType.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.encoding; + +/** + * Describes Encoder preprocessor types to define resulting model in EncoderTrainer. + * + * @see EncoderTrainer + */ +public enum EncoderType { + /** One hot encoder. */ + ONE_HOT_ENCODER, + + /** String encoder. */ + STRING_ENCODER +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/onehotencoder/OneHotEncoderPreprocessor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/onehotencoder/OneHotEncoderPreprocessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/onehotencoder/OneHotEncoderPreprocessor.java new file mode 100644 index 0000000..7aadadf --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/onehotencoder/OneHotEncoderPreprocessor.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.encoding.onehotencoder; + +import java.util.Map; +import java.util.Set; +import org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialFeatureValue; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.preprocessing.encoding.EncoderPreprocessor; +import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; + +/** + * Preprocessing function that makes one-hot encoding. + * + * One-hot encoding maps a categorical feature, + * represented as a label index (Double or String value), + * to a binary vector with at most a single one-value indicating the presence of a specific feature value + * from among the set of all feature values. + * + * This preprocessor can transform multiple columns which indices are handled during training process. + * + * Each one-hot encoded binary vector adds its cells to the end of the current feature vector. + * + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @see EncoderTrainer + * + * This prerpocessor always creates separate column for the NULL values. + * + * NOTE: the index value associated with NULL will located in binary vector according the frequency of NULL values. + */ +public class OneHotEncoderPreprocessor<K, V> extends EncoderPreprocessor<K, V> { + /** */ + private static final long serialVersionUID = 6237812226552623469L; + + /** + * Constructs a new instance of One Hot Encoder preprocessor. + * + * @param basePreprocessor Base preprocessor. + * @param handledIndices Handled indices. + */ + public OneHotEncoderPreprocessor(Map<String, Integer>[] encodingValues, + IgniteBiFunction<K, V, Object[]> basePreprocessor, Set<Integer> handledIndices) { + super(encodingValues, basePreprocessor, handledIndices); + } + + /** + * Applies this preprocessor. + * + * @param k Key. + * @param v Value. + * @return Preprocessed row. + */ + @Override public Vector apply(K k, V v) { + Object[] tmp = basePreprocessor.apply(k, v); + + double[] res = new double[tmp.length + getAdditionalSize(encodingValues)]; + + int categorialFeatureCntr = 0; + + for (int i = 0; i < tmp.length; i++) { + Object tmpObj = tmp[i]; + if (handledIndices.contains(i)) { + categorialFeatureCntr++; + + if (tmpObj.equals(Double.NaN) && encodingValues[i].containsKey(KEY_FOR_NULL_VALUES)) { + final Integer indexedVal = encodingValues[i].get(KEY_FOR_NULL_VALUES); + + res[i] = indexedVal; + + res[tmp.length + getIdxOffset(categorialFeatureCntr, indexedVal, encodingValues)] = 1.0; + } else { + final String key = String.valueOf(tmpObj); + + if (encodingValues[i].containsKey(key)) { + final Integer indexedVal = encodingValues[i].get(key); + + res[i] = indexedVal; + + res[tmp.length + getIdxOffset(categorialFeatureCntr, indexedVal, encodingValues)] = 1.0; + + } else + throw new UnknownCategorialFeatureValue(tmpObj.toString()); + } + + } else + res[i] = (double) tmpObj; + } + return VectorUtils.of(res); + } + + /** + * Calculates the additional size of feature vector based on trainer's stats. + * It adds amount of column for each categorial feature equal to amount of categories. + * + * @param encodingValues The given trainer stats which helps to calculates the actual size of feature vector. + * @return The additional size. + */ + private int getAdditionalSize(Map<String, Integer>[] encodingValues) { + int newSize = 0; + for (Map<String, Integer> encodingValue : encodingValues) { + if (encodingValue != null) + newSize += encodingValue.size(); // - 1 if we don't keep NULL values and it has NULL values + } + return newSize; + } + + /** + * Calculates the offset in feature vector to set up 1.0 accordingly the index value. + * + * @param categorialFeatureCntr The actual order number for the current categorial feature. + * @param indexedVal The indexed value, converted from the raw value. + * @param encodingValues The trainer's stats about category frequencies. + * @return The offset. + */ + private int getIdxOffset(int categorialFeatureCntr, int indexedVal, Map<String, Integer>[] encodingValues) { + int idxOff = 0; + + int locCategorialFeatureCntr = 1; + + for (int i = 0; locCategorialFeatureCntr < categorialFeatureCntr; i++) { + if (encodingValues[i] != null) { + locCategorialFeatureCntr++; + idxOff += encodingValues[i].size(); // - 1 if we don't keep NULL values and it has NULL values + } + } + + idxOff += indexedVal; + + return idxOff; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/onehotencoder/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/onehotencoder/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/onehotencoder/package-info.java new file mode 100644 index 0000000..66d6c55 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/onehotencoder/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains one hot encoding preprocessor. + */ +package org.apache.ignite.ml.preprocessing.encoding.onehotencoder; http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPartitionData.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPartitionData.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPartitionData.java deleted file mode 100644 index acd2aa2..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPartitionData.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.preprocessing.encoding.stringencoder; - -import java.util.Map; - -/** - * Partition data used in String Encoder preprocessor. - * - * @see StringEncoderTrainer - * @see StringEncoderPreprocessor - */ -public class StringEncoderPartitionData implements AutoCloseable { - /** Frequencies of categories for each categorial feature presented as strings. */ - private Map<String, Integer>[] categoryFrequencies; - - /** - * Constructs a new instance of String Encoder partition data. - */ - public StringEncoderPartitionData() { - } - - /** - * Gets the array of maps of frequencies by value in partition for each feature in the dataset. - * - * @return The frequencies. - */ - public Map<String, Integer>[] categoryFrequencies() { - return categoryFrequencies; - } - - /** - * Sets the array of maps of frequencies by value in partition for each feature in the dataset. - * - * @param categoryFrequencies The given value. - * @return The partition data. - */ - public StringEncoderPartitionData withCategoryFrequencies(Map<String, Integer>[] categoryFrequencies) { - this.categoryFrequencies = categoryFrequencies; - return this; - } - - /** */ - @Override public void close() { - // Do nothing, GC will clean up. - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java index a1c5b77..12f98f6 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java @@ -19,10 +19,11 @@ package org.apache.ignite.ml.preprocessing.encoding.stringencoder; import java.util.Map; import java.util.Set; -import org.apache.ignite.ml.math.exceptions.preprocessing.UnknownStringValue; +import org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialFeatureValue; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.preprocessing.encoding.EncoderPreprocessor; /** * Preprocessing function that makes String encoding. @@ -30,32 +31,19 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils; * @param <K> Type of a key in {@code upstream} data. * @param <V> Type of a value in {@code upstream} data. */ -public class StringEncoderPreprocessor<K, V> implements IgniteBiFunction<K, V, Vector> { +public class StringEncoderPreprocessor<K, V> extends EncoderPreprocessor<K, V> { /** */ - private static final long serialVersionUID = 6237812226382623469L; - /** */ - private static final String KEY_FOR_NULL_VALUES = ""; - - /** Filling values. */ - private final Map<String, Integer>[] encodingValues; - - /** Base preprocessor. */ - private final IgniteBiFunction<K, V, Object[]> basePreprocessor; - - /** Feature indices to apply encoder.*/ - private final Set<Integer> handledIndices; + protected static final long serialVersionUID = 6237712226382623488L; /** * Constructs a new instance of String Encoder preprocessor. * * @param basePreprocessor Base preprocessor. - * @param handledIndices Handled indices. + * @param handledIndices Handled indices. */ public StringEncoderPreprocessor(Map<String, Integer>[] encodingValues, - IgniteBiFunction<K, V, Object[]> basePreprocessor, Set<Integer> handledIndices) { - this.handledIndices = handledIndices; - this.encodingValues = encodingValues; - this.basePreprocessor = basePreprocessor; + IgniteBiFunction<K, V, Object[]> basePreprocessor, Set<Integer> handledIndices) { + super(encodingValues, basePreprocessor, handledIndices); } /** @@ -71,15 +59,15 @@ public class StringEncoderPreprocessor<K, V> implements IgniteBiFunction<K, V, V for (int i = 0; i < res.length; i++) { Object tmpObj = tmp[i]; - if(handledIndices.contains(i)){ - if(tmpObj.equals(Double.NaN) && encodingValues[i].containsKey(KEY_FOR_NULL_VALUES)) + if (handledIndices.contains(i)) { + if (tmpObj.equals(Double.NaN) && encodingValues[i].containsKey(KEY_FOR_NULL_VALUES)) res[i] = encodingValues[i].get(KEY_FOR_NULL_VALUES); else if (encodingValues[i].containsKey(tmpObj)) res[i] = encodingValues[i].get(tmpObj); else - throw new UnknownStringValue(tmpObj.toString()); + throw new UnknownCategorialFeatureValue(tmpObj.toString()); } else - res[i] = (double)tmpObj; + res[i] = (double) tmpObj; } return VectorUtils.of(res); } http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java deleted file mode 100644 index ec16af4..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.preprocessing.encoding.stringencoder; - -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; -import org.apache.ignite.ml.dataset.Dataset; -import org.apache.ignite.ml.dataset.DatasetBuilder; -import org.apache.ignite.ml.dataset.UpstreamEntry; -import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; -import org.jetbrains.annotations.NotNull; - -/** - * Trainer of the String Encoder preprocessor. - * The String Encoder encodes string values (categories) to double values in range [0.0, amountOfCategories) - * where the most popular value will be presented as 0.0 and the least popular value presented with amountOfCategories-1 value. - * - * @param <K> Type of a key in {@code upstream} data. - * @param <V> Type of a value in {@code upstream} data. - */ -public class StringEncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[], Vector> { - /** Indices of features which should be encoded. */ - private Set<Integer> handledIndices = new HashSet<>(); - - /** {@inheritDoc} */ - @Override public StringEncoderPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Object[]> basePreprocessor) { - if(handledIndices.isEmpty()) - throw new RuntimeException("Add indices of handled features"); - - try (Dataset<EmptyContext, StringEncoderPartitionData> dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new EmptyContext(), - (upstream, upstreamSize, ctx) -> { - // This array will contain not null values for handled indices - Map<String, Integer>[] categoryFrequencies = null; - - while (upstream.hasNext()) { - UpstreamEntry<K, V> entity = upstream.next(); - Object[] row = basePreprocessor.apply(entity.getKey(), entity.getValue()); - categoryFrequencies = calculateFrequencies(row, categoryFrequencies); - } - return new StringEncoderPartitionData() - .withCategoryFrequencies(categoryFrequencies); - } - )) { - Map<String, Integer>[] encodingValues = calculateEncodingValuesByFrequencies(dataset); - - return new StringEncoderPreprocessor<>(encodingValues, basePreprocessor, handledIndices); - } - catch (Exception e) { - throw new RuntimeException(e); - } - } - - /** - * Calculates the encoding values values by frequencies keeping in the given dataset. - * - * @param dataset The dataset of frequencies for each feature aggregated in each partition. - * @return Encoding values for each feature. - */ - private Map<String, Integer>[] calculateEncodingValuesByFrequencies( - Dataset<EmptyContext, StringEncoderPartitionData> dataset) { - Map<String, Integer>[] frequencies = dataset.compute( - StringEncoderPartitionData::categoryFrequencies, - (a, b) -> { - if (a == null) - return b; - - if (b == null) - return a; - - assert a.length == b.length; - - for (int i = 0; i < a.length; i++) { - if(handledIndices.contains(i)){ - int finalI = i; - a[i].forEach((k, v) -> b[finalI].merge(k, v, (f1, f2) -> f1 + f2)); - } - } - return b; - } - ); - - Map<String, Integer>[] res = new HashMap[frequencies.length]; - - for (int i = 0; i < frequencies.length; i++) - if(handledIndices.contains(i)) - res[i] = transformFrequenciesToEncodingValues(frequencies[i]); - - return res; - } - - /** - * Transforms frequencies to the encoding values. - * - * @param frequencies Frequencies of categories for the specific feature. - * @return Encoding values. - */ - private Map<String, Integer> transformFrequenciesToEncodingValues(Map<String, Integer> frequencies) { - final HashMap<String, Integer> resMap = frequencies.entrySet() - .stream() - .sorted(Map.Entry.comparingByValue()) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, - (oldValue, newValue) -> oldValue, LinkedHashMap::new)); - - int amountOfLabels = frequencies.size(); - - for (Map.Entry<String, Integer> m : resMap.entrySet()) - m.setValue(--amountOfLabels); - - return resMap; - } - - /** - * Updates frequencies by values and features. - * - * @param row Feature vector. - * @param categoryFrequencies Holds the frequencies of categories by values and features. - * @return Updated frequencies by values and features. - */ - private Map<String, Integer>[] calculateFrequencies(Object[] row, Map<String, Integer>[] categoryFrequencies) { - if (categoryFrequencies == null) - categoryFrequencies = initializeCategoryFrequencies(row); - else - assert categoryFrequencies.length == row.length : "Base preprocessor must return exactly " - + categoryFrequencies.length + " features"; - - for (int i = 0; i < categoryFrequencies.length; i++) { - if(handledIndices.contains(i)){ - String strVal; - Object featureVal = row[i]; - - if(featureVal.equals(Double.NaN)) { - strVal = ""; - row[i] = strVal; - } - else strVal = (String)featureVal; - - Map<String, Integer> map = categoryFrequencies[i]; - - if (map.containsKey(strVal)) - map.put(strVal, (map.get(strVal)) + 1); - else - map.put(strVal, 1); - } - } - return categoryFrequencies; - } - - /** - * Initialize frequencies for handled indices only. - * @param row Feature vector. - * @return The array contains not null values for handled indices. - */ - @NotNull private Map<String, Integer>[] initializeCategoryFrequencies(Object[] row) { - Map<String, Integer>[] categoryFrequencies = new HashMap[row.length]; - - for (int i = 0; i < categoryFrequencies.length; i++) - if(handledIndices.contains(i)) - categoryFrequencies[i] = new HashMap<>(); - - return categoryFrequencies; - } - - /** - * Add the index of encoded feature. - * @param idx The index of encoded feature. - * @return The changed trainer. - */ - public StringEncoderTrainer<K, V> encodeFeature(int idx){ - handledIndices.add(idx); - return this; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java index b13ed7d..7b3d5fc 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java @@ -19,8 +19,9 @@ package org.apache.ignite.ml.preprocessing; import org.apache.ignite.ml.preprocessing.binarization.BinarizationPreprocessorTest; import org.apache.ignite.ml.preprocessing.binarization.BinarizationTrainerTest; +import org.apache.ignite.ml.preprocessing.encoding.OneHotEncoderPreprocessorTest; import org.apache.ignite.ml.preprocessing.encoding.StringEncoderPreprocessorTest; -import org.apache.ignite.ml.preprocessing.encoding.StringEncoderTrainerTest; +import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainerTest; import org.apache.ignite.ml.preprocessing.imputing.ImputerPreprocessorTest; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainerTest; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerPreprocessorTest; @@ -41,7 +42,8 @@ import org.junit.runners.Suite; BinarizationTrainerTest.class, ImputerPreprocessorTest.class, ImputerTrainerTest.class, - StringEncoderTrainerTest.class, + EncoderTrainerTest.class, + OneHotEncoderPreprocessorTest.class, StringEncoderPreprocessorTest.class, NormalizationTrainerTest.class, NormalizationPreprocessorTest.class http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java new file mode 100644 index 0000000..c0157e9 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.encoding; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialFeatureValue; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static junit.framework.TestCase.fail; +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link EncoderTrainer}. + */ +@RunWith(Parameterized.class) +public class EncoderTrainerTest { + /** Parameters. */ + @Parameterized.Parameters(name = "Data divided on {0} partitions") + public static Iterable<Integer[]> data() { + return Arrays.asList( + new Integer[]{1}, + new Integer[]{2}, + new Integer[]{3}, + new Integer[]{5}, + new Integer[]{7}, + new Integer[]{100}, + new Integer[]{1000} + ); + } + + /** Number of partitions. */ + @Parameterized.Parameter + public int parts; + + /** Tests {@code fit()} method. */ + @Test + public void testFitOnStringCategorialFeatures() { + Map<Integer, String[]> data = new HashMap<>(); + data.put(1, new String[]{"Monday", "September"}); + data.put(2, new String[]{"Monday", "August"}); + data.put(3, new String[]{"Monday", "August"}); + data.put(4, new String[]{"Friday", "June"}); + data.put(5, new String[]{"Friday", "June"}); + data.put(6, new String[]{"Sunday", "August"}); + + DatasetBuilder<Integer, String[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts); + + EncoderTrainer<Integer, String[]> strEncoderTrainer = new EncoderTrainer<Integer, String[]>() + .withEncoderType(EncoderType.STRING_ENCODER) + .encodeFeature(0) + .encodeFeature(1); + + EncoderPreprocessor<Integer, String[]> preprocessor = strEncoderTrainer.fit( + datasetBuilder, + (k, v) -> v + ); + + assertArrayEquals(new double[]{0.0, 2.0}, preprocessor.apply(7, new String[]{"Monday", "September"}).asArray(), 1e-8); + } + + /** Tests {@code fit()} method. */ + @Test + public void testFitOnIntegerCategorialFeatures() { + Map<Integer, Object[]> data = new HashMap<>(); + data.put(1, new Object[]{3.0, 0.0}); + data.put(2, new Object[]{3.0, 12.0}); + data.put(3, new Object[]{3.0, 12.0}); + data.put(4, new Object[]{2.0, 45.0}); + data.put(5, new Object[]{2.0, 45.0}); + data.put(6, new Object[]{14.0, 12.0}); + + DatasetBuilder<Integer, Object[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts); + + EncoderTrainer<Integer, Object[]> strEncoderTrainer = new EncoderTrainer<Integer, Object[]>() + .withEncoderType(EncoderType.ONE_HOT_ENCODER) + .encodeFeature(0) + .encodeFeature(1); + + EncoderPreprocessor<Integer, Object[]> preprocessor = strEncoderTrainer.fit( + datasetBuilder, + (k, v) -> v + ); + assertArrayEquals(new double[]{0.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0}, preprocessor.apply(7, new Double[]{3.0, 0.0}).asArray(), 1e-8); + assertArrayEquals(new double[]{1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0}, preprocessor.apply(8, new Double[]{2.0, 12.0}).asArray(), 1e-8); + } + + /** Tests {@code fit()} method. */ + @Test + public void testFitWithUnknownStringValueInTheGivenData() { + Map<Integer, Object[]> data = new HashMap<>(); + data.put(1, new Object[]{3.0, 0.0}); + data.put(2, new Object[]{3.0, 12.0}); + data.put(3, new Object[]{3.0, 12.0}); + data.put(4, new Object[]{2.0, 45.0}); + data.put(5, new Object[]{2.0, 45.0}); + data.put(6, new Object[]{14.0, 12.0}); + + DatasetBuilder<Integer, Object[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts); + + EncoderTrainer<Integer, Object[]> strEncoderTrainer = new EncoderTrainer<Integer, Object[]>() + .withEncoderType(EncoderType.STRING_ENCODER) + .encodeFeature(0) + .encodeFeature(1); + + EncoderPreprocessor<Integer, Object[]> preprocessor = strEncoderTrainer.fit( + datasetBuilder, + (k, v) -> v + ); + + try { + preprocessor.apply(7, new String[]{"Monday", "September"}).asArray(); + fail("UnknownCategorialFeatureValue"); + } catch (UnknownCategorialFeatureValue e) { + return; + } + fail("UnknownCategorialFeatureValue"); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/OneHotEncoderPreprocessorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/OneHotEncoderPreprocessorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/OneHotEncoderPreprocessorTest.java new file mode 100644 index 0000000..294cfa0 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/OneHotEncoderPreprocessorTest.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.preprocessing.encoding; + +import java.util.HashMap; +import java.util.HashSet; +import org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialFeatureValue; +import org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor; +import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderPreprocessor; +import org.junit.Test; + +import static junit.framework.TestCase.fail; +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link StringEncoderPreprocessor}. + */ +public class OneHotEncoderPreprocessorTest { + /** Tests {@code apply()} method. */ + @Test + public void testApplyWithStringValues() { + String[][] data = new String[][]{ + {"1", "Moscow", "A"}, + {"2", "Moscow", "A"}, + {"2", "Moscow", "B"}, + }; + + OneHotEncoderPreprocessor<Integer, String[]> preprocessor = new OneHotEncoderPreprocessor<Integer, String[]>( + new HashMap[]{new HashMap() { + { + put("1", 1); + put("2", 0); + } + }, new HashMap() { + { + put("Moscow", 0); + } + }, new HashMap() { + { + put("A", 0); + put("B", 1); + } + }}, + (k, v) -> v, + new HashSet() { + { + add(0); + add(1); + add(2); + } + }); + + double[][] postProcessedData = new double[][]{ + {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0}, + {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0}, + {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0}, + }; + + for (int i = 0; i < data.length; i++) + assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).asArray(), 1e-8); + } + + + /** + * The {@code apply()} method is failed with UnknownCategorialFeatureValue exception. + * + * The reason is missed information in encodingValues. + * + * @see UnknownCategorialFeatureValue + */ + @Test + public void testApplyWithUnknownGategorialValues() { + String[][] data = new String[][]{ + {"1", "Moscow", "A"}, + {"2", "Moscow", "A"}, + {"2", "Moscow", "B"}, + }; + + OneHotEncoderPreprocessor<Integer, String[]> preprocessor = new OneHotEncoderPreprocessor<Integer, String[]>( + new HashMap[]{new HashMap() { + { + put("2", 0); + } + }, new HashMap() { + { + put("Moscow", 0); + } + }, new HashMap() { + { + put("A", 0); + put("B", 1); + } + }}, + (k, v) -> v, + new HashSet() { + { + add(0); + add(1); + add(2); + } + }); + + double[][] postProcessedData = new double[][]{ + {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0}, + {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0}, + {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0}, + }; + + try { + for (int i = 0; i < data.length; i++) + assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).asArray(), 1e-8); + + fail("UnknownCategorialFeatureValue"); + } catch (UnknownCategorialFeatureValue e) { + return; + } + fail("UnknownCategorialFeatureValue"); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java index f480209..1bfd7ee 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java @@ -68,7 +68,7 @@ public class StringEncoderPreprocessorTest { {0.0, 0.0, 0.0}, }; - for (int i = 0; i < data.length; i++) - assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).asArray(), 1e-8); + for (int i = 0; i < data.length; i++) + assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).asArray(), 1e-8); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/8cf9aa27/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java deleted file mode 100644 index 4f9d757..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.preprocessing.encoding; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; -import org.apache.ignite.ml.dataset.DatasetBuilder; -import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; -import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderPreprocessor; -import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; - -import static org.junit.Assert.assertArrayEquals; - -/** - * Tests for {@link StringEncoderTrainer}. - */ -@RunWith(Parameterized.class) -public class StringEncoderTrainerTest { - /** Parameters. */ - @Parameterized.Parameters(name = "Data divided on {0} partitions") - public static Iterable<Integer[]> data() { - return Arrays.asList( - new Integer[] {1}, - new Integer[] {2}, - new Integer[] {3}, - new Integer[] {5}, - new Integer[] {7}, - new Integer[] {100}, - new Integer[] {1000} - ); - } - - /** Number of partitions. */ - @Parameterized.Parameter - public int parts; - - /** Tests {@code fit()} method. */ - @Test - public void testFit() { - Map<Integer, String[]> data = new HashMap<>(); - data.put(1, new String[] {"Monday", "September"}); - data.put(2, new String[] {"Monday", "August"}); - data.put(3, new String[] {"Monday", "August"}); - data.put(4, new String[] {"Friday", "June"}); - data.put(5, new String[] {"Friday", "June"}); - data.put(6, new String[] {"Sunday", "August"}); - - DatasetBuilder<Integer, String[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts); - - StringEncoderTrainer<Integer, String[]> strEncoderTrainer = new StringEncoderTrainer<Integer, String[]>() - .encodeFeature(0) - .encodeFeature(1); - - StringEncoderPreprocessor<Integer, String[]> preprocessor = strEncoderTrainer.fit( - datasetBuilder, - (k, v) -> v - ); - - assertArrayEquals(new double[] {0.0, 2.0}, preprocessor.apply(7, new String[] {"Monday", "September"}).asArray(), 1e-8); - } -}
